crates/zeta2/Cargo.toml 🔗
@@ -11,6 +11,9 @@ workspace = true
[lib]
path = "src/zeta2.rs"
+[features]
+llm-response-cache = []
+
[dependencies]
anyhow.workspace = true
arrayvec.workspace = true
Agus Zubiaga and Oleksiy Syvokon created
We'll now cache LLM responses at the request level (by hash of
URL+contents) for both context and prediction. This way we don't need to
worry about mistakenly using the cache when we change the prompt or its
components.
Release Notes:
- N/A
---------
Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>
crates/zeta2/Cargo.toml | 3 +
crates/zeta2/src/zeta2.rs | 73 +++++++++++++++++++++++++++++++---
crates/zeta_cli/Cargo.toml | 2
crates/zeta_cli/src/evaluate.rs | 46 ++++++++-------------
crates/zeta_cli/src/main.rs | 4
crates/zeta_cli/src/paths.rs | 2
crates/zeta_cli/src/predict.rs | 58 ++++++++++++++++++++++++++
7 files changed, 147 insertions(+), 41 deletions(-)
@@ -11,6 +11,9 @@ workspace = true
[lib]
path = "src/zeta2.rs"
+[features]
+llm-response-cache = []
+
[dependencies]
anyhow.workspace = true
arrayvec.workspace = true
@@ -131,6 +131,15 @@ pub struct Zeta {
options: ZetaOptions,
update_required: bool,
debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
+ #[cfg(feature = "llm-response-cache")]
+ llm_response_cache: Option<Arc<dyn LlmResponseCache>>,
+}
+
+#[cfg(feature = "llm-response-cache")]
+pub trait LlmResponseCache: Send + Sync {
+ fn get_key(&self, url: &gpui::http_client::Url, body: &str) -> u64;
+ fn read_response(&self, key: u64) -> Option<String>;
+ fn write_response(&self, key: u64, value: &str);
}
#[derive(Debug, Clone, PartialEq)]
@@ -359,9 +368,16 @@ impl Zeta {
),
update_required: false,
debug_tx: None,
+ #[cfg(feature = "llm-response-cache")]
+ llm_response_cache: None,
}
}
+ #[cfg(feature = "llm-response-cache")]
+ pub fn with_llm_response_cache(&mut self, cache: Arc<dyn LlmResponseCache>) {
+ self.llm_response_cache = Some(cache);
+ }
+
pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
self.debug_tx = Some(debug_watch_tx);
@@ -734,6 +750,9 @@ impl Zeta {
})
.collect::<Vec<_>>();
+ #[cfg(feature = "llm-response-cache")]
+ let llm_response_cache = self.llm_response_cache.clone();
+
let request_task = cx.background_spawn({
let active_buffer = active_buffer.clone();
async move {
@@ -923,8 +942,14 @@ impl Zeta {
log::trace!("Sending edit prediction request");
let before_request = chrono::Utc::now();
- let response =
- Self::send_raw_llm_request(client, llm_token, app_version, request).await;
+ let response = Self::send_raw_llm_request(
+ request,
+ client,
+ llm_token,
+ app_version,
+ #[cfg(feature = "llm-response-cache")]
+ llm_response_cache
+ ).await;
let request_time = chrono::Utc::now() - before_request;
log::trace!("Got edit prediction response");
@@ -1005,10 +1030,13 @@ impl Zeta {
}
async fn send_raw_llm_request(
+ request: open_ai::Request,
client: Arc<Client>,
llm_token: LlmApiToken,
app_version: SemanticVersion,
- request: open_ai::Request,
+ #[cfg(feature = "llm-response-cache")] llm_response_cache: Option<
+ Arc<dyn LlmResponseCache>,
+ >,
) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
http_client::Url::parse(&predict_edits_url)?
@@ -1018,7 +1046,21 @@ impl Zeta {
.build_zed_llm_url("/predict_edits/raw", &[])?
};
- Self::send_api_request(
+ #[cfg(feature = "llm-response-cache")]
+ let cache_key = if let Some(cache) = llm_response_cache {
+ let request_json = serde_json::to_string(&request)?;
+ let key = cache.get_key(&url, &request_json);
+
+ if let Some(response_str) = cache.read_response(key) {
+ return Ok((serde_json::from_str(&response_str)?, None));
+ }
+
+ Some((cache, key))
+ } else {
+ None
+ };
+
+ let (response, usage) = Self::send_api_request(
|builder| {
let req = builder
.uri(url.as_ref())
@@ -1029,7 +1071,14 @@ impl Zeta {
llm_token,
app_version,
)
- .await
+ .await?;
+
+ #[cfg(feature = "llm-response-cache")]
+ if let Some((cache, key)) = cache_key {
+ cache.write_response(key, &serde_json::to_string(&response)?);
+ }
+
+ Ok((response, usage))
}
fn handle_api_response<T>(
@@ -1297,10 +1346,20 @@ impl Zeta {
reasoning_effort: None,
};
+ #[cfg(feature = "llm-response-cache")]
+ let llm_response_cache = self.llm_response_cache.clone();
+
cx.spawn(async move |this, cx| {
log::trace!("Sending search planning request");
- let response =
- Self::send_raw_llm_request(client, llm_token, app_version, request).await;
+ let response = Self::send_raw_llm_request(
+ request,
+ client,
+ llm_token,
+ app_version,
+ #[cfg(feature = "llm-response-cache")]
+ llm_response_cache,
+ )
+ .await;
let mut response = Self::handle_api_response(&this, response, cx)?;
log::trace!("Got search planning response");
@@ -54,7 +54,7 @@ toml.workspace = true
util.workspace = true
watch.workspace = true
zeta.workspace = true
-zeta2.workspace = true
+zeta2 = { workspace = true, features = ["llm-response-cache"] }
zlog.workspace = true
[dev-dependencies]
@@ -1,5 +1,4 @@
use std::{
- fs,
io::IsTerminal,
path::{Path, PathBuf},
sync::Arc,
@@ -12,9 +11,9 @@ use gpui::AsyncApp;
use zeta2::udiff::DiffLine;
use crate::{
+ PromptFormat,
example::{Example, NamedExample},
headless::ZetaCliAppState,
- paths::CACHE_DIR,
predict::{PredictionDetails, zeta2_predict},
};
@@ -22,7 +21,9 @@ use crate::{
pub struct EvaluateArguments {
example_paths: Vec<PathBuf>,
#[clap(long)]
- re_run: bool,
+ skip_cache: bool,
+ #[arg(long, value_enum, default_value_t = PromptFormat::default())]
+ prompt_format: PromptFormat,
}
pub async fn run_evaluate(
@@ -33,7 +34,16 @@ pub async fn run_evaluate(
let example_len = args.example_paths.len();
let all_tasks = args.example_paths.into_iter().map(|path| {
let app_state = app_state.clone();
- cx.spawn(async move |cx| run_evaluate_one(&path, args.re_run, app_state.clone(), cx).await)
+ cx.spawn(async move |cx| {
+ run_evaluate_one(
+ &path,
+ args.skip_cache,
+ args.prompt_format,
+ app_state.clone(),
+ cx,
+ )
+ .await
+ })
});
let all_results = futures::future::try_join_all(all_tasks).await.unwrap();
@@ -51,35 +61,15 @@ pub async fn run_evaluate(
pub async fn run_evaluate_one(
example_path: &Path,
- re_run: bool,
+ skip_cache: bool,
+ prompt_format: PromptFormat,
app_state: Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) -> Result<EvaluationResult> {
let example = NamedExample::load(&example_path).unwrap();
- let example_cache_path = CACHE_DIR.join(&example_path.file_name().unwrap());
-
- let predictions = if !re_run && example_cache_path.exists() {
- let file_contents = fs::read_to_string(&example_cache_path)?;
- let as_json = serde_json::from_str::<PredictionDetails>(&file_contents)?;
- log::debug!(
- "Loaded predictions from cache: {}",
- example_cache_path.display()
- );
- as_json
- } else {
- zeta2_predict(example.clone(), Default::default(), &app_state, cx)
- .await
- .unwrap()
- };
-
- if !example_cache_path.exists() {
- fs::create_dir_all(&*CACHE_DIR).unwrap();
- fs::write(
- example_cache_path,
- serde_json::to_string(&predictions).unwrap(),
- )
+ let predictions = zeta2_predict(example.clone(), skip_cache, prompt_format, &app_state, cx)
+ .await
.unwrap();
- }
let evaluation_result = evaluate(&example.example, &predictions);
@@ -158,13 +158,13 @@ fn syntax_args_to_options(
}),
max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
max_prompt_bytes: zeta2_args.max_prompt_bytes,
- prompt_format: zeta2_args.prompt_format.clone().into(),
+ prompt_format: zeta2_args.prompt_format.into(),
file_indexing_parallelism: zeta2_args.file_indexing_parallelism,
buffer_change_grouping_interval: Duration::ZERO,
}
}
-#[derive(clap::ValueEnum, Default, Debug, Clone)]
+#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
enum PromptFormat {
MarkedExcerpt,
LabeledSections,
@@ -2,7 +2,7 @@ use std::{env, path::PathBuf, sync::LazyLock};
static TARGET_DIR: LazyLock<PathBuf> = LazyLock::new(|| env::current_dir().unwrap().join("target"));
pub static CACHE_DIR: LazyLock<PathBuf> =
- LazyLock::new(|| TARGET_DIR.join("zeta-prediction-cache"));
+ LazyLock::new(|| TARGET_DIR.join("zeta-llm-response-cache"));
pub static REPOS_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_DIR.join("zeta-repos"));
pub static WORKTREES_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_DIR.join("zeta-worktrees"));
pub static LOGS_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_DIR.join("zeta-logs"));
@@ -1,10 +1,11 @@
use crate::PromptFormat;
use crate::example::{ActualExcerpt, NamedExample};
use crate::headless::ZetaCliAppState;
-use crate::paths::LOGS_DIR;
+use crate::paths::{CACHE_DIR, LOGS_DIR};
use ::serde::Serialize;
use anyhow::{Result, anyhow};
use clap::Args;
+use gpui::http_client::Url;
// use cloud_llm_client::predict_edits_v3::PromptFormat;
use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
use futures::StreamExt as _;
@@ -18,6 +19,7 @@ use std::path::PathBuf;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::{Duration, Instant};
+use zeta2::LlmResponseCache;
#[derive(Debug, Args)]
pub struct PredictArguments {
@@ -26,6 +28,8 @@ pub struct PredictArguments {
#[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
format: PredictionsOutputFormat,
example_path: PathBuf,
+ #[clap(long)]
+ skip_cache: bool,
}
#[derive(clap::ValueEnum, Debug, Clone)]
@@ -40,7 +44,7 @@ pub async fn run_zeta2_predict(
cx: &mut AsyncApp,
) {
let example = NamedExample::load(args.example_path).unwrap();
- let result = zeta2_predict(example, args.prompt_format, &app_state, cx)
+ let result = zeta2_predict(example, args.skip_cache, args.prompt_format, &app_state, cx)
.await
.unwrap();
result.write(args.format, std::io::stdout()).unwrap();
@@ -52,6 +56,7 @@ thread_local! {
pub async fn zeta2_predict(
example: NamedExample,
+ skip_cache: bool,
prompt_format: PromptFormat,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
@@ -95,6 +100,10 @@ pub async fn zeta2_predict(
let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
+ zeta.update(cx, |zeta, _cx| {
+ zeta.with_llm_response_cache(Arc::new(Cache { skip_cache }));
+ })?;
+
cx.subscribe(&buffer_store, {
let project = project.clone();
move |_, event, cx| match event {
@@ -233,6 +242,51 @@ pub async fn zeta2_predict(
anyhow::Ok(result)
}
+struct Cache {
+ skip_cache: bool,
+}
+
+impl Cache {
+ fn path(key: u64) -> PathBuf {
+ CACHE_DIR.join(format!("{key:x}.json"))
+ }
+}
+
+impl LlmResponseCache for Cache {
+ fn get_key(&self, url: &Url, body: &str) -> u64 {
+ use collections::FxHasher;
+ use std::hash::{Hash, Hasher};
+
+ let mut hasher = FxHasher::default();
+ url.hash(&mut hasher);
+ body.hash(&mut hasher);
+ hasher.finish()
+ }
+
+ fn read_response(&self, key: u64) -> Option<String> {
+ let path = Cache::path(key);
+ if path.exists() {
+ if self.skip_cache {
+ log::info!("Skipping existing cached LLM response: {}", path.display());
+ None
+ } else {
+ log::info!("Using LLM response from cache: {}", path.display());
+ Some(fs::read_to_string(path).unwrap())
+ }
+ } else {
+ None
+ }
+ }
+
+ fn write_response(&self, key: u64, value: &str) {
+ fs::create_dir_all(&*CACHE_DIR).unwrap();
+
+ let path = Cache::path(key);
+ log::info!("Writing LLM response to cache: {}", path.display());
+ fs::write(path, value).unwrap();
+ }
+}
+
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct PredictionDetails {
pub diff: String,