diff --git a/crates/zeta2/Cargo.toml b/crates/zeta2/Cargo.toml index 3f394cd5ef2ab5d5bce05430a717312c9e3c0f5c..1cb3a866065748f8e39dee7a980b99ea0b6c63fa 100644 --- a/crates/zeta2/Cargo.toml +++ b/crates/zeta2/Cargo.toml @@ -11,6 +11,9 @@ workspace = true [lib] path = "src/zeta2.rs" +[features] +llm-response-cache = [] + [dependencies] anyhow.workspace = true arrayvec.workspace = true diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index 297bfa1c4a940448e7fdb570ea4b808556c3f416..c77c78b6f517bce085a26b2c60d04318b2f3cdae 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -131,6 +131,15 @@ pub struct Zeta { options: ZetaOptions, update_required: bool, debug_tx: Option>, + #[cfg(feature = "llm-response-cache")] + llm_response_cache: Option>, +} + +#[cfg(feature = "llm-response-cache")] +pub trait LlmResponseCache: Send + Sync { + fn get_key(&self, url: &gpui::http_client::Url, body: &str) -> u64; + fn read_response(&self, key: u64) -> Option; + fn write_response(&self, key: u64, value: &str); } #[derive(Debug, Clone, PartialEq)] @@ -359,9 +368,16 @@ impl Zeta { ), update_required: false, debug_tx: None, + #[cfg(feature = "llm-response-cache")] + llm_response_cache: None, } } + #[cfg(feature = "llm-response-cache")] + pub fn with_llm_response_cache(&mut self, cache: Arc) { + self.llm_response_cache = Some(cache); + } + pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver { let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded(); self.debug_tx = Some(debug_watch_tx); @@ -734,6 +750,9 @@ impl Zeta { }) .collect::>(); + #[cfg(feature = "llm-response-cache")] + let llm_response_cache = self.llm_response_cache.clone(); + let request_task = cx.background_spawn({ let active_buffer = active_buffer.clone(); async move { @@ -923,8 +942,14 @@ impl Zeta { log::trace!("Sending edit prediction request"); let before_request = chrono::Utc::now(); - let response = - Self::send_raw_llm_request(client, llm_token, app_version, request).await; + let response = Self::send_raw_llm_request( + request, + client, + llm_token, + app_version, + #[cfg(feature = "llm-response-cache")] + llm_response_cache + ).await; let request_time = chrono::Utc::now() - before_request; log::trace!("Got edit prediction response"); @@ -1005,10 +1030,13 @@ impl Zeta { } async fn send_raw_llm_request( + request: open_ai::Request, client: Arc, llm_token: LlmApiToken, app_version: SemanticVersion, - request: open_ai::Request, + #[cfg(feature = "llm-response-cache")] llm_response_cache: Option< + Arc, + >, ) -> Result<(open_ai::Response, Option)> { let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() { http_client::Url::parse(&predict_edits_url)? @@ -1018,7 +1046,21 @@ impl Zeta { .build_zed_llm_url("/predict_edits/raw", &[])? }; - Self::send_api_request( + #[cfg(feature = "llm-response-cache")] + let cache_key = if let Some(cache) = llm_response_cache { + let request_json = serde_json::to_string(&request)?; + let key = cache.get_key(&url, &request_json); + + if let Some(response_str) = cache.read_response(key) { + return Ok((serde_json::from_str(&response_str)?, None)); + } + + Some((cache, key)) + } else { + None + }; + + let (response, usage) = Self::send_api_request( |builder| { let req = builder .uri(url.as_ref()) @@ -1029,7 +1071,14 @@ impl Zeta { llm_token, app_version, ) - .await + .await?; + + #[cfg(feature = "llm-response-cache")] + if let Some((cache, key)) = cache_key { + cache.write_response(key, &serde_json::to_string(&response)?); + } + + Ok((response, usage)) } fn handle_api_response( @@ -1297,10 +1346,20 @@ impl Zeta { reasoning_effort: None, }; + #[cfg(feature = "llm-response-cache")] + let llm_response_cache = self.llm_response_cache.clone(); + cx.spawn(async move |this, cx| { log::trace!("Sending search planning request"); - let response = - Self::send_raw_llm_request(client, llm_token, app_version, request).await; + let response = Self::send_raw_llm_request( + request, + client, + llm_token, + app_version, + #[cfg(feature = "llm-response-cache")] + llm_response_cache, + ) + .await; let mut response = Self::handle_api_response(&this, response, cx)?; log::trace!("Got search planning response"); diff --git a/crates/zeta_cli/Cargo.toml b/crates/zeta_cli/Cargo.toml index 5bf90910f18f085db42d5f7934d13601e1c691a2..2e62f2a4462e31b7632aa5e825ea76a4b7df5fc8 100644 --- a/crates/zeta_cli/Cargo.toml +++ b/crates/zeta_cli/Cargo.toml @@ -54,7 +54,7 @@ toml.workspace = true util.workspace = true watch.workspace = true zeta.workspace = true -zeta2.workspace = true +zeta2 = { workspace = true, features = ["llm-response-cache"] } zlog.workspace = true [dev-dependencies] diff --git a/crates/zeta_cli/src/evaluate.rs b/crates/zeta_cli/src/evaluate.rs index c0f513fa38df5fb837be2294845eeae3214074bd..6d5b2da13a4301bfb52cb3cda7662843dea7cd12 100644 --- a/crates/zeta_cli/src/evaluate.rs +++ b/crates/zeta_cli/src/evaluate.rs @@ -1,5 +1,4 @@ use std::{ - fs, io::IsTerminal, path::{Path, PathBuf}, sync::Arc, @@ -12,9 +11,9 @@ use gpui::AsyncApp; use zeta2::udiff::DiffLine; use crate::{ + PromptFormat, example::{Example, NamedExample}, headless::ZetaCliAppState, - paths::CACHE_DIR, predict::{PredictionDetails, zeta2_predict}, }; @@ -22,7 +21,9 @@ use crate::{ pub struct EvaluateArguments { example_paths: Vec, #[clap(long)] - re_run: bool, + skip_cache: bool, + #[arg(long, value_enum, default_value_t = PromptFormat::default())] + prompt_format: PromptFormat, } pub async fn run_evaluate( @@ -33,7 +34,16 @@ pub async fn run_evaluate( let example_len = args.example_paths.len(); let all_tasks = args.example_paths.into_iter().map(|path| { let app_state = app_state.clone(); - cx.spawn(async move |cx| run_evaluate_one(&path, args.re_run, app_state.clone(), cx).await) + cx.spawn(async move |cx| { + run_evaluate_one( + &path, + args.skip_cache, + args.prompt_format, + app_state.clone(), + cx, + ) + .await + }) }); let all_results = futures::future::try_join_all(all_tasks).await.unwrap(); @@ -51,35 +61,15 @@ pub async fn run_evaluate( pub async fn run_evaluate_one( example_path: &Path, - re_run: bool, + skip_cache: bool, + prompt_format: PromptFormat, app_state: Arc, cx: &mut AsyncApp, ) -> Result { let example = NamedExample::load(&example_path).unwrap(); - let example_cache_path = CACHE_DIR.join(&example_path.file_name().unwrap()); - - let predictions = if !re_run && example_cache_path.exists() { - let file_contents = fs::read_to_string(&example_cache_path)?; - let as_json = serde_json::from_str::(&file_contents)?; - log::debug!( - "Loaded predictions from cache: {}", - example_cache_path.display() - ); - as_json - } else { - zeta2_predict(example.clone(), Default::default(), &app_state, cx) - .await - .unwrap() - }; - - if !example_cache_path.exists() { - fs::create_dir_all(&*CACHE_DIR).unwrap(); - fs::write( - example_cache_path, - serde_json::to_string(&predictions).unwrap(), - ) + let predictions = zeta2_predict(example.clone(), skip_cache, prompt_format, &app_state, cx) + .await .unwrap(); - } let evaluation_result = evaluate(&example.example, &predictions); diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index 66b4a6c8bd71ce046b6336ecb671d491128af945..25fb920bab18f374e41b539bc21320faf6c75484 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -158,13 +158,13 @@ fn syntax_args_to_options( }), max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes, max_prompt_bytes: zeta2_args.max_prompt_bytes, - prompt_format: zeta2_args.prompt_format.clone().into(), + prompt_format: zeta2_args.prompt_format.into(), file_indexing_parallelism: zeta2_args.file_indexing_parallelism, buffer_change_grouping_interval: Duration::ZERO, } } -#[derive(clap::ValueEnum, Default, Debug, Clone)] +#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)] enum PromptFormat { MarkedExcerpt, LabeledSections, diff --git a/crates/zeta_cli/src/paths.rs b/crates/zeta_cli/src/paths.rs index 61987607bf2a5bb99eae68db4863f97bb282b29c..144bf6f5dd97c518d965d7bd23da83ce7f11f66f 100644 --- a/crates/zeta_cli/src/paths.rs +++ b/crates/zeta_cli/src/paths.rs @@ -2,7 +2,7 @@ use std::{env, path::PathBuf, sync::LazyLock}; static TARGET_DIR: LazyLock = LazyLock::new(|| env::current_dir().unwrap().join("target")); pub static CACHE_DIR: LazyLock = - LazyLock::new(|| TARGET_DIR.join("zeta-prediction-cache")); + LazyLock::new(|| TARGET_DIR.join("zeta-llm-response-cache")); pub static REPOS_DIR: LazyLock = LazyLock::new(|| TARGET_DIR.join("zeta-repos")); pub static WORKTREES_DIR: LazyLock = LazyLock::new(|| TARGET_DIR.join("zeta-worktrees")); pub static LOGS_DIR: LazyLock = LazyLock::new(|| TARGET_DIR.join("zeta-logs")); diff --git a/crates/zeta_cli/src/predict.rs b/crates/zeta_cli/src/predict.rs index f7f503ffebe24d71023ad259ce76adfdea364efc..d85f009c9bacc0b6177683c064979740a0709115 100644 --- a/crates/zeta_cli/src/predict.rs +++ b/crates/zeta_cli/src/predict.rs @@ -1,10 +1,11 @@ use crate::PromptFormat; use crate::example::{ActualExcerpt, NamedExample}; use crate::headless::ZetaCliAppState; -use crate::paths::LOGS_DIR; +use crate::paths::{CACHE_DIR, LOGS_DIR}; use ::serde::Serialize; use anyhow::{Result, anyhow}; use clap::Args; +use gpui::http_client::Url; // use cloud_llm_client::predict_edits_v3::PromptFormat; use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock}; use futures::StreamExt as _; @@ -18,6 +19,7 @@ use std::path::PathBuf; use std::sync::Arc; use std::sync::Mutex; use std::time::{Duration, Instant}; +use zeta2::LlmResponseCache; #[derive(Debug, Args)] pub struct PredictArguments { @@ -26,6 +28,8 @@ pub struct PredictArguments { #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)] format: PredictionsOutputFormat, example_path: PathBuf, + #[clap(long)] + skip_cache: bool, } #[derive(clap::ValueEnum, Debug, Clone)] @@ -40,7 +44,7 @@ pub async fn run_zeta2_predict( cx: &mut AsyncApp, ) { let example = NamedExample::load(args.example_path).unwrap(); - let result = zeta2_predict(example, args.prompt_format, &app_state, cx) + let result = zeta2_predict(example, args.skip_cache, args.prompt_format, &app_state, cx) .await .unwrap(); result.write(args.format, std::io::stdout()).unwrap(); @@ -52,6 +56,7 @@ thread_local! { pub async fn zeta2_predict( example: NamedExample, + skip_cache: bool, prompt_format: PromptFormat, app_state: &Arc, cx: &mut AsyncApp, @@ -95,6 +100,10 @@ pub async fn zeta2_predict( let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?; + zeta.update(cx, |zeta, _cx| { + zeta.with_llm_response_cache(Arc::new(Cache { skip_cache })); + })?; + cx.subscribe(&buffer_store, { let project = project.clone(); move |_, event, cx| match event { @@ -233,6 +242,51 @@ pub async fn zeta2_predict( anyhow::Ok(result) } +struct Cache { + skip_cache: bool, +} + +impl Cache { + fn path(key: u64) -> PathBuf { + CACHE_DIR.join(format!("{key:x}.json")) + } +} + +impl LlmResponseCache for Cache { + fn get_key(&self, url: &Url, body: &str) -> u64 { + use collections::FxHasher; + use std::hash::{Hash, Hasher}; + + let mut hasher = FxHasher::default(); + url.hash(&mut hasher); + body.hash(&mut hasher); + hasher.finish() + } + + fn read_response(&self, key: u64) -> Option { + let path = Cache::path(key); + if path.exists() { + if self.skip_cache { + log::info!("Skipping existing cached LLM response: {}", path.display()); + None + } else { + log::info!("Using LLM response from cache: {}", path.display()); + Some(fs::read_to_string(path).unwrap()) + } + } else { + None + } + } + + fn write_response(&self, key: u64, value: &str) { + fs::create_dir_all(&*CACHE_DIR).unwrap(); + + let path = Cache::path(key); + log::info!("Writing LLM response to cache: {}", path.display()); + fs::write(path, value).unwrap(); + } +} + #[derive(Clone, Debug, Default, Serialize, Deserialize)] pub struct PredictionDetails { pub diff: String,