From 91d631c22920518dfe025a40dbcbd1bbb867784c Mon Sep 17 00:00:00 2001 From: Oleksiy Syvokon Date: Tue, 4 Nov 2025 19:36:50 +0200 Subject: [PATCH] Evaluate zeta2 context retrieval and edit predictions (#41921) This PR implements the `zeta-cli eval` command. It will: - Run the edit prediction model if there are no cached results - Compute precision/recall/F1 for context retrieval at the line level: every retrieved line of context is counted as a true positive (correct retrieval), false positive (retrieved something that was not expected), or false negative (didn't retrieve an expected line) - Compute similar metrics for edit predictions - Pretty-print results, highlighting the difference between actual and expected when printing to tty Other changes: - `zeta-cli predict` accepts a `--format` argument with options `md`, `json`, `diff` - Code restructure Release Notes: - N/A --------- Co-authored-by: Piotr Osiewicz <24362066+osiewicz@users.noreply.github.com> Co-authored-by: Agus Zubiaga --- crates/cloud_llm_client/src/udiff.rs | 36 ++- crates/zeta_cli/src/evaluate.rs | 338 +++++++++++++++++++++++++++ crates/zeta_cli/src/example.rs | 48 +++- crates/zeta_cli/src/main.rs | 297 ++++------------------- crates/zeta_cli/src/predict.rs | 287 +++++++++++++++++++++++ 5 files changed, 742 insertions(+), 264 deletions(-) create mode 100644 crates/zeta_cli/src/evaluate.rs create mode 100644 crates/zeta_cli/src/predict.rs diff --git a/crates/cloud_llm_client/src/udiff.rs b/crates/cloud_llm_client/src/udiff.rs index 444452e6b7350de1680d51b5b9a34eab69685fa3..c5972fc139dd105c9d4a0f4d5917950752278cb7 100644 --- a/crates/cloud_llm_client/src/udiff.rs +++ b/crates/cloud_llm_client/src/udiff.rs @@ -1,4 +1,4 @@ -use std::borrow::Cow; +use std::{borrow::Cow, fmt::Display}; #[derive(Debug, PartialEq)] pub enum DiffLine<'a> { @@ -8,7 +8,7 @@ pub enum DiffLine<'a> { Context(&'a str), Deletion(&'a str), Addition(&'a str), - Garbage, + Garbage(&'a str), } #[derive(Debug, PartialEq)] @@ -21,7 +21,7 @@ pub struct HunkLocation { impl<'a> DiffLine<'a> { pub fn parse(line: &'a str) -> Self { - Self::try_parse(line).unwrap_or(Self::Garbage) + Self::try_parse(line).unwrap_or(Self::Garbage(line)) } fn try_parse(line: &'a str) -> Option { @@ -60,6 +60,30 @@ impl<'a> DiffLine<'a> { } } +impl<'a> Display for DiffLine<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + DiffLine::OldPath { path } => write!(f, "--- {path}"), + DiffLine::NewPath { path } => write!(f, "+++ {path}"), + DiffLine::HunkHeader(Some(hunk_location)) => { + write!( + f, + "@@ -{},{} +{},{} @@", + hunk_location.start_line_old + 1, + hunk_location.count_old, + hunk_location.start_line_new + 1, + hunk_location.count_new + ) + } + DiffLine::HunkHeader(None) => write!(f, "@@ ... @@"), + DiffLine::Context(content) => write!(f, " {content}"), + DiffLine::Deletion(content) => write!(f, "-{content}"), + DiffLine::Addition(content) => write!(f, "+{content}"), + DiffLine::Garbage(line) => write!(f, "{line}"), + } + } +} + fn parse_header_path<'a>(strip_prefix: &'static str, header: &'a str) -> Cow<'a, str> { if !header.contains(['"', '\\']) { let path = header.split_ascii_whitespace().next().unwrap_or(header); @@ -134,8 +158,8 @@ mod tests { pretty_assertions::assert_eq!( lines, &[ - DiffLine::Garbage, - DiffLine::Garbage, + DiffLine::Garbage("diff --git a/text.txt b/text.txt"), + DiffLine::Garbage("index 86c770d..a1fd855 100644"), DiffLine::OldPath { path: "file.txt".into() }, @@ -151,7 +175,7 @@ mod tests { DiffLine::Context("context"), DiffLine::Deletion("deleted"), DiffLine::Addition("inserted"), - DiffLine::Garbage, + DiffLine::Garbage("garbage"), DiffLine::Context(""), DiffLine::OldPath { path: "b/file.txt".into() diff --git a/crates/zeta_cli/src/evaluate.rs b/crates/zeta_cli/src/evaluate.rs new file mode 100644 index 0000000000000000000000000000000000000000..6b5e2d0eecb8ca5e38ca233254aa1b0271448a11 --- /dev/null +++ b/crates/zeta_cli/src/evaluate.rs @@ -0,0 +1,338 @@ +use std::{ + fs, + io::IsTerminal, + path::{Path, PathBuf}, + sync::Arc, +}; + +use anyhow::Result; +use clap::Args; +use cloud_llm_client::udiff::DiffLine; +use collections::HashSet; +use gpui::AsyncApp; + +use crate::{ + example::{Example, NamedExample}, + headless::ZetaCliAppState, + predict::{PredictionDetails, zeta2_predict}, +}; + +#[derive(Debug, Args)] +pub struct EvaluateArguments { + example_paths: Vec, + #[clap(long)] + re_run: bool, +} + +pub async fn run_evaluate( + args: EvaluateArguments, + app_state: &Arc, + cx: &mut AsyncApp, +) { + let example_len = args.example_paths.len(); + let all_tasks = args.example_paths.into_iter().map(|path| { + let app_state = app_state.clone(); + cx.spawn(async move |cx| run_evaluate_one(&path, args.re_run, app_state.clone(), cx).await) + }); + let all_results = futures::future::try_join_all(all_tasks).await.unwrap(); + + let aggregated_result = EvaluationResult { + context: Scores::aggregate(all_results.iter().map(|r| &r.context)), + edit_prediction: Scores::aggregate(all_results.iter().map(|r| &r.edit_prediction)), + }; + + if example_len > 1 { + println!("\n{}", "-".repeat(80)); + println!("# TOTAL SCORES:"); + println!("{}", aggregated_result.to_markdown()); + } +} + +pub async fn run_evaluate_one( + example_path: &Path, + re_run: bool, + app_state: Arc, + cx: &mut AsyncApp, +) -> Result { + let cache_dir = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap_or_default()) + .join("../../target/zeta-prediction-cache"); + let example = NamedExample::load(&example_path).unwrap(); + let example_cache_path = cache_dir.join(&example_path.file_name().unwrap()); + + let predictions = if !re_run && example_cache_path.exists() { + let file_contents = fs::read_to_string(&example_cache_path)?; + let as_json = serde_json::from_str::(&file_contents)?; + log::debug!( + "Loaded predictions from cache: {}", + example_cache_path.display() + ); + as_json + } else { + zeta2_predict(example.clone(), &app_state, cx) + .await + .unwrap() + }; + + if !example_cache_path.exists() { + fs::create_dir_all(&cache_dir).unwrap(); + fs::write( + example_cache_path, + serde_json::to_string(&predictions).unwrap(), + ) + .unwrap(); + } + + let evaluation_result = evaluate(&example.example, &predictions); + + println!("# {}\n", example.name); + println!( + "## Expected Context: \n\n```\n{}\n```\n\n", + compare_context(&example.example, &predictions) + ); + println!( + "## Expected edit prediction:\n\n```diff\n{}\n```\n", + compare_diffs(&example.example.expected_patch, &predictions.diff) + ); + println!( + "## Actual edit prediction:\n\n```diff\n{}\n```\n", + compare_diffs(&predictions.diff, &example.example.expected_patch) + ); + + println!("{}", evaluation_result.to_markdown()); + + anyhow::Ok(evaluation_result) +} + +#[derive(Debug, Default)] +pub struct EvaluationResult { + pub context: Scores, + pub edit_prediction: Scores, +} + +#[derive(Default, Debug)] +pub struct Scores { + pub precision: f64, + pub recall: f64, + pub f1_score: f64, + pub true_positives: usize, + pub false_positives: usize, + pub false_negatives: usize, +} + +impl Scores { + pub fn to_markdown(&self) -> String { + format!( + " +Precision : {:.4} +Recall : {:.4} +F1 Score : {:.4} +True Positives : {} +False Positives : {} +False Negatives : {}", + self.precision, + self.recall, + self.f1_score, + self.true_positives, + self.false_positives, + self.false_negatives + ) + } +} + +impl Scores { + pub fn aggregate<'a>(scores: impl Iterator) -> Scores { + let mut true_positives = 0; + let mut false_positives = 0; + let mut false_negatives = 0; + + for score in scores { + true_positives += score.true_positives; + false_positives += score.false_positives; + false_negatives += score.false_negatives; + } + + let precision = true_positives as f64 / (true_positives + false_positives) as f64; + let recall = true_positives as f64 / (true_positives + false_negatives) as f64; + let mut f1_score = 2.0 * precision * recall / (precision + recall); + if f1_score.is_nan() { + f1_score = 0.0; + } + + Scores { + precision, + recall, + f1_score, + true_positives, + false_positives, + false_negatives, + } + } +} + +impl EvaluationResult { + pub fn to_markdown(&self) -> String { + format!( + r#" +### Context Scores +{} + +### Edit Prediction Scores +{} +"#, + self.context.to_markdown(), + self.edit_prediction.to_markdown() + ) + } +} + +pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult { + let mut result = EvaluationResult::default(); + + let expected_context_lines = example + .expected_excerpts + .iter() + .flat_map(|excerpt| { + excerpt + .text + .lines() + .map(|line| format!("{}: {line}", excerpt.path.display())) + }) + .collect(); + let actual_context_lines = preds + .excerpts + .iter() + .flat_map(|excerpt| { + excerpt + .text + .lines() + .map(|line| format!("{}: {line}", excerpt.path.display())) + }) + .collect(); + + result.context = precision_recall(&expected_context_lines, &actual_context_lines); + + let expected_patch_lines = example + .expected_patch + .lines() + .map(DiffLine::parse) + .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_))) + .map(|line| line.to_string()) + .collect(); + + let actual_patch_lines = preds + .diff + .lines() + .map(DiffLine::parse) + .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_))) + .map(|line| line.to_string()) + .collect(); + + result.edit_prediction = precision_recall(&expected_patch_lines, &actual_patch_lines); + + result +} + +fn precision_recall(expected: &HashSet, actual: &HashSet) -> Scores { + let true_positives = expected.intersection(actual).count(); + let false_positives = actual.difference(expected).count(); + let false_negatives = expected.difference(actual).count(); + + let precision = if true_positives + false_positives == 0 { + 0.0 + } else { + true_positives as f64 / (true_positives + false_positives) as f64 + }; + let recall = if true_positives + false_negatives == 0 { + 0.0 + } else { + true_positives as f64 / (true_positives + false_negatives) as f64 + }; + let f1_score = if precision + recall == 0.0 { + 0.0 + } else { + 2.0 * precision * recall / (precision + recall) + }; + + Scores { + precision, + recall, + f1_score, + true_positives, + false_positives, + false_negatives, + } +} + +/// Compare actual and expected context. +/// +/// Return expected context annotated with these markers: +/// +/// `✓ context line` -- line was correctly predicted +/// `✗ context line` -- line is missing from predictions +pub fn compare_context(example: &Example, preds: &PredictionDetails) -> String { + let use_color = std::io::stdout().is_terminal(); + let green = if use_color { "\x1b[32m" } else { "" }; + let red = if use_color { "\x1b[31m" } else { "" }; + let reset = if use_color { "\x1b[0m" } else { "" }; + let expected: Vec<_> = example + .expected_excerpts + .iter() + .flat_map(|excerpt| { + excerpt + .text + .lines() + .map(|line| (excerpt.path.clone(), line)) + }) + .collect(); + let actual: HashSet<_> = preds + .excerpts + .iter() + .flat_map(|excerpt| { + excerpt + .text + .lines() + .map(|line| (excerpt.path.clone(), line)) + }) + .collect(); + + let annotated = expected + .iter() + .map(|(path, line)| { + if actual.contains(&(path.to_path_buf(), line)) { + format!("{green}✓ {line}{reset}") + } else { + format!("{red}✗ {line}{reset}") + } + }) + .collect::>(); + + annotated.join("\n") +} + +/// Return annotated `patch_a` so that: +/// Additions and deletions that are not present in `patch_b` will be highlighted in red. +/// Additions and deletions that are present in `patch_b` will be highlighted in green. +pub fn compare_diffs(patch_a: &str, patch_b: &str) -> String { + let use_color = std::io::stdout().is_terminal(); + let green = if use_color { "\x1b[32m✓ " } else { "" }; + let red = if use_color { "\x1b[31m✗ " } else { "" }; + let neutral = if use_color { " " } else { "" }; + let reset = if use_color { "\x1b[0m" } else { "" }; + let lines_a = patch_a.lines().map(DiffLine::parse); + let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect(); + + let annotated = lines_a + .map(|line| match line { + DiffLine::Addition(_) | DiffLine::Deletion(_) => { + if lines_b.contains(&line) { + format!("{green}{line}{reset}") + } else { + format!("{red}{line}{reset}") + } + } + _ => format!("{neutral}{line}{reset}"), + }) + .collect::>(); + + annotated.join("\n") +} diff --git a/crates/zeta_cli/src/example.rs b/crates/zeta_cli/src/example.rs index e742241787cbc714deb6ab934f07bc01218dce10..6537068e84cc46f5ab72a0e1bd9e19445a3ec37e 100644 --- a/crates/zeta_cli/src/example.rs +++ b/crates/zeta_cli/src/example.rs @@ -1,5 +1,6 @@ use std::{ borrow::Cow, + cell::RefCell, env, fmt::{self, Display}, fs, @@ -7,12 +8,16 @@ use std::{ mem, ops::Range, path::{Path, PathBuf}, + sync::Arc, }; use anyhow::{Context as _, Result}; use clap::ValueEnum; -use collections::HashSet; -use futures::AsyncWriteExt as _; +use collections::{HashMap, HashSet}; +use futures::{ + AsyncWriteExt as _, + lock::{Mutex, OwnedMutexGuard}, +}; use gpui::{AsyncApp, Entity, http_client::Url}; use language::Buffer; use project::{Project, ProjectPath}; @@ -27,13 +32,13 @@ const EXPECTED_EXCERPTS_HEADING: &str = "Expected Excerpts"; const REPOSITORY_URL_FIELD: &str = "repository_url"; const REVISION_FIELD: &str = "revision"; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct NamedExample { pub name: String, pub example: Example, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct Example { pub repository_url: String, pub revision: String, @@ -45,10 +50,13 @@ pub struct Example { pub expected_excerpts: Vec, } -#[derive(Debug, Serialize, Deserialize)] -pub struct ExpectedExcerpt { - path: PathBuf, - text: String, +pub type ExpectedExcerpt = Excerpt; +pub type ActualExcerpt = Excerpt; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Excerpt { + pub path: PathBuf, + pub text: String, } #[derive(ValueEnum, Debug, Clone)] @@ -171,6 +179,7 @@ impl NamedExample { } else if current_section.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) { named.example.expected_patch = mem::take(&mut text); } else if current_section.eq_ignore_ascii_case(EXPECTED_EXCERPTS_HEADING) { + // TODO: "…" should not be a part of the excerpt named.example.expected_excerpts.push(ExpectedExcerpt { path: block_info.into(), text: mem::take(&mut text), @@ -202,7 +211,6 @@ impl NamedExample { } } - #[allow(unused)] pub async fn setup_worktree(&self) -> Result { let (repo_owner, repo_name) = self.repo_name()?; let file_name = self.file_name(); @@ -213,6 +221,8 @@ impl NamedExample { fs::create_dir_all(&worktrees_dir)?; let repo_dir = repos_dir.join(repo_owner.as_ref()).join(repo_name.as_ref()); + let repo_lock = lock_repo(&repo_dir).await; + if !repo_dir.is_dir() { fs::create_dir_all(&repo_dir)?; run_git(&repo_dir, &["init"]).await?; @@ -255,6 +265,7 @@ impl NamedExample { ) .await?; } + drop(repo_lock); // Apply the uncommitted diff for this example. if !self.example.uncommitted_diff.is_empty() { @@ -419,6 +430,23 @@ impl Display for NamedExample { } } +thread_local! { + static REPO_LOCKS: RefCell>>> = RefCell::new(HashMap::default()); +} + +#[must_use] +pub async fn lock_repo(path: impl AsRef) -> OwnedMutexGuard<()> { + REPO_LOCKS + .with(|cell| { + cell.borrow_mut() + .entry(path.as_ref().to_path_buf()) + .or_default() + .clone() + }) + .lock_owned() + .await +} + #[must_use] pub async fn apply_diff( diff: &str, @@ -487,7 +515,7 @@ pub async fn apply_diff( }); } } - DiffLine::HunkHeader(_) | DiffLine::Garbage => {} + DiffLine::HunkHeader(_) | DiffLine::Garbage(_) => {} } let at_hunk_end = match diff_lines.peek() { diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index f0d1cb3fd445d841c2f237c2f828c65c326836ea..43d5b899c8e1c3f3656d5752ffd226dc4b73656d 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -1,14 +1,17 @@ +mod evaluate; mod example; mod headless; +mod predict; mod source_location; mod syntax_retrieval_stats; mod util; +use crate::evaluate::{EvaluateArguments, run_evaluate}; use crate::example::{ExampleFormat, NamedExample}; +use crate::predict::{PredictArguments, run_zeta2_predict}; use crate::syntax_retrieval_stats::retrieval_stats; use ::serde::Serialize; use ::util::paths::PathStyle; -use ::util::rel_path::RelPath; use anyhow::{Context as _, Result, anyhow}; use clap::{Args, Parser, Subcommand}; use cloud_llm_client::predict_edits_v3::{self, Excerpt}; @@ -22,12 +25,12 @@ use futures::channel::mpsc; use gpui::{Application, AsyncApp, Entity, prelude::*}; use language::{Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point}; use language_model::LanguageModelRegistry; -use project::{Project, ProjectPath, Worktree}; +use project::{Project, Worktree}; use reqwest_client::ReqwestClient; use serde_json::json; -use std::io; -use std::time::{Duration, Instant}; -use std::{collections::HashSet, path::PathBuf, process::exit, str::FromStr, sync::Arc}; +use std::io::{self}; +use std::time::Duration; +use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc}; use zeta2::{ContextMode, LlmContextOptions, SearchToolQuery}; use crate::headless::ZetaCliAppState; @@ -82,9 +85,8 @@ enum Zeta2Command { #[command(subcommand)] command: Zeta2LlmCommand, }, - Predict { - example_path: PathBuf, - }, + Predict(PredictArguments), + Eval(EvaluateArguments), } #[derive(Subcommand, Debug)] @@ -327,208 +329,6 @@ async fn load_context( }) } -async fn zeta2_predict( - example: NamedExample, - app_state: &Arc, - cx: &mut AsyncApp, -) -> Result<()> { - let worktree_path = example.setup_worktree().await?; - - cx.update(|cx| { - LanguageModelRegistry::global(cx).update(cx, |registry, cx| { - registry - .provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID) - .unwrap() - .authenticate(cx) - }) - })? - .await?; - - app_state - .client - .sign_in_with_optional_connect(true, cx) - .await?; - - let project = cx.update(|cx| { - Project::local( - app_state.client.clone(), - app_state.node_runtime.clone(), - app_state.user_store.clone(), - app_state.languages.clone(), - app_state.fs.clone(), - None, - cx, - ) - })?; - - let worktree = project - .update(cx, |project, cx| { - project.create_worktree(&worktree_path, true, cx) - })? - .await?; - worktree - .read_with(cx, |worktree, _cx| { - worktree.as_local().unwrap().scan_complete() - })? - .await; - - let _edited_buffers = example.apply_edit_history(&project, cx).await?; - - let cursor_path = RelPath::new(&example.example.cursor_path, PathStyle::Posix)?.into_arc(); - - let cursor_buffer = project - .update(cx, |project, cx| { - project.open_buffer( - ProjectPath { - worktree_id: worktree.read(cx).id(), - path: cursor_path, - }, - cx, - ) - })? - .await?; - - let cursor_offset_within_excerpt = example - .example - .cursor_position - .find(CURSOR_MARKER) - .ok_or_else(|| anyhow!("missing cursor marker"))?; - let mut cursor_excerpt = example.example.cursor_position.clone(); - cursor_excerpt.replace_range( - cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()), - "", - ); - let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| { - let text = buffer.text(); - - let mut matches = text.match_indices(&cursor_excerpt); - let Some((excerpt_offset, _)) = matches.next() else { - anyhow::bail!( - "Cursor excerpt did not exist in buffer.\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n" - ); - }; - assert!(matches.next().is_none()); - - Ok(excerpt_offset) - })??; - - let cursor_offset = excerpt_offset + cursor_offset_within_excerpt; - let cursor_anchor = - cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?; - - let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?; - - let refresh_task = zeta.update(cx, |zeta, cx| { - zeta.register_buffer(&cursor_buffer, &project, cx); - zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx) - })?; - - let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?; - let mut context_retrieval_started_at = None; - let mut context_retrieval_finished_at = None; - let mut search_queries_generated_at = None; - let mut search_queries_executed_at = None; - let mut prediction_started_at = None; - let mut prediction_finished_at = None; - let mut excerpts_text = String::new(); - let mut prediction_task = None; - while let Some(event) = debug_rx.next().await { - match event { - zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => { - context_retrieval_started_at = Some(info.timestamp); - } - zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => { - search_queries_generated_at = Some(info.timestamp); - } - zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => { - search_queries_executed_at = Some(info.timestamp); - } - zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => { - context_retrieval_finished_at = Some(info.timestamp); - - prediction_task = Some(zeta.update(cx, |zeta, cx| { - zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx) - })?); - } - zeta2::ZetaDebugInfo::EditPredicted(request) => { - prediction_started_at = Some(Instant::now()); - request.response_rx.await?.map_err(|err| anyhow!(err))?; - prediction_finished_at = Some(Instant::now()); - - for included_file in request.request.included_files { - let insertions = vec![(request.request.cursor_point, CURSOR_MARKER)]; - write_codeblock( - &included_file.path, - included_file.excerpts.iter(), - if included_file.path == request.request.excerpt_path { - &insertions - } else { - &[] - }, - included_file.max_row, - false, - &mut excerpts_text, - ); - } - break; - } - _ => {} - } - } - - refresh_task.await.context("context retrieval failed")?; - let prediction = prediction_task.unwrap().await?.context("No prediction")?; - - println!("## Excerpts\n"); - println!("{excerpts_text}"); - - let old_text = prediction.snapshot.text(); - let new_text = prediction.buffer.update(cx, |buffer, cx| { - buffer.edit(prediction.edits.iter().cloned(), None, cx); - buffer.text() - })?; - let diff = language::unified_diff(&old_text, &new_text); - - println!("## Prediction\n"); - println!("{diff}"); - - println!("## Time\n"); - - let planning_search_time = - search_queries_generated_at.unwrap() - context_retrieval_started_at.unwrap(); - - println!("Planning searches: {}ms", planning_search_time.as_millis()); - println!( - "Running searches: {}ms", - (search_queries_executed_at.unwrap() - search_queries_generated_at.unwrap()).as_millis() - ); - - let filtering_search_time = - context_retrieval_finished_at.unwrap() - search_queries_executed_at.unwrap(); - println!( - "Filtering context results: {}ms", - filtering_search_time.as_millis() - ); - - let prediction_time = prediction_finished_at.unwrap() - prediction_started_at.unwrap(); - println!("Making Prediction: {}ms", prediction_time.as_millis()); - - println!("-------------------"); - let total_time = - (prediction_finished_at.unwrap() - context_retrieval_started_at.unwrap()).as_millis(); - println!("Total: {}ms", total_time); - - let inference_time = - (planning_search_time + filtering_search_time + prediction_time).as_millis(); - println!( - "Inference: {}ms ({:.2}%)", - inference_time, - (inference_time as f64 / total_time as f64) * 100. - ); - - anyhow::Ok(()) -} - async fn zeta2_syntax_context( zeta2_args: Zeta2Args, syntax_args: Zeta2SyntaxArgs, @@ -819,50 +619,62 @@ fn main() { app.run(move |cx| { let app_state = Arc::new(headless::init(cx)); cx.spawn(async move |cx| { - let result = match args.command { + match args.command { Command::Zeta1 { command: Zeta1Command::Context { context_args }, } => { let context = zeta1_context(context_args, &app_state, cx).await.unwrap(); - serde_json::to_string_pretty(&context.body).map_err(|err| anyhow::anyhow!(err)) + let result = serde_json::to_string_pretty(&context.body).unwrap(); + println!("{}", result); } Command::Zeta2 { command } => match command { - Zeta2Command::Predict { example_path } => { - let example = NamedExample::load(example_path).unwrap(); - zeta2_predict(example, &app_state, cx).await.unwrap(); - let _ = cx.update(|cx| cx.quit()); - return; + Zeta2Command::Predict(arguments) => { + run_zeta2_predict(arguments, &app_state, cx).await; + } + Zeta2Command::Eval(arguments) => { + run_evaluate(arguments, &app_state, cx).await; } Zeta2Command::Syntax { args, syntax_args, command, - } => match command { - Zeta2SyntaxCommand::Context { context_args } => { - zeta2_syntax_context(args, syntax_args, context_args, &app_state, cx) + } => { + let result = match command { + Zeta2SyntaxCommand::Context { context_args } => { + zeta2_syntax_context( + args, + syntax_args, + context_args, + &app_state, + cx, + ) .await - } - Zeta2SyntaxCommand::Stats { - worktree, - extension, - limit, - skip, - } => { - retrieval_stats( + } + Zeta2SyntaxCommand::Stats { worktree, - app_state, extension, limit, skip, - syntax_args_to_options(&args, &syntax_args, false), - cx, - ) - .await - } - }, + } => { + retrieval_stats( + worktree, + app_state, + extension, + limit, + skip, + syntax_args_to_options(&args, &syntax_args, false), + cx, + ) + .await + } + }; + println!("{}", result.unwrap()); + } Zeta2Command::Llm { args, command } => match command { Zeta2LlmCommand::Context { context_args } => { - zeta2_llm_context(args, context_args, &app_state, cx).await + let result = + zeta2_llm_context(args, context_args, &app_state, cx).await; + println!("{}", result.unwrap()); } }, }, @@ -872,21 +684,10 @@ fn main() { } => { let example = NamedExample::load(path).unwrap(); example.write(output_format, io::stdout()).unwrap(); - let _ = cx.update(|cx| cx.quit()); - return; } }; - match result { - Ok(output) => { - println!("{}", output); - let _ = cx.update(|cx| cx.quit()); - } - Err(e) => { - eprintln!("Failed: {:?}", e); - exit(1); - } - } + let _ = cx.update(|cx| cx.quit()); }) .detach(); }); diff --git a/crates/zeta_cli/src/predict.rs b/crates/zeta_cli/src/predict.rs new file mode 100644 index 0000000000000000000000000000000000000000..cdf385ae6db0556180ae9a223ff32efc53ad9a02 --- /dev/null +++ b/crates/zeta_cli/src/predict.rs @@ -0,0 +1,287 @@ +use crate::example::{ActualExcerpt, NamedExample}; + +use crate::headless::ZetaCliAppState; +use ::serde::Serialize; +use ::util::paths::PathStyle; +use anyhow::{Context as _, Result, anyhow}; +use clap::Args; +use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock}; +use futures::StreamExt as _; +use gpui::AsyncApp; +use language_model::LanguageModelRegistry; +use project::{Project, ProjectPath}; +use serde::Deserialize; +use std::cell::Cell; +use std::io::Write; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use util::rel_path::RelPath; + +#[derive(Debug, Args)] +pub struct PredictArguments { + example_path: PathBuf, + #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)] + format: PredictionsOutputFormat, +} + +#[derive(clap::ValueEnum, Debug, Clone)] +pub enum PredictionsOutputFormat { + Json, + Md, + Diff, +} +pub async fn run_zeta2_predict( + args: PredictArguments, + app_state: &Arc, + cx: &mut AsyncApp, +) { + let example = NamedExample::load(args.example_path).unwrap(); + let result = zeta2_predict(example, &app_state, cx).await.unwrap(); + result.write(args.format, std::io::stdout()).unwrap(); +} + +thread_local! { + static AUTHENTICATED: Cell = const { Cell::new(false) }; +} + +pub async fn zeta2_predict( + example: NamedExample, + app_state: &Arc, + cx: &mut AsyncApp, +) -> Result { + let worktree_path = example.setup_worktree().await?; + + if !AUTHENTICATED.get() { + AUTHENTICATED.set(true); + + cx.update(|cx| { + LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + registry + .provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID) + .unwrap() + .authenticate(cx) + }) + })? + .await?; + + app_state + .client + .sign_in_with_optional_connect(true, cx) + .await?; + } + + let project = cx.update(|cx| { + Project::local( + app_state.client.clone(), + app_state.node_runtime.clone(), + app_state.user_store.clone(), + app_state.languages.clone(), + app_state.fs.clone(), + None, + cx, + ) + })?; + + let worktree = project + .update(cx, |project, cx| { + project.create_worktree(&worktree_path, true, cx) + })? + .await?; + worktree + .read_with(cx, |worktree, _cx| { + worktree.as_local().unwrap().scan_complete() + })? + .await; + + let _edited_buffers = example.apply_edit_history(&project, cx).await?; + + let cursor_path = RelPath::new(&example.example.cursor_path, PathStyle::Posix)?.into_arc(); + + let cursor_buffer = project + .update(cx, |project, cx| { + project.open_buffer( + ProjectPath { + worktree_id: worktree.read(cx).id(), + path: cursor_path, + }, + cx, + ) + })? + .await?; + + let cursor_offset_within_excerpt = example + .example + .cursor_position + .find(CURSOR_MARKER) + .ok_or_else(|| anyhow!("missing cursor marker"))?; + let mut cursor_excerpt = example.example.cursor_position.clone(); + cursor_excerpt.replace_range( + cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()), + "", + ); + let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| { + let text = buffer.text(); + + let mut matches = text.match_indices(&cursor_excerpt); + let Some((excerpt_offset, _)) = matches.next() else { + anyhow::bail!( + "Cursor excerpt did not exist in buffer.\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n" + ); + }; + assert!(matches.next().is_none()); + + Ok(excerpt_offset) + })??; + + let cursor_offset = excerpt_offset + cursor_offset_within_excerpt; + let cursor_anchor = + cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?; + + let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?; + + let refresh_task = zeta.update(cx, |zeta, cx| { + zeta.register_buffer(&cursor_buffer, &project, cx); + zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx) + })?; + + let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?; + let mut context_retrieval_started_at = None; + let mut context_retrieval_finished_at = None; + let mut search_queries_generated_at = None; + let mut search_queries_executed_at = None; + let mut prediction_started_at = None; + let mut prediction_finished_at = None; + let mut excerpts_text = String::new(); + let mut prediction_task = None; + let mut result = PredictionDetails::default(); + while let Some(event) = debug_rx.next().await { + match event { + zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => { + context_retrieval_started_at = Some(info.timestamp); + } + zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => { + search_queries_generated_at = Some(info.timestamp); + } + zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => { + search_queries_executed_at = Some(info.timestamp); + } + zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => { + context_retrieval_finished_at = Some(info.timestamp); + + prediction_task = Some(zeta.update(cx, |zeta, cx| { + zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx) + })?); + } + zeta2::ZetaDebugInfo::EditPredicted(request) => { + prediction_started_at = Some(Instant::now()); + request.response_rx.await?.map_err(|err| anyhow!(err))?; + prediction_finished_at = Some(Instant::now()); + + for included_file in request.request.included_files { + let insertions = vec![(request.request.cursor_point, CURSOR_MARKER)]; + result + .excerpts + .extend(included_file.excerpts.iter().map(|excerpt| ActualExcerpt { + path: included_file.path.components().skip(1).collect(), + text: String::from(excerpt.text.as_ref()), + })); + write_codeblock( + &included_file.path, + included_file.excerpts.iter(), + if included_file.path == request.request.excerpt_path { + &insertions + } else { + &[] + }, + included_file.max_row, + false, + &mut excerpts_text, + ); + } + break; + } + _ => {} + } + } + + refresh_task.await.context("context retrieval failed")?; + let prediction = prediction_task.unwrap().await?; + + result.diff = prediction + .map(|prediction| { + let old_text = prediction.snapshot.text(); + let new_text = prediction.buffer.update(cx, |buffer, cx| { + buffer.edit(prediction.edits.iter().cloned(), None, cx); + buffer.text() + })?; + anyhow::Ok(language::unified_diff(&old_text, &new_text)) + }) + .transpose()? + .unwrap_or_default(); + result.excerpts_text = excerpts_text; + + result.planning_search_time = + search_queries_generated_at.unwrap() - context_retrieval_started_at.unwrap(); + result.running_search_time = + search_queries_executed_at.unwrap() - search_queries_generated_at.unwrap(); + result.filtering_search_time = + context_retrieval_finished_at.unwrap() - search_queries_executed_at.unwrap(); + result.prediction_time = prediction_finished_at.unwrap() - prediction_started_at.unwrap(); + result.total_time = prediction_finished_at.unwrap() - context_retrieval_started_at.unwrap(); + + anyhow::Ok(result) +} + +#[derive(Debug, Default, Serialize, Deserialize)] +pub struct PredictionDetails { + pub diff: String, + pub excerpts: Vec, + pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly + pub planning_search_time: Duration, + pub filtering_search_time: Duration, + pub running_search_time: Duration, + pub prediction_time: Duration, + pub total_time: Duration, +} + +impl PredictionDetails { + pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> { + let formatted = match format { + PredictionsOutputFormat::Md => self.to_markdown(), + PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?, + PredictionsOutputFormat::Diff => self.diff.clone(), + }; + + Ok(out.write_all(formatted.as_bytes())?) + } + + pub fn to_markdown(&self) -> String { + let inference_time = + self.planning_search_time + self.filtering_search_time + self.prediction_time; + + format!( + "## Excerpts\n\n\ + {}\n\n\ + ## Prediction\n\n\ + {}\n\n\ + ## Time\n\n\ + Planning searches: {}ms\n\ + Running searches: {}ms\n\ + Filtering context results: {}ms\n\ + Making Prediction: {}ms\n\n\ + -------------------\n\n\ + Total: {}ms\n\ + Inference: {}ms ({:.2}%)\n", + self.excerpts_text, + self.diff, + self.planning_search_time.as_millis(), + self.running_search_time.as_millis(), + self.filtering_search_time.as_millis(), + self.prediction_time.as_millis(), + self.total_time.as_millis(), + inference_time.as_millis(), + (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100. + ) + } +}