Detailed changes
@@ -44,7 +44,7 @@ pub struct SearchToolInput {
}
/// Search for relevant code by path, syntax hierarchy, and content.
-#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
+#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct SearchToolQuery {
/// 1. A glob pattern to match file paths in the codebase to search in.
pub glob: String,
@@ -12,7 +12,7 @@ workspace = true
path = "src/zeta2.rs"
[features]
-llm-response-cache = []
+eval-support = []
[dependencies]
anyhow.workspace = true
@@ -1,5 +1,3 @@
-use std::ops::Range;
-
use anyhow::Result;
use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery;
use collections::HashMap;
@@ -14,17 +12,76 @@ use project::{
search::{SearchQuery, SearchResult},
};
use smol::channel;
+use std::ops::Range;
use util::{
ResultExt as _,
paths::{PathMatcher, PathStyle},
};
use workspace::item::Settings as _;
+#[cfg(feature = "eval-support")]
+type CachedSearchResults = std::collections::BTreeMap<std::path::PathBuf, Vec<Range<usize>>>;
+
pub async fn run_retrieval_searches(
- project: Entity<Project>,
queries: Vec<SearchToolQuery>,
+ project: Entity<Project>,
+ #[cfg(feature = "eval-support")] eval_cache: Option<std::sync::Arc<dyn crate::EvalCache>>,
cx: &mut AsyncApp,
) -> Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>> {
+ #[cfg(feature = "eval-support")]
+ let cache = if let Some(eval_cache) = eval_cache {
+ use crate::EvalCacheEntryKind;
+ use anyhow::Context;
+ use collections::FxHasher;
+ use std::hash::{Hash, Hasher};
+
+ let mut hasher = FxHasher::default();
+ project.read_with(cx, |project, cx| {
+ let mut worktrees = project.worktrees(cx);
+ let Some(worktree) = worktrees.next() else {
+ panic!("Expected a single worktree in eval project. Found none.");
+ };
+ assert!(
+ worktrees.next().is_none(),
+ "Expected a single worktree in eval project. Found more than one."
+ );
+ worktree.read(cx).abs_path().hash(&mut hasher);
+ })?;
+
+ queries.hash(&mut hasher);
+ let key = (EvalCacheEntryKind::Search, hasher.finish());
+
+ if let Some(cached_results) = eval_cache.read(key) {
+ let file_results = serde_json::from_str::<CachedSearchResults>(&cached_results)
+ .context("Failed to deserialize cached search results")?;
+ let mut results = HashMap::default();
+
+ for (path, ranges) in file_results {
+ let buffer = project
+ .update(cx, |project, cx| {
+ let project_path = project.find_project_path(path, cx).unwrap();
+ project.open_buffer(project_path, cx)
+ })?
+ .await?;
+ let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
+ let mut ranges = ranges
+ .into_iter()
+ .map(|range| {
+ snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end)
+ })
+ .collect();
+ merge_anchor_ranges(&mut ranges, &snapshot);
+ results.insert(buffer, ranges);
+ }
+
+ return Ok(results);
+ }
+
+ Some((eval_cache, serde_json::to_string_pretty(&queries)?, key))
+ } else {
+ None
+ };
+
let (exclude_matcher, path_style) = project.update(cx, |project, cx| {
let global_settings = WorktreeSettings::get_global(cx);
let exclude_patterns = global_settings
@@ -58,6 +115,8 @@ pub async fn run_retrieval_searches(
}
drop(results_tx);
+ #[cfg(feature = "eval-support")]
+ let cache = cache.clone();
cx.background_spawn(async move {
let mut results: HashMap<Entity<Buffer>, Vec<Range<Anchor>>> = HashMap::default();
let mut snapshots = HashMap::default();
@@ -79,6 +138,29 @@ pub async fn run_retrieval_searches(
}
}
+ #[cfg(feature = "eval-support")]
+ if let Some((cache, queries, key)) = cache {
+ let cached_results: CachedSearchResults = results
+ .iter()
+ .filter_map(|(buffer, ranges)| {
+ let snapshot = snapshots.get(&buffer.entity_id())?;
+ let path = snapshot.file().map(|f| f.path());
+ let mut ranges = ranges
+ .iter()
+ .map(|range| range.to_offset(&snapshot))
+ .collect::<Vec<_>>();
+ ranges.sort_unstable_by_key(|range| (range.start, range.end));
+
+ Some((path?.as_std_path().to_path_buf(), ranges))
+ })
+ .collect();
+ cache.write(
+ key,
+ &queries,
+ &serde_json::to_string_pretty(&cached_results)?,
+ );
+ }
+
for (buffer, ranges) in results.iter_mut() {
if let Some(snapshot) = snapshots.get(&buffer.entity_id()) {
merge_anchor_ranges(ranges, snapshot);
@@ -489,9 +571,10 @@ mod tests {
expected_output: &str,
cx: &mut TestAppContext,
) {
- let results = run_retrieval_searches(project.clone(), vec![query], &mut cx.to_async())
- .await
- .unwrap();
+ let results =
+ run_retrieval_searches(vec![query], project.clone(), None, &mut cx.to_async())
+ .await
+ .unwrap();
let mut results = results.into_iter().collect::<Vec<_>>();
results.sort_by_key(|results| {
@@ -105,21 +105,58 @@ fn resolve_new_text_old_text_in_buffer(
#[cfg(debug_assertions)]
fn closest_old_text_match(buffer: &TextBufferSnapshot, old_text: &str) -> Option<String> {
let buffer_text = buffer.text();
- let mut cursor = 0;
let len = old_text.len();
+ if len == 0 || buffer_text.len() < len {
+ return None;
+ }
+
let mut min_score = usize::MAX;
let mut min_start = 0;
+ let old_text_bytes = old_text.as_bytes();
+ let old_alpha_count = old_text_bytes
+ .iter()
+ .filter(|&&b| b.is_ascii_alphanumeric())
+ .count();
+
+ let old_line_count = old_text.lines().count();
+
+ let mut cursor = 0;
+
while cursor + len <= buffer_text.len() {
let candidate = &buffer_text[cursor..cursor + len];
+ let candidate_bytes = candidate.as_bytes();
+
+ if usize::abs_diff(candidate.lines().count(), old_line_count) > 4 {
+ cursor += 1;
+ continue;
+ }
+
+ let candidate_alpha_count = candidate_bytes
+ .iter()
+ .filter(|&&b| b.is_ascii_alphanumeric())
+ .count();
+
+ // If alphanumeric character count differs by more than 30%, skip
+ if usize::abs_diff(old_alpha_count, candidate_alpha_count) * 10 > old_alpha_count * 3 {
+ cursor += 1;
+ continue;
+ }
+
let score = strsim::levenshtein(candidate, old_text);
if score < min_score {
min_score = score;
min_start = cursor;
+
+ if min_score <= len / 10 {
+ break;
+ }
}
+
cursor += 1;
}
+
if min_score != usize::MAX {
Some(buffer_text[min_start..min_start + len].to_string())
} else {
@@ -132,15 +132,8 @@ pub struct Zeta {
options: ZetaOptions,
update_required: bool,
debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
- #[cfg(feature = "llm-response-cache")]
- llm_response_cache: Option<Arc<dyn LlmResponseCache>>,
-}
-
-#[cfg(feature = "llm-response-cache")]
-pub trait LlmResponseCache: Send + Sync {
- fn get_key(&self, url: &gpui::http_client::Url, body: &str) -> u64;
- fn read_response(&self, key: u64) -> Option<String>;
- fn write_response(&self, key: u64, value: &str);
+ #[cfg(feature = "eval-support")]
+ eval_cache: Option<Arc<dyn EvalCache>>,
}
#[derive(Debug, Clone, PartialEq)]
@@ -369,14 +362,14 @@ impl Zeta {
),
update_required: false,
debug_tx: None,
- #[cfg(feature = "llm-response-cache")]
- llm_response_cache: None,
+ #[cfg(feature = "eval-support")]
+ eval_cache: None,
}
}
- #[cfg(feature = "llm-response-cache")]
- pub fn with_llm_response_cache(&mut self, cache: Arc<dyn LlmResponseCache>) {
- self.llm_response_cache = Some(cache);
+ #[cfg(feature = "eval-support")]
+ pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
+ self.eval_cache = Some(cache);
}
pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
@@ -736,9 +729,19 @@ impl Zeta {
// TODO data collection
let can_collect_data = cx.is_staff();
- let mut included_files = project_state
+ let empty_context_files = HashMap::default();
+ let context_files = project_state
.and_then(|project_state| project_state.context.as_ref())
- .unwrap_or(&HashMap::default())
+ .unwrap_or(&empty_context_files);
+
+ #[cfg(feature = "eval-support")]
+ let parsed_fut = futures::future::join_all(
+ context_files
+ .keys()
+ .map(|buffer| buffer.read(cx).parsing_idle()),
+ );
+
+ let mut included_files = context_files
.iter()
.filter_map(|(buffer_entity, ranges)| {
let buffer = buffer_entity.read(cx);
@@ -751,12 +754,19 @@ impl Zeta {
})
.collect::<Vec<_>>();
- #[cfg(feature = "llm-response-cache")]
- let llm_response_cache = self.llm_response_cache.clone();
+ included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
+ (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
+ });
+
+ #[cfg(feature = "eval-support")]
+ let eval_cache = self.eval_cache.clone();
let request_task = cx.background_spawn({
let active_buffer = active_buffer.clone();
async move {
+ #[cfg(feature = "eval-support")]
+ parsed_fut.await;
+
let index_state = if let Some(index_state) = index_state {
Some(index_state.lock_owned().await)
} else {
@@ -819,17 +829,17 @@ impl Zeta {
let included_files = included_files
.iter()
- .map(|(_, buffer, path, ranges)| {
+ .map(|(_, snapshot, path, ranges)| {
let excerpts = merge_excerpts(
- &buffer,
+ &snapshot,
ranges.iter().map(|range| {
- let point_range = range.to_point(&buffer);
+ let point_range = range.to_point(&snapshot);
Line(point_range.start.row)..Line(point_range.end.row)
}),
);
predict_edits_v3::IncludedFile {
path: path.clone(),
- max_row: Line(buffer.max_point().row),
+ max_row: Line(snapshot.max_point().row),
excerpts,
}
})
@@ -948,8 +958,10 @@ impl Zeta {
client,
llm_token,
app_version,
- #[cfg(feature = "llm-response-cache")]
- llm_response_cache,
+ #[cfg(feature = "eval-support")]
+ eval_cache,
+ #[cfg(feature = "eval-support")]
+ EvalCacheEntryKind::Prediction,
)
.await;
let request_time = chrono::Utc::now() - before_request;
@@ -1049,9 +1061,8 @@ impl Zeta {
client: Arc<Client>,
llm_token: LlmApiToken,
app_version: SemanticVersion,
- #[cfg(feature = "llm-response-cache")] llm_response_cache: Option<
- Arc<dyn LlmResponseCache>,
- >,
+ #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
+ #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
http_client::Url::parse(&predict_edits_url)?
@@ -1061,16 +1072,23 @@ impl Zeta {
.build_zed_llm_url("/predict_edits/raw", &[])?
};
- #[cfg(feature = "llm-response-cache")]
- let cache_key = if let Some(cache) = llm_response_cache {
- let request_json = serde_json::to_string(&request)?;
- let key = cache.get_key(&url, &request_json);
+ #[cfg(feature = "eval-support")]
+ let cache_key = if let Some(cache) = eval_cache {
+ use collections::FxHasher;
+ use std::hash::{Hash, Hasher};
+
+ let mut hasher = FxHasher::default();
+ url.hash(&mut hasher);
+ let request_str = serde_json::to_string_pretty(&request)?;
+ request_str.hash(&mut hasher);
+ let hash = hasher.finish();
- if let Some(response_str) = cache.read_response(key) {
+ let key = (eval_cache_kind, hash);
+ if let Some(response_str) = cache.read(key) {
return Ok((serde_json::from_str(&response_str)?, None));
}
- Some((cache, key))
+ Some((cache, request_str, key))
} else {
None
};
@@ -1088,9 +1106,9 @@ impl Zeta {
)
.await?;
- #[cfg(feature = "llm-response-cache")]
- if let Some((cache, key)) = cache_key {
- cache.write_response(key, &serde_json::to_string(&response)?);
+ #[cfg(feature = "eval-support")]
+ if let Some((cache, request, key)) = cache_key {
+ cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
}
Ok((response, usage))
@@ -1361,8 +1379,8 @@ impl Zeta {
reasoning_effort: None,
};
- #[cfg(feature = "llm-response-cache")]
- let llm_response_cache = self.llm_response_cache.clone();
+ #[cfg(feature = "eval-support")]
+ let eval_cache = self.eval_cache.clone();
cx.spawn(async move |this, cx| {
log::trace!("Sending search planning request");
@@ -1371,8 +1389,10 @@ impl Zeta {
client,
llm_token,
app_version,
- #[cfg(feature = "llm-response-cache")]
- llm_response_cache,
+ #[cfg(feature = "eval-support")]
+ eval_cache.clone(),
+ #[cfg(feature = "eval-support")]
+ EvalCacheEntryKind::Context,
)
.await;
let mut response = Self::handle_api_response(&this, response, cx)?;
@@ -1421,8 +1441,14 @@ impl Zeta {
log::trace!("Running retrieval search: {queries:#?}");
- let related_excerpts_result =
- retrieval_search::run_retrieval_searches(project.clone(), queries, cx).await;
+ let related_excerpts_result = retrieval_search::run_retrieval_searches(
+ queries,
+ project.clone(),
+ #[cfg(feature = "eval-support")]
+ eval_cache,
+ cx,
+ )
+ .await;
log::trace!("Search queries executed");
@@ -1772,6 +1798,34 @@ fn add_signature(
Some(signature_index)
}
+#[cfg(feature = "eval-support")]
+pub type EvalCacheKey = (EvalCacheEntryKind, u64);
+
+#[cfg(feature = "eval-support")]
+#[derive(Debug, Clone, Copy, PartialEq)]
+pub enum EvalCacheEntryKind {
+ Context,
+ Search,
+ Prediction,
+}
+
+#[cfg(feature = "eval-support")]
+impl std::fmt::Display for EvalCacheEntryKind {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ EvalCacheEntryKind::Search => write!(f, "search"),
+ EvalCacheEntryKind::Context => write!(f, "context"),
+ EvalCacheEntryKind::Prediction => write!(f, "prediction"),
+ }
+ }
+}
+
+#[cfg(feature = "eval-support")]
+pub trait EvalCache: Send + Sync {
+ fn read(&self, key: EvalCacheKey) -> Option<String>;
+ fn write(&self, key: EvalCacheKey, input: &str, value: &str);
+}
+
#[cfg(test)]
mod tests {
use std::{path::Path, sync::Arc};
@@ -54,7 +54,7 @@ toml.workspace = true
util.workspace = true
watch.workspace = true
zeta.workspace = true
-zeta2 = { workspace = true, features = ["llm-response-cache"] }
+zeta2 = { workspace = true, features = ["eval-support"] }
zlog.workspace = true
[dev-dependencies]
@@ -14,18 +14,19 @@ use crate::{
PromptFormat,
example::{Example, NamedExample},
headless::ZetaCliAppState,
- predict::{PredictionDetails, zeta2_predict},
+ paths::print_run_data_dir,
+ predict::{CacheMode, PredictionDetails, zeta2_predict},
};
#[derive(Debug, Args)]
pub struct EvaluateArguments {
example_paths: Vec<PathBuf>,
- #[clap(long)]
- skip_cache: bool,
#[arg(long, value_enum, default_value_t = PromptFormat::default())]
prompt_format: PromptFormat,
#[arg(long)]
use_expected_context: bool,
+ #[clap(long, value_enum, default_value_t = CacheMode::default())]
+ cache: CacheMode,
}
pub async fn run_evaluate(
@@ -39,43 +40,49 @@ pub async fn run_evaluate(
cx.spawn(async move |cx| {
run_evaluate_one(
&path,
- args.skip_cache,
args.prompt_format,
args.use_expected_context,
+ args.cache,
app_state.clone(),
cx,
)
.await
})
});
- let all_results = futures::future::try_join_all(all_tasks).await.unwrap();
+ let all_results = futures::future::try_join_all(all_tasks).await;
+
+ if let Ok(all_results) = &all_results {
+ let aggregated_result = EvaluationResult {
+ context: Scores::aggregate(all_results.iter().map(|r| &r.context)),
+ edit_prediction: Scores::aggregate(all_results.iter().map(|r| &r.edit_prediction)),
+ };
+
+ if example_len > 1 {
+ println!("\n{}", "-".repeat(80));
+ println!("\n## TOTAL SCORES");
+ println!("{}", aggregated_result.to_markdown());
+ }
+ }
- let aggregated_result = EvaluationResult {
- context: Scores::aggregate(all_results.iter().map(|r| &r.context)),
- edit_prediction: Scores::aggregate(all_results.iter().map(|r| &r.edit_prediction)),
- };
+ print_run_data_dir();
- if example_len > 1 {
- println!("\n{}", "-".repeat(80));
- println!("# TOTAL SCORES:");
- println!("{}", aggregated_result.to_markdown());
- }
+ all_results.unwrap();
}
pub async fn run_evaluate_one(
example_path: &Path,
- skip_cache: bool,
prompt_format: PromptFormat,
use_expected_context: bool,
+ cache_mode: CacheMode,
app_state: Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) -> Result<EvaluationResult> {
let example = NamedExample::load(&example_path).unwrap();
let predictions = zeta2_predict(
example.clone(),
- skip_cache,
prompt_format,
use_expected_context,
+ cache_mode,
&app_state,
cx,
)
@@ -398,7 +398,7 @@ impl NamedExample {
Ok(worktree_path)
}
- fn file_name(&self) -> String {
+ pub fn file_name(&self) -> String {
self.name
.chars()
.map(|c| {
@@ -54,6 +54,7 @@ enum Command {
#[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
output_format: ExampleFormat,
},
+ Clean,
}
#[derive(Subcommand, Debug)]
@@ -470,6 +471,7 @@ fn main() {
let example = NamedExample::load(path).unwrap();
example.write(output_format, io::stdout()).unwrap();
}
+ Command::Clean => std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap(),
};
let _ = cx.update(|cx| cx.quit());
@@ -1,16 +1,40 @@
use std::{env, path::PathBuf, sync::LazyLock};
-static TARGET_DIR: LazyLock<PathBuf> = LazyLock::new(|| env::current_dir().unwrap().join("target"));
-pub static CACHE_DIR: LazyLock<PathBuf> =
- LazyLock::new(|| TARGET_DIR.join("zeta-llm-response-cache"));
-pub static REPOS_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_DIR.join("zeta-repos"));
-pub static WORKTREES_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_DIR.join("zeta-worktrees"));
-pub static LOGS_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_DIR.join("zeta-logs"));
-pub static LOGS_SEARCH_PROMPT: LazyLock<PathBuf> =
- LazyLock::new(|| LOGS_DIR.join("search_prompt.md"));
-pub static LOGS_SEARCH_QUERIES: LazyLock<PathBuf> =
- LazyLock::new(|| LOGS_DIR.join("search_queries.json"));
-pub static LOGS_PREDICTION_PROMPT: LazyLock<PathBuf> =
- LazyLock::new(|| LOGS_DIR.join("prediction_prompt.md"));
-pub static LOGS_PREDICTION_RESPONSE: LazyLock<PathBuf> =
- LazyLock::new(|| LOGS_DIR.join("prediction_response.md"));
+pub static TARGET_ZETA_DIR: LazyLock<PathBuf> =
+ LazyLock::new(|| env::current_dir().unwrap().join("target/zeta"));
+pub static CACHE_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("cache"));
+pub static REPOS_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("repos"));
+pub static WORKTREES_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("worktrees"));
+pub static RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
+ TARGET_ZETA_DIR
+ .join("runs")
+ .join(chrono::Local::now().format("%d-%m-%y-%H_%M_%S").to_string())
+});
+pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> =
+ LazyLock::new(|| TARGET_ZETA_DIR.join("latest"));
+
+pub fn print_run_data_dir() {
+ println!("\n## Run Data\n");
+
+ let current_dir = std::env::current_dir().unwrap();
+ for file in std::fs::read_dir(&*RUN_DIR).unwrap() {
+ let file = file.unwrap();
+ if file.file_type().unwrap().is_dir() {
+ for file in std::fs::read_dir(file.path()).unwrap() {
+ let path = file.unwrap().path();
+ let path = path.strip_prefix(¤t_dir).unwrap_or(&path);
+ println!(
+ "- {}/\x1b[34m{}\x1b[0m",
+ path.parent().unwrap().display(),
+ path.file_name().unwrap().display(),
+ );
+ }
+ } else {
+ let path = file.path();
+ println!(
+ "- {} ",
+ path.strip_prefix(¤t_dir).unwrap_or(&path).display()
+ );
+ }
+ }
+}
@@ -1,20 +1,15 @@
use crate::PromptFormat;
use crate::example::{ActualExcerpt, ExpectedExcerpt, NamedExample};
use crate::headless::ZetaCliAppState;
-use crate::paths::{
- CACHE_DIR, LOGS_DIR, LOGS_PREDICTION_PROMPT, LOGS_PREDICTION_RESPONSE, LOGS_SEARCH_PROMPT,
- LOGS_SEARCH_QUERIES,
-};
+use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
use ::serde::Serialize;
-use anyhow::{Result, anyhow};
-use clap::Args;
-use collections::HashMap;
-use gpui::http_client::Url;
-use language::{Anchor, Buffer, Point};
-// use cloud_llm_client::predict_edits_v3::PromptFormat;
+use anyhow::{Context, Result, anyhow};
+use clap::{Args, ValueEnum};
use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
+use collections::HashMap;
use futures::StreamExt as _;
use gpui::{AppContext, AsyncApp, Entity};
+use language::{Anchor, Buffer, Point};
use project::Project;
use serde::Deserialize;
use std::cell::Cell;
@@ -25,7 +20,7 @@ use std::path::PathBuf;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::{Duration, Instant};
-use zeta2::LlmResponseCache;
+use zeta2::{EvalCache, EvalCacheEntryKind, EvalCacheKey};
#[derive(Debug, Args)]
pub struct PredictArguments {
@@ -36,8 +31,31 @@ pub struct PredictArguments {
#[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
format: PredictionsOutputFormat,
example_path: PathBuf,
- #[clap(long)]
- skip_cache: bool,
+ #[clap(long, value_enum, default_value_t = CacheMode::default())]
+ cache: CacheMode,
+}
+
+#[derive(Debug, ValueEnum, Default, Clone, Copy)]
+pub enum CacheMode {
+ /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
+ #[default]
+ #[value(alias = "request")]
+ Requests,
+ /// Ignore existing cache entries for both LLM and search.
+ Skip,
+ /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
+ /// Useful for reproducing results and fixing bugs outside of search queries
+ Force,
+}
+
+impl CacheMode {
+ fn use_cached_llm_responses(&self) -> bool {
+ matches!(self, CacheMode::Requests | CacheMode::Force)
+ }
+
+ fn use_cached_search_results(&self) -> bool {
+ matches!(self, CacheMode::Force)
+ }
}
#[derive(clap::ValueEnum, Debug, Clone)]
@@ -55,9 +73,9 @@ pub async fn run_zeta2_predict(
let example = NamedExample::load(args.example_path).unwrap();
let result = zeta2_predict(
example,
- args.skip_cache,
args.prompt_format,
args.use_expected_context,
+ args.cache,
&app_state,
cx,
)
@@ -65,14 +83,7 @@ pub async fn run_zeta2_predict(
.unwrap();
result.write(args.format, std::io::stdout()).unwrap();
- println!("## Logs\n");
- println!("Search prompt: {}", LOGS_SEARCH_PROMPT.display());
- println!("Search queries: {}", LOGS_SEARCH_QUERIES.display());
- println!("Prediction prompt: {}", LOGS_PREDICTION_PROMPT.display());
- println!(
- "Prediction response: {}",
- LOGS_PREDICTION_RESPONSE.display()
- );
+ print_run_data_dir();
}
thread_local! {
@@ -81,13 +92,12 @@ thread_local! {
pub async fn zeta2_predict(
example: NamedExample,
- skip_cache: bool,
prompt_format: PromptFormat,
use_expected_context: bool,
+ cache_mode: CacheMode,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) -> Result<PredictionDetails> {
- fs::create_dir_all(&*LOGS_DIR)?;
let worktree_path = example.setup_worktree().await?;
if !AUTHENTICATED.get() {
@@ -126,8 +136,25 @@ pub async fn zeta2_predict(
let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
+ let example_run_dir = RUN_DIR.join(&example.file_name());
+ fs::create_dir_all(&example_run_dir)?;
+ if LATEST_EXAMPLE_RUN_DIR.exists() {
+ fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
+ }
+
+ #[cfg(unix)]
+ std::os::unix::fs::symlink(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
+ .context("creating latest link")?;
+
+ #[cfg(windows)]
+ std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
+ .context("creating latest link")?;
+
zeta.update(cx, |zeta, _cx| {
- zeta.with_llm_response_cache(Arc::new(Cache { skip_cache }));
+ zeta.with_eval_cache(Arc::new(RunCache {
+ example_run_dir: example_run_dir.clone(),
+ cache_mode,
+ }));
})?;
cx.subscribe(&buffer_store, {
@@ -159,12 +186,15 @@ pub async fn zeta2_predict(
match event {
zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
start_time = Some(info.timestamp);
- fs::write(&*LOGS_SEARCH_PROMPT, &info.search_prompt)?;
+ fs::write(
+ example_run_dir.join("search_prompt.md"),
+ &info.search_prompt,
+ )?;
}
zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
search_queries_generated_at = Some(info.timestamp);
fs::write(
- &*LOGS_SEARCH_QUERIES,
+ example_run_dir.join("search_queries.json"),
serde_json::to_string_pretty(&info.search_queries).unwrap(),
)?;
}
@@ -176,7 +206,7 @@ pub async fn zeta2_predict(
let prediction_started_at = Instant::now();
start_time.get_or_insert(prediction_started_at);
fs::write(
- &*LOGS_PREDICTION_PROMPT,
+ example_run_dir.join("prediction_prompt.md"),
&request.local_prompt.unwrap_or_default(),
)?;
@@ -210,7 +240,7 @@ pub async fn zeta2_predict(
let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
let response = zeta2::text_from_response(response).unwrap_or_default();
let prediction_finished_at = Instant::now();
- fs::write(&*LOGS_PREDICTION_RESPONSE, &response)?;
+ fs::write(example_run_dir.join("prediction_response.md"), &response)?;
let mut result = result.lock().unwrap();
@@ -328,48 +358,69 @@ async fn resolve_context_entry(
Ok((buffer, ranges))
}
-struct Cache {
- skip_cache: bool,
+struct RunCache {
+ cache_mode: CacheMode,
+ example_run_dir: PathBuf,
}
-impl Cache {
- fn path(key: u64) -> PathBuf {
- CACHE_DIR.join(format!("{key:x}.json"))
+impl RunCache {
+ fn output_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
+ CACHE_DIR.join(format!("{kind}_out_{key:x}.json",))
}
-}
-impl LlmResponseCache for Cache {
- fn get_key(&self, url: &Url, body: &str) -> u64 {
- use collections::FxHasher;
- use std::hash::{Hash, Hasher};
+ fn input_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
+ CACHE_DIR.join(format!("{kind}_in_{key:x}.json",))
+ }
- let mut hasher = FxHasher::default();
- url.hash(&mut hasher);
- body.hash(&mut hasher);
- hasher.finish()
+ fn link_to_run(&self, key: &EvalCacheKey) {
+ let output_link_path = self.example_run_dir.join(format!("{}_out.json", key.0));
+ fs::hard_link(Self::output_cache_path(key), &output_link_path).unwrap();
+
+ let input_link_path = self.example_run_dir.join(format!("{}_in.json", key.0));
+ fs::hard_link(Self::input_cache_path(key), &input_link_path).unwrap();
}
+}
+
+impl EvalCache for RunCache {
+ fn read(&self, key: EvalCacheKey) -> Option<String> {
+ let path = RunCache::output_cache_path(&key);
- fn read_response(&self, key: u64) -> Option<String> {
- let path = Cache::path(key);
if path.exists() {
- if self.skip_cache {
- log::info!("Skipping existing cached LLM response: {}", path.display());
- None
- } else {
- log::info!("Using LLM response from cache: {}", path.display());
+ let use_cache = match key.0 {
+ EvalCacheEntryKind::Search => self.cache_mode.use_cached_search_results(),
+ EvalCacheEntryKind::Context | EvalCacheEntryKind::Prediction => {
+ self.cache_mode.use_cached_llm_responses()
+ }
+ };
+ if use_cache {
+ log::info!("Using cache entry: {}", path.display());
+ self.link_to_run(&key);
Some(fs::read_to_string(path).unwrap())
+ } else {
+ log::info!("Skipping cached entry: {}", path.display());
+ None
}
+ } else if matches!(self.cache_mode, CacheMode::Force) {
+ panic!(
+ "No cached entry found for {:?}. Run without `--cache force` at least once.",
+ key.0
+ );
} else {
None
}
}
- fn write_response(&self, key: u64, value: &str) {
+ fn write(&self, key: EvalCacheKey, input: &str, output: &str) {
fs::create_dir_all(&*CACHE_DIR).unwrap();
- let path = Cache::path(key);
- log::info!("Writing LLM response to cache: {}", path.display());
- fs::write(path, value).unwrap();
+ let input_path = RunCache::input_cache_path(&key);
+ fs::write(&input_path, input).unwrap();
+
+ let output_path = RunCache::output_cache_path(&key);
+ log::info!("Writing cache entry: {}", output_path.display());
+ fs::write(&output_path, output).unwrap();
+
+ self.link_to_run(&key);
}
}