zeta2 cli: Cache at LLM request level (#42371)

Agus Zubiaga and Oleksiy Syvokon created

We'll now cache LLM responses at the request level (by hash of
URL+contents) for both context and prediction. This way we don't need to
worry about mistakenly using the cache when we change the prompt or its
components.

Release Notes:

- N/A

---------

Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>

Change summary

crates/zeta2/Cargo.toml         |  3 +
crates/zeta2/src/zeta2.rs       | 73 +++++++++++++++++++++++++++++++---
crates/zeta_cli/Cargo.toml      |  2 
crates/zeta_cli/src/evaluate.rs | 46 ++++++++-------------
crates/zeta_cli/src/main.rs     |  4 
crates/zeta_cli/src/paths.rs    |  2 
crates/zeta_cli/src/predict.rs  | 58 ++++++++++++++++++++++++++
7 files changed, 147 insertions(+), 41 deletions(-)

Detailed changes

crates/zeta2/Cargo.toml 🔗

@@ -11,6 +11,9 @@ workspace = true
 [lib]
 path = "src/zeta2.rs"
 
+[features]
+llm-response-cache = []
+
 [dependencies]
 anyhow.workspace = true
 arrayvec.workspace = true

crates/zeta2/src/zeta2.rs 🔗

@@ -131,6 +131,15 @@ pub struct Zeta {
     options: ZetaOptions,
     update_required: bool,
     debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
+    #[cfg(feature = "llm-response-cache")]
+    llm_response_cache: Option<Arc<dyn LlmResponseCache>>,
+}
+
+#[cfg(feature = "llm-response-cache")]
+pub trait LlmResponseCache: Send + Sync {
+    fn get_key(&self, url: &gpui::http_client::Url, body: &str) -> u64;
+    fn read_response(&self, key: u64) -> Option<String>;
+    fn write_response(&self, key: u64, value: &str);
 }
 
 #[derive(Debug, Clone, PartialEq)]
@@ -359,9 +368,16 @@ impl Zeta {
             ),
             update_required: false,
             debug_tx: None,
+            #[cfg(feature = "llm-response-cache")]
+            llm_response_cache: None,
         }
     }
 
+    #[cfg(feature = "llm-response-cache")]
+    pub fn with_llm_response_cache(&mut self, cache: Arc<dyn LlmResponseCache>) {
+        self.llm_response_cache = Some(cache);
+    }
+
     pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
         let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
         self.debug_tx = Some(debug_watch_tx);
@@ -734,6 +750,9 @@ impl Zeta {
             })
             .collect::<Vec<_>>();
 
+        #[cfg(feature = "llm-response-cache")]
+        let llm_response_cache = self.llm_response_cache.clone();
+
         let request_task = cx.background_spawn({
             let active_buffer = active_buffer.clone();
             async move {
@@ -923,8 +942,14 @@ impl Zeta {
                 log::trace!("Sending edit prediction request");
 
                 let before_request = chrono::Utc::now();
-                let response =
-                    Self::send_raw_llm_request(client, llm_token, app_version, request).await;
+                let response = Self::send_raw_llm_request(
+                    request,
+                    client,
+                    llm_token,
+                    app_version,
+                    #[cfg(feature = "llm-response-cache")]
+                    llm_response_cache
+                ).await;
                 let request_time = chrono::Utc::now() - before_request;
 
                 log::trace!("Got edit prediction response");
@@ -1005,10 +1030,13 @@ impl Zeta {
     }
 
     async fn send_raw_llm_request(
+        request: open_ai::Request,
         client: Arc<Client>,
         llm_token: LlmApiToken,
         app_version: SemanticVersion,
-        request: open_ai::Request,
+        #[cfg(feature = "llm-response-cache")] llm_response_cache: Option<
+            Arc<dyn LlmResponseCache>,
+        >,
     ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
         let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
             http_client::Url::parse(&predict_edits_url)?
@@ -1018,7 +1046,21 @@ impl Zeta {
                 .build_zed_llm_url("/predict_edits/raw", &[])?
         };
 
-        Self::send_api_request(
+        #[cfg(feature = "llm-response-cache")]
+        let cache_key = if let Some(cache) = llm_response_cache {
+            let request_json = serde_json::to_string(&request)?;
+            let key = cache.get_key(&url, &request_json);
+
+            if let Some(response_str) = cache.read_response(key) {
+                return Ok((serde_json::from_str(&response_str)?, None));
+            }
+
+            Some((cache, key))
+        } else {
+            None
+        };
+
+        let (response, usage) = Self::send_api_request(
             |builder| {
                 let req = builder
                     .uri(url.as_ref())
@@ -1029,7 +1071,14 @@ impl Zeta {
             llm_token,
             app_version,
         )
-        .await
+        .await?;
+
+        #[cfg(feature = "llm-response-cache")]
+        if let Some((cache, key)) = cache_key {
+            cache.write_response(key, &serde_json::to_string(&response)?);
+        }
+
+        Ok((response, usage))
     }
 
     fn handle_api_response<T>(
@@ -1297,10 +1346,20 @@ impl Zeta {
             reasoning_effort: None,
         };
 
+        #[cfg(feature = "llm-response-cache")]
+        let llm_response_cache = self.llm_response_cache.clone();
+
         cx.spawn(async move |this, cx| {
             log::trace!("Sending search planning request");
-            let response =
-                Self::send_raw_llm_request(client, llm_token, app_version, request).await;
+            let response = Self::send_raw_llm_request(
+                request,
+                client,
+                llm_token,
+                app_version,
+                #[cfg(feature = "llm-response-cache")]
+                llm_response_cache,
+            )
+            .await;
             let mut response = Self::handle_api_response(&this, response, cx)?;
             log::trace!("Got search planning response");
 

crates/zeta_cli/Cargo.toml 🔗

@@ -54,7 +54,7 @@ toml.workspace = true
 util.workspace = true
 watch.workspace = true
 zeta.workspace = true
-zeta2.workspace = true
+zeta2 = { workspace = true, features = ["llm-response-cache"] }
 zlog.workspace = true
 
 [dev-dependencies]

crates/zeta_cli/src/evaluate.rs 🔗

@@ -1,5 +1,4 @@
 use std::{
-    fs,
     io::IsTerminal,
     path::{Path, PathBuf},
     sync::Arc,
@@ -12,9 +11,9 @@ use gpui::AsyncApp;
 use zeta2::udiff::DiffLine;
 
 use crate::{
+    PromptFormat,
     example::{Example, NamedExample},
     headless::ZetaCliAppState,
-    paths::CACHE_DIR,
     predict::{PredictionDetails, zeta2_predict},
 };
 
@@ -22,7 +21,9 @@ use crate::{
 pub struct EvaluateArguments {
     example_paths: Vec<PathBuf>,
     #[clap(long)]
-    re_run: bool,
+    skip_cache: bool,
+    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
+    prompt_format: PromptFormat,
 }
 
 pub async fn run_evaluate(
@@ -33,7 +34,16 @@ pub async fn run_evaluate(
     let example_len = args.example_paths.len();
     let all_tasks = args.example_paths.into_iter().map(|path| {
         let app_state = app_state.clone();
-        cx.spawn(async move |cx| run_evaluate_one(&path, args.re_run, app_state.clone(), cx).await)
+        cx.spawn(async move |cx| {
+            run_evaluate_one(
+                &path,
+                args.skip_cache,
+                args.prompt_format,
+                app_state.clone(),
+                cx,
+            )
+            .await
+        })
     });
     let all_results = futures::future::try_join_all(all_tasks).await.unwrap();
 
@@ -51,35 +61,15 @@ pub async fn run_evaluate(
 
 pub async fn run_evaluate_one(
     example_path: &Path,
-    re_run: bool,
+    skip_cache: bool,
+    prompt_format: PromptFormat,
     app_state: Arc<ZetaCliAppState>,
     cx: &mut AsyncApp,
 ) -> Result<EvaluationResult> {
     let example = NamedExample::load(&example_path).unwrap();
-    let example_cache_path = CACHE_DIR.join(&example_path.file_name().unwrap());
-
-    let predictions = if !re_run && example_cache_path.exists() {
-        let file_contents = fs::read_to_string(&example_cache_path)?;
-        let as_json = serde_json::from_str::<PredictionDetails>(&file_contents)?;
-        log::debug!(
-            "Loaded predictions from cache: {}",
-            example_cache_path.display()
-        );
-        as_json
-    } else {
-        zeta2_predict(example.clone(), Default::default(), &app_state, cx)
-            .await
-            .unwrap()
-    };
-
-    if !example_cache_path.exists() {
-        fs::create_dir_all(&*CACHE_DIR).unwrap();
-        fs::write(
-            example_cache_path,
-            serde_json::to_string(&predictions).unwrap(),
-        )
+    let predictions = zeta2_predict(example.clone(), skip_cache, prompt_format, &app_state, cx)
+        .await
         .unwrap();
-    }
 
     let evaluation_result = evaluate(&example.example, &predictions);
 

crates/zeta_cli/src/main.rs 🔗

@@ -158,13 +158,13 @@ fn syntax_args_to_options(
         }),
         max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
         max_prompt_bytes: zeta2_args.max_prompt_bytes,
-        prompt_format: zeta2_args.prompt_format.clone().into(),
+        prompt_format: zeta2_args.prompt_format.into(),
         file_indexing_parallelism: zeta2_args.file_indexing_parallelism,
         buffer_change_grouping_interval: Duration::ZERO,
     }
 }
 
-#[derive(clap::ValueEnum, Default, Debug, Clone)]
+#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
 enum PromptFormat {
     MarkedExcerpt,
     LabeledSections,

crates/zeta_cli/src/paths.rs 🔗

@@ -2,7 +2,7 @@ use std::{env, path::PathBuf, sync::LazyLock};
 
 static TARGET_DIR: LazyLock<PathBuf> = LazyLock::new(|| env::current_dir().unwrap().join("target"));
 pub static CACHE_DIR: LazyLock<PathBuf> =
-    LazyLock::new(|| TARGET_DIR.join("zeta-prediction-cache"));
+    LazyLock::new(|| TARGET_DIR.join("zeta-llm-response-cache"));
 pub static REPOS_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_DIR.join("zeta-repos"));
 pub static WORKTREES_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_DIR.join("zeta-worktrees"));
 pub static LOGS_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_DIR.join("zeta-logs"));

crates/zeta_cli/src/predict.rs 🔗

@@ -1,10 +1,11 @@
 use crate::PromptFormat;
 use crate::example::{ActualExcerpt, NamedExample};
 use crate::headless::ZetaCliAppState;
-use crate::paths::LOGS_DIR;
+use crate::paths::{CACHE_DIR, LOGS_DIR};
 use ::serde::Serialize;
 use anyhow::{Result, anyhow};
 use clap::Args;
+use gpui::http_client::Url;
 // use cloud_llm_client::predict_edits_v3::PromptFormat;
 use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
 use futures::StreamExt as _;
@@ -18,6 +19,7 @@ use std::path::PathBuf;
 use std::sync::Arc;
 use std::sync::Mutex;
 use std::time::{Duration, Instant};
+use zeta2::LlmResponseCache;
 
 #[derive(Debug, Args)]
 pub struct PredictArguments {
@@ -26,6 +28,8 @@ pub struct PredictArguments {
     #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
     format: PredictionsOutputFormat,
     example_path: PathBuf,
+    #[clap(long)]
+    skip_cache: bool,
 }
 
 #[derive(clap::ValueEnum, Debug, Clone)]
@@ -40,7 +44,7 @@ pub async fn run_zeta2_predict(
     cx: &mut AsyncApp,
 ) {
     let example = NamedExample::load(args.example_path).unwrap();
-    let result = zeta2_predict(example, args.prompt_format, &app_state, cx)
+    let result = zeta2_predict(example, args.skip_cache, args.prompt_format, &app_state, cx)
         .await
         .unwrap();
     result.write(args.format, std::io::stdout()).unwrap();
@@ -52,6 +56,7 @@ thread_local! {
 
 pub async fn zeta2_predict(
     example: NamedExample,
+    skip_cache: bool,
     prompt_format: PromptFormat,
     app_state: &Arc<ZetaCliAppState>,
     cx: &mut AsyncApp,
@@ -95,6 +100,10 @@ pub async fn zeta2_predict(
 
     let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
 
+    zeta.update(cx, |zeta, _cx| {
+        zeta.with_llm_response_cache(Arc::new(Cache { skip_cache }));
+    })?;
+
     cx.subscribe(&buffer_store, {
         let project = project.clone();
         move |_, event, cx| match event {
@@ -233,6 +242,51 @@ pub async fn zeta2_predict(
     anyhow::Ok(result)
 }
 
+struct Cache {
+    skip_cache: bool,
+}
+
+impl Cache {
+    fn path(key: u64) -> PathBuf {
+        CACHE_DIR.join(format!("{key:x}.json"))
+    }
+}
+
+impl LlmResponseCache for Cache {
+    fn get_key(&self, url: &Url, body: &str) -> u64 {
+        use collections::FxHasher;
+        use std::hash::{Hash, Hasher};
+
+        let mut hasher = FxHasher::default();
+        url.hash(&mut hasher);
+        body.hash(&mut hasher);
+        hasher.finish()
+    }
+
+    fn read_response(&self, key: u64) -> Option<String> {
+        let path = Cache::path(key);
+        if path.exists() {
+            if self.skip_cache {
+                log::info!("Skipping existing cached LLM response: {}", path.display());
+                None
+            } else {
+                log::info!("Using LLM response from cache: {}", path.display());
+                Some(fs::read_to_string(path).unwrap())
+            }
+        } else {
+            None
+        }
+    }
+
+    fn write_response(&self, key: u64, value: &str) {
+        fs::create_dir_all(&*CACHE_DIR).unwrap();
+
+        let path = Cache::path(key);
+        log::info!("Writing LLM response to cache: {}", path.display());
+        fs::write(path, value).unwrap();
+    }
+}
+
 #[derive(Clone, Debug, Default, Serialize, Deserialize)]
 pub struct PredictionDetails {
     pub diff: String,