zeta eval: Improve determinism and debugging ergonomics (#42478)

Ben Kunkle and Agus created

- Improves the determinism of the search step for better cache
reusability
- Adds a `--cache force` mode that refuses to make any requests or
searches that aren't cached
- The structure of the `zeta-*` directories under `target` has been
rethought for convenience

Release Notes:

- N/A

---------

Co-authored-by: Agus <agus@zed.dev>

Change summary

crates/cloud_zeta2_prompt/src/retrieval_prompt.rs |   2 
crates/zeta2/Cargo.toml                           |   2 
crates/zeta2/src/retrieval_search.rs              |  95 +++++++++
crates/zeta2/src/xml_edits.rs                     |  39 ++++
crates/zeta2/src/zeta2.rs                         | 138 ++++++++++----
crates/zeta_cli/Cargo.toml                        |   2 
crates/zeta_cli/src/evaluate.rs                   |  39 ++-
crates/zeta_cli/src/example.rs                    |   2 
crates/zeta_cli/src/main.rs                       |   2 
crates/zeta_cli/src/paths.rs                      |  52 ++++-
crates/zeta_cli/src/predict.rs                    | 159 +++++++++++-----
11 files changed, 395 insertions(+), 137 deletions(-)

Detailed changes

crates/cloud_zeta2_prompt/src/retrieval_prompt.rs 🔗

@@ -44,7 +44,7 @@ pub struct SearchToolInput {
 }
 
 /// Search for relevant code by path, syntax hierarchy, and content.
-#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
+#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Hash)]
 pub struct SearchToolQuery {
     /// 1. A glob pattern to match file paths in the codebase to search in.
     pub glob: String,

crates/zeta2/Cargo.toml 🔗

@@ -12,7 +12,7 @@ workspace = true
 path = "src/zeta2.rs"
 
 [features]
-llm-response-cache = []
+eval-support = []
 
 [dependencies]
 anyhow.workspace = true

crates/zeta2/src/retrieval_search.rs 🔗

@@ -1,5 +1,3 @@
-use std::ops::Range;
-
 use anyhow::Result;
 use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery;
 use collections::HashMap;
@@ -14,17 +12,76 @@ use project::{
     search::{SearchQuery, SearchResult},
 };
 use smol::channel;
+use std::ops::Range;
 use util::{
     ResultExt as _,
     paths::{PathMatcher, PathStyle},
 };
 use workspace::item::Settings as _;
 
+#[cfg(feature = "eval-support")]
+type CachedSearchResults = std::collections::BTreeMap<std::path::PathBuf, Vec<Range<usize>>>;
+
 pub async fn run_retrieval_searches(
-    project: Entity<Project>,
     queries: Vec<SearchToolQuery>,
+    project: Entity<Project>,
+    #[cfg(feature = "eval-support")] eval_cache: Option<std::sync::Arc<dyn crate::EvalCache>>,
     cx: &mut AsyncApp,
 ) -> Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>> {
+    #[cfg(feature = "eval-support")]
+    let cache = if let Some(eval_cache) = eval_cache {
+        use crate::EvalCacheEntryKind;
+        use anyhow::Context;
+        use collections::FxHasher;
+        use std::hash::{Hash, Hasher};
+
+        let mut hasher = FxHasher::default();
+        project.read_with(cx, |project, cx| {
+            let mut worktrees = project.worktrees(cx);
+            let Some(worktree) = worktrees.next() else {
+                panic!("Expected a single worktree in eval project. Found none.");
+            };
+            assert!(
+                worktrees.next().is_none(),
+                "Expected a single worktree in eval project. Found more than one."
+            );
+            worktree.read(cx).abs_path().hash(&mut hasher);
+        })?;
+
+        queries.hash(&mut hasher);
+        let key = (EvalCacheEntryKind::Search, hasher.finish());
+
+        if let Some(cached_results) = eval_cache.read(key) {
+            let file_results = serde_json::from_str::<CachedSearchResults>(&cached_results)
+                .context("Failed to deserialize cached search results")?;
+            let mut results = HashMap::default();
+
+            for (path, ranges) in file_results {
+                let buffer = project
+                    .update(cx, |project, cx| {
+                        let project_path = project.find_project_path(path, cx).unwrap();
+                        project.open_buffer(project_path, cx)
+                    })?
+                    .await?;
+                let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
+                let mut ranges = ranges
+                    .into_iter()
+                    .map(|range| {
+                        snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end)
+                    })
+                    .collect();
+                merge_anchor_ranges(&mut ranges, &snapshot);
+                results.insert(buffer, ranges);
+            }
+
+            return Ok(results);
+        }
+
+        Some((eval_cache, serde_json::to_string_pretty(&queries)?, key))
+    } else {
+        None
+    };
+
     let (exclude_matcher, path_style) = project.update(cx, |project, cx| {
         let global_settings = WorktreeSettings::get_global(cx);
         let exclude_patterns = global_settings
@@ -58,6 +115,8 @@ pub async fn run_retrieval_searches(
     }
     drop(results_tx);
 
+    #[cfg(feature = "eval-support")]
+    let cache = cache.clone();
     cx.background_spawn(async move {
         let mut results: HashMap<Entity<Buffer>, Vec<Range<Anchor>>> = HashMap::default();
         let mut snapshots = HashMap::default();
@@ -79,6 +138,29 @@ pub async fn run_retrieval_searches(
             }
         }
 
+        #[cfg(feature = "eval-support")]
+        if let Some((cache, queries, key)) = cache {
+            let cached_results: CachedSearchResults = results
+                .iter()
+                .filter_map(|(buffer, ranges)| {
+                    let snapshot = snapshots.get(&buffer.entity_id())?;
+                    let path = snapshot.file().map(|f| f.path());
+                    let mut ranges = ranges
+                        .iter()
+                        .map(|range| range.to_offset(&snapshot))
+                        .collect::<Vec<_>>();
+                    ranges.sort_unstable_by_key(|range| (range.start, range.end));
+
+                    Some((path?.as_std_path().to_path_buf(), ranges))
+                })
+                .collect();
+            cache.write(
+                key,
+                &queries,
+                &serde_json::to_string_pretty(&cached_results)?,
+            );
+        }
+
         for (buffer, ranges) in results.iter_mut() {
             if let Some(snapshot) = snapshots.get(&buffer.entity_id()) {
                 merge_anchor_ranges(ranges, snapshot);
@@ -489,9 +571,10 @@ mod tests {
         expected_output: &str,
         cx: &mut TestAppContext,
     ) {
-        let results = run_retrieval_searches(project.clone(), vec![query], &mut cx.to_async())
-            .await
-            .unwrap();
+        let results =
+            run_retrieval_searches(vec![query], project.clone(), None, &mut cx.to_async())
+                .await
+                .unwrap();
 
         let mut results = results.into_iter().collect::<Vec<_>>();
         results.sort_by_key(|results| {

crates/zeta2/src/xml_edits.rs 🔗

@@ -105,21 +105,58 @@ fn resolve_new_text_old_text_in_buffer(
 #[cfg(debug_assertions)]
 fn closest_old_text_match(buffer: &TextBufferSnapshot, old_text: &str) -> Option<String> {
     let buffer_text = buffer.text();
-    let mut cursor = 0;
     let len = old_text.len();
 
+    if len == 0 || buffer_text.len() < len {
+        return None;
+    }
+
     let mut min_score = usize::MAX;
     let mut min_start = 0;
 
+    let old_text_bytes = old_text.as_bytes();
+    let old_alpha_count = old_text_bytes
+        .iter()
+        .filter(|&&b| b.is_ascii_alphanumeric())
+        .count();
+
+    let old_line_count = old_text.lines().count();
+
+    let mut cursor = 0;
+
     while cursor + len <= buffer_text.len() {
         let candidate = &buffer_text[cursor..cursor + len];
+        let candidate_bytes = candidate.as_bytes();
+
+        if usize::abs_diff(candidate.lines().count(), old_line_count) > 4 {
+            cursor += 1;
+            continue;
+        }
+
+        let candidate_alpha_count = candidate_bytes
+            .iter()
+            .filter(|&&b| b.is_ascii_alphanumeric())
+            .count();
+
+        // If alphanumeric character count differs by more than 30%, skip
+        if usize::abs_diff(old_alpha_count, candidate_alpha_count) * 10 > old_alpha_count * 3 {
+            cursor += 1;
+            continue;
+        }
+
         let score = strsim::levenshtein(candidate, old_text);
         if score < min_score {
             min_score = score;
             min_start = cursor;
+
+            if min_score <= len / 10 {
+                break;
+            }
         }
+
         cursor += 1;
     }
+
     if min_score != usize::MAX {
         Some(buffer_text[min_start..min_start + len].to_string())
     } else {

crates/zeta2/src/zeta2.rs 🔗

@@ -132,15 +132,8 @@ pub struct Zeta {
     options: ZetaOptions,
     update_required: bool,
     debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
-    #[cfg(feature = "llm-response-cache")]
-    llm_response_cache: Option<Arc<dyn LlmResponseCache>>,
-}
-
-#[cfg(feature = "llm-response-cache")]
-pub trait LlmResponseCache: Send + Sync {
-    fn get_key(&self, url: &gpui::http_client::Url, body: &str) -> u64;
-    fn read_response(&self, key: u64) -> Option<String>;
-    fn write_response(&self, key: u64, value: &str);
+    #[cfg(feature = "eval-support")]
+    eval_cache: Option<Arc<dyn EvalCache>>,
 }
 
 #[derive(Debug, Clone, PartialEq)]
@@ -369,14 +362,14 @@ impl Zeta {
             ),
             update_required: false,
             debug_tx: None,
-            #[cfg(feature = "llm-response-cache")]
-            llm_response_cache: None,
+            #[cfg(feature = "eval-support")]
+            eval_cache: None,
         }
     }
 
-    #[cfg(feature = "llm-response-cache")]
-    pub fn with_llm_response_cache(&mut self, cache: Arc<dyn LlmResponseCache>) {
-        self.llm_response_cache = Some(cache);
+    #[cfg(feature = "eval-support")]
+    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
+        self.eval_cache = Some(cache);
     }
 
     pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
@@ -736,9 +729,19 @@ impl Zeta {
         // TODO data collection
         let can_collect_data = cx.is_staff();
 
-        let mut included_files = project_state
+        let empty_context_files = HashMap::default();
+        let context_files = project_state
             .and_then(|project_state| project_state.context.as_ref())
-            .unwrap_or(&HashMap::default())
+            .unwrap_or(&empty_context_files);
+
+        #[cfg(feature = "eval-support")]
+        let parsed_fut = futures::future::join_all(
+            context_files
+                .keys()
+                .map(|buffer| buffer.read(cx).parsing_idle()),
+        );
+
+        let mut included_files = context_files
             .iter()
             .filter_map(|(buffer_entity, ranges)| {
                 let buffer = buffer_entity.read(cx);
@@ -751,12 +754,19 @@ impl Zeta {
             })
             .collect::<Vec<_>>();
 
-        #[cfg(feature = "llm-response-cache")]
-        let llm_response_cache = self.llm_response_cache.clone();
+        included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
+            (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
+        });
+
+        #[cfg(feature = "eval-support")]
+        let eval_cache = self.eval_cache.clone();
 
         let request_task = cx.background_spawn({
             let active_buffer = active_buffer.clone();
             async move {
+                #[cfg(feature = "eval-support")]
+                parsed_fut.await;
+
                 let index_state = if let Some(index_state) = index_state {
                     Some(index_state.lock_owned().await)
                 } else {
@@ -819,17 +829,17 @@ impl Zeta {
 
                         let included_files = included_files
                             .iter()
-                            .map(|(_, buffer, path, ranges)| {
+                            .map(|(_, snapshot, path, ranges)| {
                                 let excerpts = merge_excerpts(
-                                    &buffer,
+                                    &snapshot,
                                     ranges.iter().map(|range| {
-                                        let point_range = range.to_point(&buffer);
+                                        let point_range = range.to_point(&snapshot);
                                         Line(point_range.start.row)..Line(point_range.end.row)
                                     }),
                                 );
                                 predict_edits_v3::IncludedFile {
                                     path: path.clone(),
-                                    max_row: Line(buffer.max_point().row),
+                                    max_row: Line(snapshot.max_point().row),
                                     excerpts,
                                 }
                             })
@@ -948,8 +958,10 @@ impl Zeta {
                     client,
                     llm_token,
                     app_version,
-                    #[cfg(feature = "llm-response-cache")]
-                    llm_response_cache,
+                    #[cfg(feature = "eval-support")]
+                    eval_cache,
+                    #[cfg(feature = "eval-support")]
+                    EvalCacheEntryKind::Prediction,
                 )
                 .await;
                 let request_time = chrono::Utc::now() - before_request;
@@ -1049,9 +1061,8 @@ impl Zeta {
         client: Arc<Client>,
         llm_token: LlmApiToken,
         app_version: SemanticVersion,
-        #[cfg(feature = "llm-response-cache")] llm_response_cache: Option<
-            Arc<dyn LlmResponseCache>,
-        >,
+        #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
+        #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
     ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
         let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
             http_client::Url::parse(&predict_edits_url)?
@@ -1061,16 +1072,23 @@ impl Zeta {
                 .build_zed_llm_url("/predict_edits/raw", &[])?
         };
 
-        #[cfg(feature = "llm-response-cache")]
-        let cache_key = if let Some(cache) = llm_response_cache {
-            let request_json = serde_json::to_string(&request)?;
-            let key = cache.get_key(&url, &request_json);
+        #[cfg(feature = "eval-support")]
+        let cache_key = if let Some(cache) = eval_cache {
+            use collections::FxHasher;
+            use std::hash::{Hash, Hasher};
+
+            let mut hasher = FxHasher::default();
+            url.hash(&mut hasher);
+            let request_str = serde_json::to_string_pretty(&request)?;
+            request_str.hash(&mut hasher);
+            let hash = hasher.finish();
 
-            if let Some(response_str) = cache.read_response(key) {
+            let key = (eval_cache_kind, hash);
+            if let Some(response_str) = cache.read(key) {
                 return Ok((serde_json::from_str(&response_str)?, None));
             }
 
-            Some((cache, key))
+            Some((cache, request_str, key))
         } else {
             None
         };
@@ -1088,9 +1106,9 @@ impl Zeta {
         )
         .await?;
 
-        #[cfg(feature = "llm-response-cache")]
-        if let Some((cache, key)) = cache_key {
-            cache.write_response(key, &serde_json::to_string(&response)?);
+        #[cfg(feature = "eval-support")]
+        if let Some((cache, request, key)) = cache_key {
+            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
         }
 
         Ok((response, usage))
@@ -1361,8 +1379,8 @@ impl Zeta {
             reasoning_effort: None,
         };
 
-        #[cfg(feature = "llm-response-cache")]
-        let llm_response_cache = self.llm_response_cache.clone();
+        #[cfg(feature = "eval-support")]
+        let eval_cache = self.eval_cache.clone();
 
         cx.spawn(async move |this, cx| {
             log::trace!("Sending search planning request");
@@ -1371,8 +1389,10 @@ impl Zeta {
                 client,
                 llm_token,
                 app_version,
-                #[cfg(feature = "llm-response-cache")]
-                llm_response_cache,
+                #[cfg(feature = "eval-support")]
+                eval_cache.clone(),
+                #[cfg(feature = "eval-support")]
+                EvalCacheEntryKind::Context,
             )
             .await;
             let mut response = Self::handle_api_response(&this, response, cx)?;
@@ -1421,8 +1441,14 @@ impl Zeta {
 
             log::trace!("Running retrieval search: {queries:#?}");
 
-            let related_excerpts_result =
-                retrieval_search::run_retrieval_searches(project.clone(), queries, cx).await;
+            let related_excerpts_result = retrieval_search::run_retrieval_searches(
+                queries,
+                project.clone(),
+                #[cfg(feature = "eval-support")]
+                eval_cache,
+                cx,
+            )
+            .await;
 
             log::trace!("Search queries executed");
 
@@ -1772,6 +1798,34 @@ fn add_signature(
     Some(signature_index)
 }
 
+#[cfg(feature = "eval-support")]
+pub type EvalCacheKey = (EvalCacheEntryKind, u64);
+
+#[cfg(feature = "eval-support")]
+#[derive(Debug, Clone, Copy, PartialEq)]
+pub enum EvalCacheEntryKind {
+    Context,
+    Search,
+    Prediction,
+}
+
+#[cfg(feature = "eval-support")]
+impl std::fmt::Display for EvalCacheEntryKind {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        match self {
+            EvalCacheEntryKind::Search => write!(f, "search"),
+            EvalCacheEntryKind::Context => write!(f, "context"),
+            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
+        }
+    }
+}
+
+#[cfg(feature = "eval-support")]
+pub trait EvalCache: Send + Sync {
+    fn read(&self, key: EvalCacheKey) -> Option<String>;
+    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
+}
+
 #[cfg(test)]
 mod tests {
     use std::{path::Path, sync::Arc};

crates/zeta_cli/Cargo.toml 🔗

@@ -54,7 +54,7 @@ toml.workspace = true
 util.workspace = true
 watch.workspace = true
 zeta.workspace = true
-zeta2 = { workspace = true, features = ["llm-response-cache"] }
+zeta2 = { workspace = true, features = ["eval-support"] }
 zlog.workspace = true
 
 [dev-dependencies]

crates/zeta_cli/src/evaluate.rs 🔗

@@ -14,18 +14,19 @@ use crate::{
     PromptFormat,
     example::{Example, NamedExample},
     headless::ZetaCliAppState,
-    predict::{PredictionDetails, zeta2_predict},
+    paths::print_run_data_dir,
+    predict::{CacheMode, PredictionDetails, zeta2_predict},
 };
 
 #[derive(Debug, Args)]
 pub struct EvaluateArguments {
     example_paths: Vec<PathBuf>,
-    #[clap(long)]
-    skip_cache: bool,
     #[arg(long, value_enum, default_value_t = PromptFormat::default())]
     prompt_format: PromptFormat,
     #[arg(long)]
     use_expected_context: bool,
+    #[clap(long, value_enum, default_value_t = CacheMode::default())]
+    cache: CacheMode,
 }
 
 pub async fn run_evaluate(
@@ -39,43 +40,49 @@ pub async fn run_evaluate(
         cx.spawn(async move |cx| {
             run_evaluate_one(
                 &path,
-                args.skip_cache,
                 args.prompt_format,
                 args.use_expected_context,
+                args.cache,
                 app_state.clone(),
                 cx,
             )
             .await
         })
     });
-    let all_results = futures::future::try_join_all(all_tasks).await.unwrap();
+    let all_results = futures::future::try_join_all(all_tasks).await;
+
+    if let Ok(all_results) = &all_results {
+        let aggregated_result = EvaluationResult {
+            context: Scores::aggregate(all_results.iter().map(|r| &r.context)),
+            edit_prediction: Scores::aggregate(all_results.iter().map(|r| &r.edit_prediction)),
+        };
+
+        if example_len > 1 {
+            println!("\n{}", "-".repeat(80));
+            println!("\n## TOTAL SCORES");
+            println!("{}", aggregated_result.to_markdown());
+        }
+    }
 
-    let aggregated_result = EvaluationResult {
-        context: Scores::aggregate(all_results.iter().map(|r| &r.context)),
-        edit_prediction: Scores::aggregate(all_results.iter().map(|r| &r.edit_prediction)),
-    };
+    print_run_data_dir();
 
-    if example_len > 1 {
-        println!("\n{}", "-".repeat(80));
-        println!("# TOTAL SCORES:");
-        println!("{}", aggregated_result.to_markdown());
-    }
+    all_results.unwrap();
 }
 
 pub async fn run_evaluate_one(
     example_path: &Path,
-    skip_cache: bool,
     prompt_format: PromptFormat,
     use_expected_context: bool,
+    cache_mode: CacheMode,
     app_state: Arc<ZetaCliAppState>,
     cx: &mut AsyncApp,
 ) -> Result<EvaluationResult> {
     let example = NamedExample::load(&example_path).unwrap();
     let predictions = zeta2_predict(
         example.clone(),
-        skip_cache,
         prompt_format,
         use_expected_context,
+        cache_mode,
         &app_state,
         cx,
     )

crates/zeta_cli/src/example.rs 🔗

@@ -398,7 +398,7 @@ impl NamedExample {
         Ok(worktree_path)
     }
 
-    fn file_name(&self) -> String {
+    pub fn file_name(&self) -> String {
         self.name
             .chars()
             .map(|c| {

crates/zeta_cli/src/main.rs 🔗

@@ -54,6 +54,7 @@ enum Command {
         #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
         output_format: ExampleFormat,
     },
+    Clean,
 }
 
 #[derive(Subcommand, Debug)]
@@ -470,6 +471,7 @@ fn main() {
                     let example = NamedExample::load(path).unwrap();
                     example.write(output_format, io::stdout()).unwrap();
                 }
+                Command::Clean => std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap(),
             };
 
             let _ = cx.update(|cx| cx.quit());

crates/zeta_cli/src/paths.rs 🔗

@@ -1,16 +1,40 @@
 use std::{env, path::PathBuf, sync::LazyLock};
 
-static TARGET_DIR: LazyLock<PathBuf> = LazyLock::new(|| env::current_dir().unwrap().join("target"));
-pub static CACHE_DIR: LazyLock<PathBuf> =
-    LazyLock::new(|| TARGET_DIR.join("zeta-llm-response-cache"));
-pub static REPOS_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_DIR.join("zeta-repos"));
-pub static WORKTREES_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_DIR.join("zeta-worktrees"));
-pub static LOGS_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_DIR.join("zeta-logs"));
-pub static LOGS_SEARCH_PROMPT: LazyLock<PathBuf> =
-    LazyLock::new(|| LOGS_DIR.join("search_prompt.md"));
-pub static LOGS_SEARCH_QUERIES: LazyLock<PathBuf> =
-    LazyLock::new(|| LOGS_DIR.join("search_queries.json"));
-pub static LOGS_PREDICTION_PROMPT: LazyLock<PathBuf> =
-    LazyLock::new(|| LOGS_DIR.join("prediction_prompt.md"));
-pub static LOGS_PREDICTION_RESPONSE: LazyLock<PathBuf> =
-    LazyLock::new(|| LOGS_DIR.join("prediction_response.md"));
+pub static TARGET_ZETA_DIR: LazyLock<PathBuf> =
+    LazyLock::new(|| env::current_dir().unwrap().join("target/zeta"));
+pub static CACHE_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("cache"));
+pub static REPOS_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("repos"));
+pub static WORKTREES_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("worktrees"));
+pub static RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
+    TARGET_ZETA_DIR
+        .join("runs")
+        .join(chrono::Local::now().format("%d-%m-%y-%H_%M_%S").to_string())
+});
+pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> =
+    LazyLock::new(|| TARGET_ZETA_DIR.join("latest"));
+
+pub fn print_run_data_dir() {
+    println!("\n## Run Data\n");
+
+    let current_dir = std::env::current_dir().unwrap();
+    for file in std::fs::read_dir(&*RUN_DIR).unwrap() {
+        let file = file.unwrap();
+        if file.file_type().unwrap().is_dir() {
+            for file in std::fs::read_dir(file.path()).unwrap() {
+                let path = file.unwrap().path();
+                let path = path.strip_prefix(&current_dir).unwrap_or(&path);
+                println!(
+                    "- {}/\x1b[34m{}\x1b[0m",
+                    path.parent().unwrap().display(),
+                    path.file_name().unwrap().display(),
+                );
+            }
+        } else {
+            let path = file.path();
+            println!(
+                "- {} ",
+                path.strip_prefix(&current_dir).unwrap_or(&path).display()
+            );
+        }
+    }
+}

crates/zeta_cli/src/predict.rs 🔗

@@ -1,20 +1,15 @@
 use crate::PromptFormat;
 use crate::example::{ActualExcerpt, ExpectedExcerpt, NamedExample};
 use crate::headless::ZetaCliAppState;
-use crate::paths::{
-    CACHE_DIR, LOGS_DIR, LOGS_PREDICTION_PROMPT, LOGS_PREDICTION_RESPONSE, LOGS_SEARCH_PROMPT,
-    LOGS_SEARCH_QUERIES,
-};
+use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
 use ::serde::Serialize;
-use anyhow::{Result, anyhow};
-use clap::Args;
-use collections::HashMap;
-use gpui::http_client::Url;
-use language::{Anchor, Buffer, Point};
-// use cloud_llm_client::predict_edits_v3::PromptFormat;
+use anyhow::{Context, Result, anyhow};
+use clap::{Args, ValueEnum};
 use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
+use collections::HashMap;
 use futures::StreamExt as _;
 use gpui::{AppContext, AsyncApp, Entity};
+use language::{Anchor, Buffer, Point};
 use project::Project;
 use serde::Deserialize;
 use std::cell::Cell;
@@ -25,7 +20,7 @@ use std::path::PathBuf;
 use std::sync::Arc;
 use std::sync::Mutex;
 use std::time::{Duration, Instant};
-use zeta2::LlmResponseCache;
+use zeta2::{EvalCache, EvalCacheEntryKind, EvalCacheKey};
 
 #[derive(Debug, Args)]
 pub struct PredictArguments {
@@ -36,8 +31,31 @@ pub struct PredictArguments {
     #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
     format: PredictionsOutputFormat,
     example_path: PathBuf,
-    #[clap(long)]
-    skip_cache: bool,
+    #[clap(long, value_enum, default_value_t = CacheMode::default())]
+    cache: CacheMode,
+}
+
+#[derive(Debug, ValueEnum, Default, Clone, Copy)]
+pub enum CacheMode {
+    /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
+    #[default]
+    #[value(alias = "request")]
+    Requests,
+    /// Ignore existing cache entries for both LLM and search.
+    Skip,
+    /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
+    /// Useful for reproducing results and fixing bugs outside of search queries
+    Force,
+}
+
+impl CacheMode {
+    fn use_cached_llm_responses(&self) -> bool {
+        matches!(self, CacheMode::Requests | CacheMode::Force)
+    }
+
+    fn use_cached_search_results(&self) -> bool {
+        matches!(self, CacheMode::Force)
+    }
 }
 
 #[derive(clap::ValueEnum, Debug, Clone)]
@@ -55,9 +73,9 @@ pub async fn run_zeta2_predict(
     let example = NamedExample::load(args.example_path).unwrap();
     let result = zeta2_predict(
         example,
-        args.skip_cache,
         args.prompt_format,
         args.use_expected_context,
+        args.cache,
         &app_state,
         cx,
     )
@@ -65,14 +83,7 @@ pub async fn run_zeta2_predict(
     .unwrap();
     result.write(args.format, std::io::stdout()).unwrap();
 
-    println!("## Logs\n");
-    println!("Search prompt: {}", LOGS_SEARCH_PROMPT.display());
-    println!("Search queries: {}", LOGS_SEARCH_QUERIES.display());
-    println!("Prediction prompt: {}", LOGS_PREDICTION_PROMPT.display());
-    println!(
-        "Prediction response: {}",
-        LOGS_PREDICTION_RESPONSE.display()
-    );
+    print_run_data_dir();
 }
 
 thread_local! {
@@ -81,13 +92,12 @@ thread_local! {
 
 pub async fn zeta2_predict(
     example: NamedExample,
-    skip_cache: bool,
     prompt_format: PromptFormat,
     use_expected_context: bool,
+    cache_mode: CacheMode,
     app_state: &Arc<ZetaCliAppState>,
     cx: &mut AsyncApp,
 ) -> Result<PredictionDetails> {
-    fs::create_dir_all(&*LOGS_DIR)?;
     let worktree_path = example.setup_worktree().await?;
 
     if !AUTHENTICATED.get() {
@@ -126,8 +136,25 @@ pub async fn zeta2_predict(
 
     let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
 
+    let example_run_dir = RUN_DIR.join(&example.file_name());
+    fs::create_dir_all(&example_run_dir)?;
+    if LATEST_EXAMPLE_RUN_DIR.exists() {
+        fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
+    }
+
+    #[cfg(unix)]
+    std::os::unix::fs::symlink(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
+        .context("creating latest link")?;
+
+    #[cfg(windows)]
+    std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
+        .context("creating latest link")?;
+
     zeta.update(cx, |zeta, _cx| {
-        zeta.with_llm_response_cache(Arc::new(Cache { skip_cache }));
+        zeta.with_eval_cache(Arc::new(RunCache {
+            example_run_dir: example_run_dir.clone(),
+            cache_mode,
+        }));
     })?;
 
     cx.subscribe(&buffer_store, {
@@ -159,12 +186,15 @@ pub async fn zeta2_predict(
                 match event {
                     zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
                         start_time = Some(info.timestamp);
-                        fs::write(&*LOGS_SEARCH_PROMPT, &info.search_prompt)?;
+                        fs::write(
+                            example_run_dir.join("search_prompt.md"),
+                            &info.search_prompt,
+                        )?;
                     }
                     zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
                         search_queries_generated_at = Some(info.timestamp);
                         fs::write(
-                            &*LOGS_SEARCH_QUERIES,
+                            example_run_dir.join("search_queries.json"),
                             serde_json::to_string_pretty(&info.search_queries).unwrap(),
                         )?;
                     }
@@ -176,7 +206,7 @@ pub async fn zeta2_predict(
                         let prediction_started_at = Instant::now();
                         start_time.get_or_insert(prediction_started_at);
                         fs::write(
-                            &*LOGS_PREDICTION_PROMPT,
+                            example_run_dir.join("prediction_prompt.md"),
                             &request.local_prompt.unwrap_or_default(),
                         )?;
 
@@ -210,7 +240,7 @@ pub async fn zeta2_predict(
                         let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
                         let response = zeta2::text_from_response(response).unwrap_or_default();
                         let prediction_finished_at = Instant::now();
-                        fs::write(&*LOGS_PREDICTION_RESPONSE, &response)?;
+                        fs::write(example_run_dir.join("prediction_response.md"), &response)?;
 
                         let mut result = result.lock().unwrap();
 
@@ -328,48 +358,69 @@ async fn resolve_context_entry(
     Ok((buffer, ranges))
 }
 
-struct Cache {
-    skip_cache: bool,
+struct RunCache {
+    cache_mode: CacheMode,
+    example_run_dir: PathBuf,
 }
 
-impl Cache {
-    fn path(key: u64) -> PathBuf {
-        CACHE_DIR.join(format!("{key:x}.json"))
+impl RunCache {
+    fn output_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
+        CACHE_DIR.join(format!("{kind}_out_{key:x}.json",))
     }
-}
 
-impl LlmResponseCache for Cache {
-    fn get_key(&self, url: &Url, body: &str) -> u64 {
-        use collections::FxHasher;
-        use std::hash::{Hash, Hasher};
+    fn input_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
+        CACHE_DIR.join(format!("{kind}_in_{key:x}.json",))
+    }
 
-        let mut hasher = FxHasher::default();
-        url.hash(&mut hasher);
-        body.hash(&mut hasher);
-        hasher.finish()
+    fn link_to_run(&self, key: &EvalCacheKey) {
+        let output_link_path = self.example_run_dir.join(format!("{}_out.json", key.0));
+        fs::hard_link(Self::output_cache_path(key), &output_link_path).unwrap();
+
+        let input_link_path = self.example_run_dir.join(format!("{}_in.json", key.0));
+        fs::hard_link(Self::input_cache_path(key), &input_link_path).unwrap();
     }
+}
+
+impl EvalCache for RunCache {
+    fn read(&self, key: EvalCacheKey) -> Option<String> {
+        let path = RunCache::output_cache_path(&key);
 
-    fn read_response(&self, key: u64) -> Option<String> {
-        let path = Cache::path(key);
         if path.exists() {
-            if self.skip_cache {
-                log::info!("Skipping existing cached LLM response: {}", path.display());
-                None
-            } else {
-                log::info!("Using LLM response from cache: {}", path.display());
+            let use_cache = match key.0 {
+                EvalCacheEntryKind::Search => self.cache_mode.use_cached_search_results(),
+                EvalCacheEntryKind::Context | EvalCacheEntryKind::Prediction => {
+                    self.cache_mode.use_cached_llm_responses()
+                }
+            };
+            if use_cache {
+                log::info!("Using cache entry: {}", path.display());
+                self.link_to_run(&key);
                 Some(fs::read_to_string(path).unwrap())
+            } else {
+                log::info!("Skipping cached entry: {}", path.display());
+                None
             }
+        } else if matches!(self.cache_mode, CacheMode::Force) {
+            panic!(
+                "No cached entry found for {:?}. Run without `--cache force` at least once.",
+                key.0
+            );
         } else {
             None
         }
     }
 
-    fn write_response(&self, key: u64, value: &str) {
+    fn write(&self, key: EvalCacheKey, input: &str, output: &str) {
         fs::create_dir_all(&*CACHE_DIR).unwrap();
 
-        let path = Cache::path(key);
-        log::info!("Writing LLM response to cache: {}", path.display());
-        fs::write(path, value).unwrap();
+        let input_path = RunCache::input_cache_path(&key);
+        fs::write(&input_path, input).unwrap();
+
+        let output_path = RunCache::output_cache_path(&key);
+        log::info!("Writing cache entry: {}", output_path.display());
+        fs::write(&output_path, output).unwrap();
+
+        self.link_to_run(&key);
     }
 }