diff --git a/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs b/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs index 7fbc3834dfd0f4bbfc4085d696b7fbf755e6dd3d..a11c56da41384257b8331a31161224c9e25d0894 100644 --- a/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs +++ b/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs @@ -44,7 +44,7 @@ pub struct SearchToolInput { } /// Search for relevant code by path, syntax hierarchy, and content. -#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Hash)] pub struct SearchToolQuery { /// 1. A glob pattern to match file paths in the codebase to search in. pub glob: String, diff --git a/crates/zeta2/Cargo.toml b/crates/zeta2/Cargo.toml index 0360a74e65a109a0c95ea4787a0df1c61b375615..1eef507e6def3d80560ff1515623d0c42687d74a 100644 --- a/crates/zeta2/Cargo.toml +++ b/crates/zeta2/Cargo.toml @@ -12,7 +12,7 @@ workspace = true path = "src/zeta2.rs" [features] -llm-response-cache = [] +eval-support = [] [dependencies] anyhow.workspace = true diff --git a/crates/zeta2/src/retrieval_search.rs b/crates/zeta2/src/retrieval_search.rs index fe28976bb27e27cd6355d3efa13e0a1bf26d5962..76501fb1e5c73a22ff8eebc5c29d117d45389beb 100644 --- a/crates/zeta2/src/retrieval_search.rs +++ b/crates/zeta2/src/retrieval_search.rs @@ -1,5 +1,3 @@ -use std::ops::Range; - use anyhow::Result; use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery; use collections::HashMap; @@ -14,17 +12,76 @@ use project::{ search::{SearchQuery, SearchResult}, }; use smol::channel; +use std::ops::Range; use util::{ ResultExt as _, paths::{PathMatcher, PathStyle}, }; use workspace::item::Settings as _; +#[cfg(feature = "eval-support")] +type CachedSearchResults = std::collections::BTreeMap>>; + pub async fn run_retrieval_searches( - project: Entity, queries: Vec, + project: Entity, + #[cfg(feature = "eval-support")] eval_cache: Option>, cx: &mut AsyncApp, ) -> Result, Vec>>> { + #[cfg(feature = "eval-support")] + let cache = if let Some(eval_cache) = eval_cache { + use crate::EvalCacheEntryKind; + use anyhow::Context; + use collections::FxHasher; + use std::hash::{Hash, Hasher}; + + let mut hasher = FxHasher::default(); + project.read_with(cx, |project, cx| { + let mut worktrees = project.worktrees(cx); + let Some(worktree) = worktrees.next() else { + panic!("Expected a single worktree in eval project. Found none."); + }; + assert!( + worktrees.next().is_none(), + "Expected a single worktree in eval project. Found more than one." + ); + worktree.read(cx).abs_path().hash(&mut hasher); + })?; + + queries.hash(&mut hasher); + let key = (EvalCacheEntryKind::Search, hasher.finish()); + + if let Some(cached_results) = eval_cache.read(key) { + let file_results = serde_json::from_str::(&cached_results) + .context("Failed to deserialize cached search results")?; + let mut results = HashMap::default(); + + for (path, ranges) in file_results { + let buffer = project + .update(cx, |project, cx| { + let project_path = project.find_project_path(path, cx).unwrap(); + project.open_buffer(project_path, cx) + })? + .await?; + let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?; + let mut ranges = ranges + .into_iter() + .map(|range| { + snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end) + }) + .collect(); + merge_anchor_ranges(&mut ranges, &snapshot); + results.insert(buffer, ranges); + } + + return Ok(results); + } + + Some((eval_cache, serde_json::to_string_pretty(&queries)?, key)) + } else { + None + }; + let (exclude_matcher, path_style) = project.update(cx, |project, cx| { let global_settings = WorktreeSettings::get_global(cx); let exclude_patterns = global_settings @@ -58,6 +115,8 @@ pub async fn run_retrieval_searches( } drop(results_tx); + #[cfg(feature = "eval-support")] + let cache = cache.clone(); cx.background_spawn(async move { let mut results: HashMap, Vec>> = HashMap::default(); let mut snapshots = HashMap::default(); @@ -79,6 +138,29 @@ pub async fn run_retrieval_searches( } } + #[cfg(feature = "eval-support")] + if let Some((cache, queries, key)) = cache { + let cached_results: CachedSearchResults = results + .iter() + .filter_map(|(buffer, ranges)| { + let snapshot = snapshots.get(&buffer.entity_id())?; + let path = snapshot.file().map(|f| f.path()); + let mut ranges = ranges + .iter() + .map(|range| range.to_offset(&snapshot)) + .collect::>(); + ranges.sort_unstable_by_key(|range| (range.start, range.end)); + + Some((path?.as_std_path().to_path_buf(), ranges)) + }) + .collect(); + cache.write( + key, + &queries, + &serde_json::to_string_pretty(&cached_results)?, + ); + } + for (buffer, ranges) in results.iter_mut() { if let Some(snapshot) = snapshots.get(&buffer.entity_id()) { merge_anchor_ranges(ranges, snapshot); @@ -489,9 +571,10 @@ mod tests { expected_output: &str, cx: &mut TestAppContext, ) { - let results = run_retrieval_searches(project.clone(), vec![query], &mut cx.to_async()) - .await - .unwrap(); + let results = + run_retrieval_searches(vec![query], project.clone(), None, &mut cx.to_async()) + .await + .unwrap(); let mut results = results.into_iter().collect::>(); results.sort_by_key(|results| { diff --git a/crates/zeta2/src/xml_edits.rs b/crates/zeta2/src/xml_edits.rs index 6c9b5a97f6398cc00eaca08f9af6c4c9de991785..97087ec65e06a1a2f418ca0c4ebba41a19b1af84 100644 --- a/crates/zeta2/src/xml_edits.rs +++ b/crates/zeta2/src/xml_edits.rs @@ -105,21 +105,58 @@ fn resolve_new_text_old_text_in_buffer( #[cfg(debug_assertions)] fn closest_old_text_match(buffer: &TextBufferSnapshot, old_text: &str) -> Option { let buffer_text = buffer.text(); - let mut cursor = 0; let len = old_text.len(); + if len == 0 || buffer_text.len() < len { + return None; + } + let mut min_score = usize::MAX; let mut min_start = 0; + let old_text_bytes = old_text.as_bytes(); + let old_alpha_count = old_text_bytes + .iter() + .filter(|&&b| b.is_ascii_alphanumeric()) + .count(); + + let old_line_count = old_text.lines().count(); + + let mut cursor = 0; + while cursor + len <= buffer_text.len() { let candidate = &buffer_text[cursor..cursor + len]; + let candidate_bytes = candidate.as_bytes(); + + if usize::abs_diff(candidate.lines().count(), old_line_count) > 4 { + cursor += 1; + continue; + } + + let candidate_alpha_count = candidate_bytes + .iter() + .filter(|&&b| b.is_ascii_alphanumeric()) + .count(); + + // If alphanumeric character count differs by more than 30%, skip + if usize::abs_diff(old_alpha_count, candidate_alpha_count) * 10 > old_alpha_count * 3 { + cursor += 1; + continue; + } + let score = strsim::levenshtein(candidate, old_text); if score < min_score { min_score = score; min_start = cursor; + + if min_score <= len / 10 { + break; + } } + cursor += 1; } + if min_score != usize::MAX { Some(buffer_text[min_start..min_start + len].to_string()) } else { diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index 6139c9c75e16f8805e6529dc1700eef1beacd713..d7bff2b51a69a031d2f24b0b357b9748dd5a473b 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -132,15 +132,8 @@ pub struct Zeta { options: ZetaOptions, update_required: bool, debug_tx: Option>, - #[cfg(feature = "llm-response-cache")] - llm_response_cache: Option>, -} - -#[cfg(feature = "llm-response-cache")] -pub trait LlmResponseCache: Send + Sync { - fn get_key(&self, url: &gpui::http_client::Url, body: &str) -> u64; - fn read_response(&self, key: u64) -> Option; - fn write_response(&self, key: u64, value: &str); + #[cfg(feature = "eval-support")] + eval_cache: Option>, } #[derive(Debug, Clone, PartialEq)] @@ -369,14 +362,14 @@ impl Zeta { ), update_required: false, debug_tx: None, - #[cfg(feature = "llm-response-cache")] - llm_response_cache: None, + #[cfg(feature = "eval-support")] + eval_cache: None, } } - #[cfg(feature = "llm-response-cache")] - pub fn with_llm_response_cache(&mut self, cache: Arc) { - self.llm_response_cache = Some(cache); + #[cfg(feature = "eval-support")] + pub fn with_eval_cache(&mut self, cache: Arc) { + self.eval_cache = Some(cache); } pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver { @@ -736,9 +729,19 @@ impl Zeta { // TODO data collection let can_collect_data = cx.is_staff(); - let mut included_files = project_state + let empty_context_files = HashMap::default(); + let context_files = project_state .and_then(|project_state| project_state.context.as_ref()) - .unwrap_or(&HashMap::default()) + .unwrap_or(&empty_context_files); + + #[cfg(feature = "eval-support")] + let parsed_fut = futures::future::join_all( + context_files + .keys() + .map(|buffer| buffer.read(cx).parsing_idle()), + ); + + let mut included_files = context_files .iter() .filter_map(|(buffer_entity, ranges)| { let buffer = buffer_entity.read(cx); @@ -751,12 +754,19 @@ impl Zeta { }) .collect::>(); - #[cfg(feature = "llm-response-cache")] - let llm_response_cache = self.llm_response_cache.clone(); + included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| { + (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len())) + }); + + #[cfg(feature = "eval-support")] + let eval_cache = self.eval_cache.clone(); let request_task = cx.background_spawn({ let active_buffer = active_buffer.clone(); async move { + #[cfg(feature = "eval-support")] + parsed_fut.await; + let index_state = if let Some(index_state) = index_state { Some(index_state.lock_owned().await) } else { @@ -819,17 +829,17 @@ impl Zeta { let included_files = included_files .iter() - .map(|(_, buffer, path, ranges)| { + .map(|(_, snapshot, path, ranges)| { let excerpts = merge_excerpts( - &buffer, + &snapshot, ranges.iter().map(|range| { - let point_range = range.to_point(&buffer); + let point_range = range.to_point(&snapshot); Line(point_range.start.row)..Line(point_range.end.row) }), ); predict_edits_v3::IncludedFile { path: path.clone(), - max_row: Line(buffer.max_point().row), + max_row: Line(snapshot.max_point().row), excerpts, } }) @@ -948,8 +958,10 @@ impl Zeta { client, llm_token, app_version, - #[cfg(feature = "llm-response-cache")] - llm_response_cache, + #[cfg(feature = "eval-support")] + eval_cache, + #[cfg(feature = "eval-support")] + EvalCacheEntryKind::Prediction, ) .await; let request_time = chrono::Utc::now() - before_request; @@ -1049,9 +1061,8 @@ impl Zeta { client: Arc, llm_token: LlmApiToken, app_version: SemanticVersion, - #[cfg(feature = "llm-response-cache")] llm_response_cache: Option< - Arc, - >, + #[cfg(feature = "eval-support")] eval_cache: Option>, + #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind, ) -> Result<(open_ai::Response, Option)> { let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() { http_client::Url::parse(&predict_edits_url)? @@ -1061,16 +1072,23 @@ impl Zeta { .build_zed_llm_url("/predict_edits/raw", &[])? }; - #[cfg(feature = "llm-response-cache")] - let cache_key = if let Some(cache) = llm_response_cache { - let request_json = serde_json::to_string(&request)?; - let key = cache.get_key(&url, &request_json); + #[cfg(feature = "eval-support")] + let cache_key = if let Some(cache) = eval_cache { + use collections::FxHasher; + use std::hash::{Hash, Hasher}; + + let mut hasher = FxHasher::default(); + url.hash(&mut hasher); + let request_str = serde_json::to_string_pretty(&request)?; + request_str.hash(&mut hasher); + let hash = hasher.finish(); - if let Some(response_str) = cache.read_response(key) { + let key = (eval_cache_kind, hash); + if let Some(response_str) = cache.read(key) { return Ok((serde_json::from_str(&response_str)?, None)); } - Some((cache, key)) + Some((cache, request_str, key)) } else { None }; @@ -1088,9 +1106,9 @@ impl Zeta { ) .await?; - #[cfg(feature = "llm-response-cache")] - if let Some((cache, key)) = cache_key { - cache.write_response(key, &serde_json::to_string(&response)?); + #[cfg(feature = "eval-support")] + if let Some((cache, request, key)) = cache_key { + cache.write(key, &request, &serde_json::to_string_pretty(&response)?); } Ok((response, usage)) @@ -1361,8 +1379,8 @@ impl Zeta { reasoning_effort: None, }; - #[cfg(feature = "llm-response-cache")] - let llm_response_cache = self.llm_response_cache.clone(); + #[cfg(feature = "eval-support")] + let eval_cache = self.eval_cache.clone(); cx.spawn(async move |this, cx| { log::trace!("Sending search planning request"); @@ -1371,8 +1389,10 @@ impl Zeta { client, llm_token, app_version, - #[cfg(feature = "llm-response-cache")] - llm_response_cache, + #[cfg(feature = "eval-support")] + eval_cache.clone(), + #[cfg(feature = "eval-support")] + EvalCacheEntryKind::Context, ) .await; let mut response = Self::handle_api_response(&this, response, cx)?; @@ -1421,8 +1441,14 @@ impl Zeta { log::trace!("Running retrieval search: {queries:#?}"); - let related_excerpts_result = - retrieval_search::run_retrieval_searches(project.clone(), queries, cx).await; + let related_excerpts_result = retrieval_search::run_retrieval_searches( + queries, + project.clone(), + #[cfg(feature = "eval-support")] + eval_cache, + cx, + ) + .await; log::trace!("Search queries executed"); @@ -1772,6 +1798,34 @@ fn add_signature( Some(signature_index) } +#[cfg(feature = "eval-support")] +pub type EvalCacheKey = (EvalCacheEntryKind, u64); + +#[cfg(feature = "eval-support")] +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum EvalCacheEntryKind { + Context, + Search, + Prediction, +} + +#[cfg(feature = "eval-support")] +impl std::fmt::Display for EvalCacheEntryKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + EvalCacheEntryKind::Search => write!(f, "search"), + EvalCacheEntryKind::Context => write!(f, "context"), + EvalCacheEntryKind::Prediction => write!(f, "prediction"), + } + } +} + +#[cfg(feature = "eval-support")] +pub trait EvalCache: Send + Sync { + fn read(&self, key: EvalCacheKey) -> Option; + fn write(&self, key: EvalCacheKey, input: &str, value: &str); +} + #[cfg(test)] mod tests { use std::{path::Path, sync::Arc}; diff --git a/crates/zeta_cli/Cargo.toml b/crates/zeta_cli/Cargo.toml index 2e62f2a4462e31b7632aa5e825ea76a4b7df5fc8..e18cf54787ca98e2be60db4977dd2de18e9c09e2 100644 --- a/crates/zeta_cli/Cargo.toml +++ b/crates/zeta_cli/Cargo.toml @@ -54,7 +54,7 @@ toml.workspace = true util.workspace = true watch.workspace = true zeta.workspace = true -zeta2 = { workspace = true, features = ["llm-response-cache"] } +zeta2 = { workspace = true, features = ["eval-support"] } zlog.workspace = true [dev-dependencies] diff --git a/crates/zeta_cli/src/evaluate.rs b/crates/zeta_cli/src/evaluate.rs index b5c23af24845a90d153943f6ee2ccd29bbfaf6a7..0359ccf0fea3179dd480645ad7031b61fc3a357c 100644 --- a/crates/zeta_cli/src/evaluate.rs +++ b/crates/zeta_cli/src/evaluate.rs @@ -14,18 +14,19 @@ use crate::{ PromptFormat, example::{Example, NamedExample}, headless::ZetaCliAppState, - predict::{PredictionDetails, zeta2_predict}, + paths::print_run_data_dir, + predict::{CacheMode, PredictionDetails, zeta2_predict}, }; #[derive(Debug, Args)] pub struct EvaluateArguments { example_paths: Vec, - #[clap(long)] - skip_cache: bool, #[arg(long, value_enum, default_value_t = PromptFormat::default())] prompt_format: PromptFormat, #[arg(long)] use_expected_context: bool, + #[clap(long, value_enum, default_value_t = CacheMode::default())] + cache: CacheMode, } pub async fn run_evaluate( @@ -39,43 +40,49 @@ pub async fn run_evaluate( cx.spawn(async move |cx| { run_evaluate_one( &path, - args.skip_cache, args.prompt_format, args.use_expected_context, + args.cache, app_state.clone(), cx, ) .await }) }); - let all_results = futures::future::try_join_all(all_tasks).await.unwrap(); + let all_results = futures::future::try_join_all(all_tasks).await; + + if let Ok(all_results) = &all_results { + let aggregated_result = EvaluationResult { + context: Scores::aggregate(all_results.iter().map(|r| &r.context)), + edit_prediction: Scores::aggregate(all_results.iter().map(|r| &r.edit_prediction)), + }; + + if example_len > 1 { + println!("\n{}", "-".repeat(80)); + println!("\n## TOTAL SCORES"); + println!("{}", aggregated_result.to_markdown()); + } + } - let aggregated_result = EvaluationResult { - context: Scores::aggregate(all_results.iter().map(|r| &r.context)), - edit_prediction: Scores::aggregate(all_results.iter().map(|r| &r.edit_prediction)), - }; + print_run_data_dir(); - if example_len > 1 { - println!("\n{}", "-".repeat(80)); - println!("# TOTAL SCORES:"); - println!("{}", aggregated_result.to_markdown()); - } + all_results.unwrap(); } pub async fn run_evaluate_one( example_path: &Path, - skip_cache: bool, prompt_format: PromptFormat, use_expected_context: bool, + cache_mode: CacheMode, app_state: Arc, cx: &mut AsyncApp, ) -> Result { let example = NamedExample::load(&example_path).unwrap(); let predictions = zeta2_predict( example.clone(), - skip_cache, prompt_format, use_expected_context, + cache_mode, &app_state, cx, ) diff --git a/crates/zeta_cli/src/example.rs b/crates/zeta_cli/src/example.rs index 20176fbb5d73de83b90b8edb2831104ecddc8ef0..3e55fb0b62e0191fa5abf1014a71bc7f613fc0c9 100644 --- a/crates/zeta_cli/src/example.rs +++ b/crates/zeta_cli/src/example.rs @@ -398,7 +398,7 @@ impl NamedExample { Ok(worktree_path) } - fn file_name(&self) -> String { + pub fn file_name(&self) -> String { self.name .chars() .map(|c| { diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index 82760d6061d9b96a2da74bf5cb24e43d9ecdba60..1dd246e612979e7a4a77c74926be1a5cab72dbc6 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -54,6 +54,7 @@ enum Command { #[arg(long, value_enum, default_value_t = ExampleFormat::Md)] output_format: ExampleFormat, }, + Clean, } #[derive(Subcommand, Debug)] @@ -470,6 +471,7 @@ fn main() { let example = NamedExample::load(path).unwrap(); example.write(output_format, io::stdout()).unwrap(); } + Command::Clean => std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap(), }; let _ = cx.update(|cx| cx.quit()); diff --git a/crates/zeta_cli/src/paths.rs b/crates/zeta_cli/src/paths.rs index fc7f8b3afc3dbcd724649749a58b76dbab275750..73d541c6a0409deab5baac1714feded986fb94c1 100644 --- a/crates/zeta_cli/src/paths.rs +++ b/crates/zeta_cli/src/paths.rs @@ -1,16 +1,40 @@ use std::{env, path::PathBuf, sync::LazyLock}; -static TARGET_DIR: LazyLock = LazyLock::new(|| env::current_dir().unwrap().join("target")); -pub static CACHE_DIR: LazyLock = - LazyLock::new(|| TARGET_DIR.join("zeta-llm-response-cache")); -pub static REPOS_DIR: LazyLock = LazyLock::new(|| TARGET_DIR.join("zeta-repos")); -pub static WORKTREES_DIR: LazyLock = LazyLock::new(|| TARGET_DIR.join("zeta-worktrees")); -pub static LOGS_DIR: LazyLock = LazyLock::new(|| TARGET_DIR.join("zeta-logs")); -pub static LOGS_SEARCH_PROMPT: LazyLock = - LazyLock::new(|| LOGS_DIR.join("search_prompt.md")); -pub static LOGS_SEARCH_QUERIES: LazyLock = - LazyLock::new(|| LOGS_DIR.join("search_queries.json")); -pub static LOGS_PREDICTION_PROMPT: LazyLock = - LazyLock::new(|| LOGS_DIR.join("prediction_prompt.md")); -pub static LOGS_PREDICTION_RESPONSE: LazyLock = - LazyLock::new(|| LOGS_DIR.join("prediction_response.md")); +pub static TARGET_ZETA_DIR: LazyLock = + LazyLock::new(|| env::current_dir().unwrap().join("target/zeta")); +pub static CACHE_DIR: LazyLock = LazyLock::new(|| TARGET_ZETA_DIR.join("cache")); +pub static REPOS_DIR: LazyLock = LazyLock::new(|| TARGET_ZETA_DIR.join("repos")); +pub static WORKTREES_DIR: LazyLock = LazyLock::new(|| TARGET_ZETA_DIR.join("worktrees")); +pub static RUN_DIR: LazyLock = LazyLock::new(|| { + TARGET_ZETA_DIR + .join("runs") + .join(chrono::Local::now().format("%d-%m-%y-%H_%M_%S").to_string()) +}); +pub static LATEST_EXAMPLE_RUN_DIR: LazyLock = + LazyLock::new(|| TARGET_ZETA_DIR.join("latest")); + +pub fn print_run_data_dir() { + println!("\n## Run Data\n"); + + let current_dir = std::env::current_dir().unwrap(); + for file in std::fs::read_dir(&*RUN_DIR).unwrap() { + let file = file.unwrap(); + if file.file_type().unwrap().is_dir() { + for file in std::fs::read_dir(file.path()).unwrap() { + let path = file.unwrap().path(); + let path = path.strip_prefix(¤t_dir).unwrap_or(&path); + println!( + "- {}/\x1b[34m{}\x1b[0m", + path.parent().unwrap().display(), + path.file_name().unwrap().display(), + ); + } + } else { + let path = file.path(); + println!( + "- {} ", + path.strip_prefix(¤t_dir).unwrap_or(&path).display() + ); + } + } +} diff --git a/crates/zeta_cli/src/predict.rs b/crates/zeta_cli/src/predict.rs index 32f2f564fc53df987579bf2946eb5765519157c6..82108df076c025089f5e374f447a3136fdb0c563 100644 --- a/crates/zeta_cli/src/predict.rs +++ b/crates/zeta_cli/src/predict.rs @@ -1,20 +1,15 @@ use crate::PromptFormat; use crate::example::{ActualExcerpt, ExpectedExcerpt, NamedExample}; use crate::headless::ZetaCliAppState; -use crate::paths::{ - CACHE_DIR, LOGS_DIR, LOGS_PREDICTION_PROMPT, LOGS_PREDICTION_RESPONSE, LOGS_SEARCH_PROMPT, - LOGS_SEARCH_QUERIES, -}; +use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir}; use ::serde::Serialize; -use anyhow::{Result, anyhow}; -use clap::Args; -use collections::HashMap; -use gpui::http_client::Url; -use language::{Anchor, Buffer, Point}; -// use cloud_llm_client::predict_edits_v3::PromptFormat; +use anyhow::{Context, Result, anyhow}; +use clap::{Args, ValueEnum}; use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock}; +use collections::HashMap; use futures::StreamExt as _; use gpui::{AppContext, AsyncApp, Entity}; +use language::{Anchor, Buffer, Point}; use project::Project; use serde::Deserialize; use std::cell::Cell; @@ -25,7 +20,7 @@ use std::path::PathBuf; use std::sync::Arc; use std::sync::Mutex; use std::time::{Duration, Instant}; -use zeta2::LlmResponseCache; +use zeta2::{EvalCache, EvalCacheEntryKind, EvalCacheKey}; #[derive(Debug, Args)] pub struct PredictArguments { @@ -36,8 +31,31 @@ pub struct PredictArguments { #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)] format: PredictionsOutputFormat, example_path: PathBuf, - #[clap(long)] - skip_cache: bool, + #[clap(long, value_enum, default_value_t = CacheMode::default())] + cache: CacheMode, +} + +#[derive(Debug, ValueEnum, Default, Clone, Copy)] +pub enum CacheMode { + /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint. + #[default] + #[value(alias = "request")] + Requests, + /// Ignore existing cache entries for both LLM and search. + Skip, + /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet. + /// Useful for reproducing results and fixing bugs outside of search queries + Force, +} + +impl CacheMode { + fn use_cached_llm_responses(&self) -> bool { + matches!(self, CacheMode::Requests | CacheMode::Force) + } + + fn use_cached_search_results(&self) -> bool { + matches!(self, CacheMode::Force) + } } #[derive(clap::ValueEnum, Debug, Clone)] @@ -55,9 +73,9 @@ pub async fn run_zeta2_predict( let example = NamedExample::load(args.example_path).unwrap(); let result = zeta2_predict( example, - args.skip_cache, args.prompt_format, args.use_expected_context, + args.cache, &app_state, cx, ) @@ -65,14 +83,7 @@ pub async fn run_zeta2_predict( .unwrap(); result.write(args.format, std::io::stdout()).unwrap(); - println!("## Logs\n"); - println!("Search prompt: {}", LOGS_SEARCH_PROMPT.display()); - println!("Search queries: {}", LOGS_SEARCH_QUERIES.display()); - println!("Prediction prompt: {}", LOGS_PREDICTION_PROMPT.display()); - println!( - "Prediction response: {}", - LOGS_PREDICTION_RESPONSE.display() - ); + print_run_data_dir(); } thread_local! { @@ -81,13 +92,12 @@ thread_local! { pub async fn zeta2_predict( example: NamedExample, - skip_cache: bool, prompt_format: PromptFormat, use_expected_context: bool, + cache_mode: CacheMode, app_state: &Arc, cx: &mut AsyncApp, ) -> Result { - fs::create_dir_all(&*LOGS_DIR)?; let worktree_path = example.setup_worktree().await?; if !AUTHENTICATED.get() { @@ -126,8 +136,25 @@ pub async fn zeta2_predict( let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?; + let example_run_dir = RUN_DIR.join(&example.file_name()); + fs::create_dir_all(&example_run_dir)?; + if LATEST_EXAMPLE_RUN_DIR.exists() { + fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?; + } + + #[cfg(unix)] + std::os::unix::fs::symlink(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR) + .context("creating latest link")?; + + #[cfg(windows)] + std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR) + .context("creating latest link")?; + zeta.update(cx, |zeta, _cx| { - zeta.with_llm_response_cache(Arc::new(Cache { skip_cache })); + zeta.with_eval_cache(Arc::new(RunCache { + example_run_dir: example_run_dir.clone(), + cache_mode, + })); })?; cx.subscribe(&buffer_store, { @@ -159,12 +186,15 @@ pub async fn zeta2_predict( match event { zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => { start_time = Some(info.timestamp); - fs::write(&*LOGS_SEARCH_PROMPT, &info.search_prompt)?; + fs::write( + example_run_dir.join("search_prompt.md"), + &info.search_prompt, + )?; } zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => { search_queries_generated_at = Some(info.timestamp); fs::write( - &*LOGS_SEARCH_QUERIES, + example_run_dir.join("search_queries.json"), serde_json::to_string_pretty(&info.search_queries).unwrap(), )?; } @@ -176,7 +206,7 @@ pub async fn zeta2_predict( let prediction_started_at = Instant::now(); start_time.get_or_insert(prediction_started_at); fs::write( - &*LOGS_PREDICTION_PROMPT, + example_run_dir.join("prediction_prompt.md"), &request.local_prompt.unwrap_or_default(), )?; @@ -210,7 +240,7 @@ pub async fn zeta2_predict( let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?; let response = zeta2::text_from_response(response).unwrap_or_default(); let prediction_finished_at = Instant::now(); - fs::write(&*LOGS_PREDICTION_RESPONSE, &response)?; + fs::write(example_run_dir.join("prediction_response.md"), &response)?; let mut result = result.lock().unwrap(); @@ -328,48 +358,69 @@ async fn resolve_context_entry( Ok((buffer, ranges)) } -struct Cache { - skip_cache: bool, +struct RunCache { + cache_mode: CacheMode, + example_run_dir: PathBuf, } -impl Cache { - fn path(key: u64) -> PathBuf { - CACHE_DIR.join(format!("{key:x}.json")) +impl RunCache { + fn output_cache_path((kind, key): &EvalCacheKey) -> PathBuf { + CACHE_DIR.join(format!("{kind}_out_{key:x}.json",)) } -} -impl LlmResponseCache for Cache { - fn get_key(&self, url: &Url, body: &str) -> u64 { - use collections::FxHasher; - use std::hash::{Hash, Hasher}; + fn input_cache_path((kind, key): &EvalCacheKey) -> PathBuf { + CACHE_DIR.join(format!("{kind}_in_{key:x}.json",)) + } - let mut hasher = FxHasher::default(); - url.hash(&mut hasher); - body.hash(&mut hasher); - hasher.finish() + fn link_to_run(&self, key: &EvalCacheKey) { + let output_link_path = self.example_run_dir.join(format!("{}_out.json", key.0)); + fs::hard_link(Self::output_cache_path(key), &output_link_path).unwrap(); + + let input_link_path = self.example_run_dir.join(format!("{}_in.json", key.0)); + fs::hard_link(Self::input_cache_path(key), &input_link_path).unwrap(); } +} + +impl EvalCache for RunCache { + fn read(&self, key: EvalCacheKey) -> Option { + let path = RunCache::output_cache_path(&key); - fn read_response(&self, key: u64) -> Option { - let path = Cache::path(key); if path.exists() { - if self.skip_cache { - log::info!("Skipping existing cached LLM response: {}", path.display()); - None - } else { - log::info!("Using LLM response from cache: {}", path.display()); + let use_cache = match key.0 { + EvalCacheEntryKind::Search => self.cache_mode.use_cached_search_results(), + EvalCacheEntryKind::Context | EvalCacheEntryKind::Prediction => { + self.cache_mode.use_cached_llm_responses() + } + }; + if use_cache { + log::info!("Using cache entry: {}", path.display()); + self.link_to_run(&key); Some(fs::read_to_string(path).unwrap()) + } else { + log::info!("Skipping cached entry: {}", path.display()); + None } + } else if matches!(self.cache_mode, CacheMode::Force) { + panic!( + "No cached entry found for {:?}. Run without `--cache force` at least once.", + key.0 + ); } else { None } } - fn write_response(&self, key: u64, value: &str) { + fn write(&self, key: EvalCacheKey, input: &str, output: &str) { fs::create_dir_all(&*CACHE_DIR).unwrap(); - let path = Cache::path(key); - log::info!("Writing LLM response to cache: {}", path.display()); - fs::write(path, value).unwrap(); + let input_path = RunCache::input_cache_path(&key); + fs::write(&input_path, input).unwrap(); + + let output_path = RunCache::output_cache_path(&key); + log::info!("Writing cache entry: {}", output_path.display()); + fs::write(&output_path, output).unwrap(); + + self.link_to_run(&key); } }