Detailed changes
@@ -1,4 +1,4 @@
-use std::borrow::Cow;
+use std::{borrow::Cow, fmt::Display};
#[derive(Debug, PartialEq)]
pub enum DiffLine<'a> {
@@ -8,7 +8,7 @@ pub enum DiffLine<'a> {
Context(&'a str),
Deletion(&'a str),
Addition(&'a str),
- Garbage,
+ Garbage(&'a str),
}
#[derive(Debug, PartialEq)]
@@ -21,7 +21,7 @@ pub struct HunkLocation {
impl<'a> DiffLine<'a> {
pub fn parse(line: &'a str) -> Self {
- Self::try_parse(line).unwrap_or(Self::Garbage)
+ Self::try_parse(line).unwrap_or(Self::Garbage(line))
}
fn try_parse(line: &'a str) -> Option<Self> {
@@ -60,6 +60,30 @@ impl<'a> DiffLine<'a> {
}
}
+impl<'a> Display for DiffLine<'a> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ DiffLine::OldPath { path } => write!(f, "--- {path}"),
+ DiffLine::NewPath { path } => write!(f, "+++ {path}"),
+ DiffLine::HunkHeader(Some(hunk_location)) => {
+ write!(
+ f,
+ "@@ -{},{} +{},{} @@",
+ hunk_location.start_line_old + 1,
+ hunk_location.count_old,
+ hunk_location.start_line_new + 1,
+ hunk_location.count_new
+ )
+ }
+ DiffLine::HunkHeader(None) => write!(f, "@@ ... @@"),
+ DiffLine::Context(content) => write!(f, " {content}"),
+ DiffLine::Deletion(content) => write!(f, "-{content}"),
+ DiffLine::Addition(content) => write!(f, "+{content}"),
+ DiffLine::Garbage(line) => write!(f, "{line}"),
+ }
+ }
+}
+
fn parse_header_path<'a>(strip_prefix: &'static str, header: &'a str) -> Cow<'a, str> {
if !header.contains(['"', '\\']) {
let path = header.split_ascii_whitespace().next().unwrap_or(header);
@@ -134,8 +158,8 @@ mod tests {
pretty_assertions::assert_eq!(
lines,
&[
- DiffLine::Garbage,
- DiffLine::Garbage,
+ DiffLine::Garbage("diff --git a/text.txt b/text.txt"),
+ DiffLine::Garbage("index 86c770d..a1fd855 100644"),
DiffLine::OldPath {
path: "file.txt".into()
},
@@ -151,7 +175,7 @@ mod tests {
DiffLine::Context("context"),
DiffLine::Deletion("deleted"),
DiffLine::Addition("inserted"),
- DiffLine::Garbage,
+ DiffLine::Garbage("garbage"),
DiffLine::Context(""),
DiffLine::OldPath {
path: "b/file.txt".into()
@@ -0,0 +1,338 @@
+use std::{
+ fs,
+ io::IsTerminal,
+ path::{Path, PathBuf},
+ sync::Arc,
+};
+
+use anyhow::Result;
+use clap::Args;
+use cloud_llm_client::udiff::DiffLine;
+use collections::HashSet;
+use gpui::AsyncApp;
+
+use crate::{
+ example::{Example, NamedExample},
+ headless::ZetaCliAppState,
+ predict::{PredictionDetails, zeta2_predict},
+};
+
+#[derive(Debug, Args)]
+pub struct EvaluateArguments {
+ example_paths: Vec<PathBuf>,
+ #[clap(long)]
+ re_run: bool,
+}
+
+pub async fn run_evaluate(
+ args: EvaluateArguments,
+ app_state: &Arc<ZetaCliAppState>,
+ cx: &mut AsyncApp,
+) {
+ let example_len = args.example_paths.len();
+ let all_tasks = args.example_paths.into_iter().map(|path| {
+ let app_state = app_state.clone();
+ cx.spawn(async move |cx| run_evaluate_one(&path, args.re_run, app_state.clone(), cx).await)
+ });
+ let all_results = futures::future::try_join_all(all_tasks).await.unwrap();
+
+ let aggregated_result = EvaluationResult {
+ context: Scores::aggregate(all_results.iter().map(|r| &r.context)),
+ edit_prediction: Scores::aggregate(all_results.iter().map(|r| &r.edit_prediction)),
+ };
+
+ if example_len > 1 {
+ println!("\n{}", "-".repeat(80));
+ println!("# TOTAL SCORES:");
+ println!("{}", aggregated_result.to_markdown());
+ }
+}
+
+pub async fn run_evaluate_one(
+ example_path: &Path,
+ re_run: bool,
+ app_state: Arc<ZetaCliAppState>,
+ cx: &mut AsyncApp,
+) -> Result<EvaluationResult> {
+ let cache_dir = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap_or_default())
+ .join("../../target/zeta-prediction-cache");
+ let example = NamedExample::load(&example_path).unwrap();
+ let example_cache_path = cache_dir.join(&example_path.file_name().unwrap());
+
+ let predictions = if !re_run && example_cache_path.exists() {
+ let file_contents = fs::read_to_string(&example_cache_path)?;
+ let as_json = serde_json::from_str::<PredictionDetails>(&file_contents)?;
+ log::debug!(
+ "Loaded predictions from cache: {}",
+ example_cache_path.display()
+ );
+ as_json
+ } else {
+ zeta2_predict(example.clone(), &app_state, cx)
+ .await
+ .unwrap()
+ };
+
+ if !example_cache_path.exists() {
+ fs::create_dir_all(&cache_dir).unwrap();
+ fs::write(
+ example_cache_path,
+ serde_json::to_string(&predictions).unwrap(),
+ )
+ .unwrap();
+ }
+
+ let evaluation_result = evaluate(&example.example, &predictions);
+
+ println!("# {}\n", example.name);
+ println!(
+ "## Expected Context: \n\n```\n{}\n```\n\n",
+ compare_context(&example.example, &predictions)
+ );
+ println!(
+ "## Expected edit prediction:\n\n```diff\n{}\n```\n",
+ compare_diffs(&example.example.expected_patch, &predictions.diff)
+ );
+ println!(
+ "## Actual edit prediction:\n\n```diff\n{}\n```\n",
+ compare_diffs(&predictions.diff, &example.example.expected_patch)
+ );
+
+ println!("{}", evaluation_result.to_markdown());
+
+ anyhow::Ok(evaluation_result)
+}
+
+#[derive(Debug, Default)]
+pub struct EvaluationResult {
+ pub context: Scores,
+ pub edit_prediction: Scores,
+}
+
+#[derive(Default, Debug)]
+pub struct Scores {
+ pub precision: f64,
+ pub recall: f64,
+ pub f1_score: f64,
+ pub true_positives: usize,
+ pub false_positives: usize,
+ pub false_negatives: usize,
+}
+
+impl Scores {
+ pub fn to_markdown(&self) -> String {
+ format!(
+ "
+Precision : {:.4}
+Recall : {:.4}
+F1 Score : {:.4}
+True Positives : {}
+False Positives : {}
+False Negatives : {}",
+ self.precision,
+ self.recall,
+ self.f1_score,
+ self.true_positives,
+ self.false_positives,
+ self.false_negatives
+ )
+ }
+}
+
+impl Scores {
+ pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
+ let mut true_positives = 0;
+ let mut false_positives = 0;
+ let mut false_negatives = 0;
+
+ for score in scores {
+ true_positives += score.true_positives;
+ false_positives += score.false_positives;
+ false_negatives += score.false_negatives;
+ }
+
+ let precision = true_positives as f64 / (true_positives + false_positives) as f64;
+ let recall = true_positives as f64 / (true_positives + false_negatives) as f64;
+ let mut f1_score = 2.0 * precision * recall / (precision + recall);
+ if f1_score.is_nan() {
+ f1_score = 0.0;
+ }
+
+ Scores {
+ precision,
+ recall,
+ f1_score,
+ true_positives,
+ false_positives,
+ false_negatives,
+ }
+ }
+}
+
+impl EvaluationResult {
+ pub fn to_markdown(&self) -> String {
+ format!(
+ r#"
+### Context Scores
+{}
+
+### Edit Prediction Scores
+{}
+"#,
+ self.context.to_markdown(),
+ self.edit_prediction.to_markdown()
+ )
+ }
+}
+
+pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
+ let mut result = EvaluationResult::default();
+
+ let expected_context_lines = example
+ .expected_excerpts
+ .iter()
+ .flat_map(|excerpt| {
+ excerpt
+ .text
+ .lines()
+ .map(|line| format!("{}: {line}", excerpt.path.display()))
+ })
+ .collect();
+ let actual_context_lines = preds
+ .excerpts
+ .iter()
+ .flat_map(|excerpt| {
+ excerpt
+ .text
+ .lines()
+ .map(|line| format!("{}: {line}", excerpt.path.display()))
+ })
+ .collect();
+
+ result.context = precision_recall(&expected_context_lines, &actual_context_lines);
+
+ let expected_patch_lines = example
+ .expected_patch
+ .lines()
+ .map(DiffLine::parse)
+ .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
+ .map(|line| line.to_string())
+ .collect();
+
+ let actual_patch_lines = preds
+ .diff
+ .lines()
+ .map(DiffLine::parse)
+ .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
+ .map(|line| line.to_string())
+ .collect();
+
+ result.edit_prediction = precision_recall(&expected_patch_lines, &actual_patch_lines);
+
+ result
+}
+
+fn precision_recall(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
+ let true_positives = expected.intersection(actual).count();
+ let false_positives = actual.difference(expected).count();
+ let false_negatives = expected.difference(actual).count();
+
+ let precision = if true_positives + false_positives == 0 {
+ 0.0
+ } else {
+ true_positives as f64 / (true_positives + false_positives) as f64
+ };
+ let recall = if true_positives + false_negatives == 0 {
+ 0.0
+ } else {
+ true_positives as f64 / (true_positives + false_negatives) as f64
+ };
+ let f1_score = if precision + recall == 0.0 {
+ 0.0
+ } else {
+ 2.0 * precision * recall / (precision + recall)
+ };
+
+ Scores {
+ precision,
+ recall,
+ f1_score,
+ true_positives,
+ false_positives,
+ false_negatives,
+ }
+}
+
+/// Compare actual and expected context.
+///
+/// Return expected context annotated with these markers:
+///
+/// `✓ context line` -- line was correctly predicted
+/// `✗ context line` -- line is missing from predictions
+pub fn compare_context(example: &Example, preds: &PredictionDetails) -> String {
+ let use_color = std::io::stdout().is_terminal();
+ let green = if use_color { "\x1b[32m" } else { "" };
+ let red = if use_color { "\x1b[31m" } else { "" };
+ let reset = if use_color { "\x1b[0m" } else { "" };
+ let expected: Vec<_> = example
+ .expected_excerpts
+ .iter()
+ .flat_map(|excerpt| {
+ excerpt
+ .text
+ .lines()
+ .map(|line| (excerpt.path.clone(), line))
+ })
+ .collect();
+ let actual: HashSet<_> = preds
+ .excerpts
+ .iter()
+ .flat_map(|excerpt| {
+ excerpt
+ .text
+ .lines()
+ .map(|line| (excerpt.path.clone(), line))
+ })
+ .collect();
+
+ let annotated = expected
+ .iter()
+ .map(|(path, line)| {
+ if actual.contains(&(path.to_path_buf(), line)) {
+ format!("{green}✓ {line}{reset}")
+ } else {
+ format!("{red}✗ {line}{reset}")
+ }
+ })
+ .collect::<Vec<String>>();
+
+ annotated.join("\n")
+}
+
+/// Return annotated `patch_a` so that:
+/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
+/// Additions and deletions that are present in `patch_b` will be highlighted in green.
+pub fn compare_diffs(patch_a: &str, patch_b: &str) -> String {
+ let use_color = std::io::stdout().is_terminal();
+ let green = if use_color { "\x1b[32m✓ " } else { "" };
+ let red = if use_color { "\x1b[31m✗ " } else { "" };
+ let neutral = if use_color { " " } else { "" };
+ let reset = if use_color { "\x1b[0m" } else { "" };
+ let lines_a = patch_a.lines().map(DiffLine::parse);
+ let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
+
+ let annotated = lines_a
+ .map(|line| match line {
+ DiffLine::Addition(_) | DiffLine::Deletion(_) => {
+ if lines_b.contains(&line) {
+ format!("{green}{line}{reset}")
+ } else {
+ format!("{red}{line}{reset}")
+ }
+ }
+ _ => format!("{neutral}{line}{reset}"),
+ })
+ .collect::<Vec<String>>();
+
+ annotated.join("\n")
+}
@@ -1,5 +1,6 @@
use std::{
borrow::Cow,
+ cell::RefCell,
env,
fmt::{self, Display},
fs,
@@ -7,12 +8,16 @@ use std::{
mem,
ops::Range,
path::{Path, PathBuf},
+ sync::Arc,
};
use anyhow::{Context as _, Result};
use clap::ValueEnum;
-use collections::HashSet;
-use futures::AsyncWriteExt as _;
+use collections::{HashMap, HashSet};
+use futures::{
+ AsyncWriteExt as _,
+ lock::{Mutex, OwnedMutexGuard},
+};
use gpui::{AsyncApp, Entity, http_client::Url};
use language::Buffer;
use project::{Project, ProjectPath};
@@ -27,13 +32,13 @@ const EXPECTED_EXCERPTS_HEADING: &str = "Expected Excerpts";
const REPOSITORY_URL_FIELD: &str = "repository_url";
const REVISION_FIELD: &str = "revision";
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct NamedExample {
pub name: String,
pub example: Example,
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Example {
pub repository_url: String,
pub revision: String,
@@ -45,10 +50,13 @@ pub struct Example {
pub expected_excerpts: Vec<ExpectedExcerpt>,
}
-#[derive(Debug, Serialize, Deserialize)]
-pub struct ExpectedExcerpt {
- path: PathBuf,
- text: String,
+pub type ExpectedExcerpt = Excerpt;
+pub type ActualExcerpt = Excerpt;
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct Excerpt {
+ pub path: PathBuf,
+ pub text: String,
}
#[derive(ValueEnum, Debug, Clone)]
@@ -171,6 +179,7 @@ impl NamedExample {
} else if current_section.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
named.example.expected_patch = mem::take(&mut text);
} else if current_section.eq_ignore_ascii_case(EXPECTED_EXCERPTS_HEADING) {
+ // TODO: "…" should not be a part of the excerpt
named.example.expected_excerpts.push(ExpectedExcerpt {
path: block_info.into(),
text: mem::take(&mut text),
@@ -202,7 +211,6 @@ impl NamedExample {
}
}
- #[allow(unused)]
pub async fn setup_worktree(&self) -> Result<PathBuf> {
let (repo_owner, repo_name) = self.repo_name()?;
let file_name = self.file_name();
@@ -213,6 +221,8 @@ impl NamedExample {
fs::create_dir_all(&worktrees_dir)?;
let repo_dir = repos_dir.join(repo_owner.as_ref()).join(repo_name.as_ref());
+ let repo_lock = lock_repo(&repo_dir).await;
+
if !repo_dir.is_dir() {
fs::create_dir_all(&repo_dir)?;
run_git(&repo_dir, &["init"]).await?;
@@ -255,6 +265,7 @@ impl NamedExample {
)
.await?;
}
+ drop(repo_lock);
// Apply the uncommitted diff for this example.
if !self.example.uncommitted_diff.is_empty() {
@@ -419,6 +430,23 @@ impl Display for NamedExample {
}
}
+thread_local! {
+ static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
+}
+
+#[must_use]
+pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
+ REPO_LOCKS
+ .with(|cell| {
+ cell.borrow_mut()
+ .entry(path.as_ref().to_path_buf())
+ .or_default()
+ .clone()
+ })
+ .lock_owned()
+ .await
+}
+
#[must_use]
pub async fn apply_diff(
diff: &str,
@@ -487,7 +515,7 @@ pub async fn apply_diff(
});
}
}
- DiffLine::HunkHeader(_) | DiffLine::Garbage => {}
+ DiffLine::HunkHeader(_) | DiffLine::Garbage(_) => {}
}
let at_hunk_end = match diff_lines.peek() {
@@ -1,14 +1,17 @@
+mod evaluate;
mod example;
mod headless;
+mod predict;
mod source_location;
mod syntax_retrieval_stats;
mod util;
+use crate::evaluate::{EvaluateArguments, run_evaluate};
use crate::example::{ExampleFormat, NamedExample};
+use crate::predict::{PredictArguments, run_zeta2_predict};
use crate::syntax_retrieval_stats::retrieval_stats;
use ::serde::Serialize;
use ::util::paths::PathStyle;
-use ::util::rel_path::RelPath;
use anyhow::{Context as _, Result, anyhow};
use clap::{Args, Parser, Subcommand};
use cloud_llm_client::predict_edits_v3::{self, Excerpt};
@@ -22,12 +25,12 @@ use futures::channel::mpsc;
use gpui::{Application, AsyncApp, Entity, prelude::*};
use language::{Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point};
use language_model::LanguageModelRegistry;
-use project::{Project, ProjectPath, Worktree};
+use project::{Project, Worktree};
use reqwest_client::ReqwestClient;
use serde_json::json;
-use std::io;
-use std::time::{Duration, Instant};
-use std::{collections::HashSet, path::PathBuf, process::exit, str::FromStr, sync::Arc};
+use std::io::{self};
+use std::time::Duration;
+use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
use zeta2::{ContextMode, LlmContextOptions, SearchToolQuery};
use crate::headless::ZetaCliAppState;
@@ -82,9 +85,8 @@ enum Zeta2Command {
#[command(subcommand)]
command: Zeta2LlmCommand,
},
- Predict {
- example_path: PathBuf,
- },
+ Predict(PredictArguments),
+ Eval(EvaluateArguments),
}
#[derive(Subcommand, Debug)]
@@ -327,208 +329,6 @@ async fn load_context(
})
}
-async fn zeta2_predict(
- example: NamedExample,
- app_state: &Arc<ZetaCliAppState>,
- cx: &mut AsyncApp,
-) -> Result<()> {
- let worktree_path = example.setup_worktree().await?;
-
- cx.update(|cx| {
- LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
- registry
- .provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID)
- .unwrap()
- .authenticate(cx)
- })
- })?
- .await?;
-
- app_state
- .client
- .sign_in_with_optional_connect(true, cx)
- .await?;
-
- let project = cx.update(|cx| {
- Project::local(
- app_state.client.clone(),
- app_state.node_runtime.clone(),
- app_state.user_store.clone(),
- app_state.languages.clone(),
- app_state.fs.clone(),
- None,
- cx,
- )
- })?;
-
- let worktree = project
- .update(cx, |project, cx| {
- project.create_worktree(&worktree_path, true, cx)
- })?
- .await?;
- worktree
- .read_with(cx, |worktree, _cx| {
- worktree.as_local().unwrap().scan_complete()
- })?
- .await;
-
- let _edited_buffers = example.apply_edit_history(&project, cx).await?;
-
- let cursor_path = RelPath::new(&example.example.cursor_path, PathStyle::Posix)?.into_arc();
-
- let cursor_buffer = project
- .update(cx, |project, cx| {
- project.open_buffer(
- ProjectPath {
- worktree_id: worktree.read(cx).id(),
- path: cursor_path,
- },
- cx,
- )
- })?
- .await?;
-
- let cursor_offset_within_excerpt = example
- .example
- .cursor_position
- .find(CURSOR_MARKER)
- .ok_or_else(|| anyhow!("missing cursor marker"))?;
- let mut cursor_excerpt = example.example.cursor_position.clone();
- cursor_excerpt.replace_range(
- cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
- "",
- );
- let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
- let text = buffer.text();
-
- let mut matches = text.match_indices(&cursor_excerpt);
- let Some((excerpt_offset, _)) = matches.next() else {
- anyhow::bail!(
- "Cursor excerpt did not exist in buffer.\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n"
- );
- };
- assert!(matches.next().is_none());
-
- Ok(excerpt_offset)
- })??;
-
- let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
- let cursor_anchor =
- cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
-
- let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
-
- let refresh_task = zeta.update(cx, |zeta, cx| {
- zeta.register_buffer(&cursor_buffer, &project, cx);
- zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
- })?;
-
- let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
- let mut context_retrieval_started_at = None;
- let mut context_retrieval_finished_at = None;
- let mut search_queries_generated_at = None;
- let mut search_queries_executed_at = None;
- let mut prediction_started_at = None;
- let mut prediction_finished_at = None;
- let mut excerpts_text = String::new();
- let mut prediction_task = None;
- while let Some(event) = debug_rx.next().await {
- match event {
- zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
- context_retrieval_started_at = Some(info.timestamp);
- }
- zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
- search_queries_generated_at = Some(info.timestamp);
- }
- zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
- search_queries_executed_at = Some(info.timestamp);
- }
- zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => {
- context_retrieval_finished_at = Some(info.timestamp);
-
- prediction_task = Some(zeta.update(cx, |zeta, cx| {
- zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
- })?);
- }
- zeta2::ZetaDebugInfo::EditPredicted(request) => {
- prediction_started_at = Some(Instant::now());
- request.response_rx.await?.map_err(|err| anyhow!(err))?;
- prediction_finished_at = Some(Instant::now());
-
- for included_file in request.request.included_files {
- let insertions = vec![(request.request.cursor_point, CURSOR_MARKER)];
- write_codeblock(
- &included_file.path,
- included_file.excerpts.iter(),
- if included_file.path == request.request.excerpt_path {
- &insertions
- } else {
- &[]
- },
- included_file.max_row,
- false,
- &mut excerpts_text,
- );
- }
- break;
- }
- _ => {}
- }
- }
-
- refresh_task.await.context("context retrieval failed")?;
- let prediction = prediction_task.unwrap().await?.context("No prediction")?;
-
- println!("## Excerpts\n");
- println!("{excerpts_text}");
-
- let old_text = prediction.snapshot.text();
- let new_text = prediction.buffer.update(cx, |buffer, cx| {
- buffer.edit(prediction.edits.iter().cloned(), None, cx);
- buffer.text()
- })?;
- let diff = language::unified_diff(&old_text, &new_text);
-
- println!("## Prediction\n");
- println!("{diff}");
-
- println!("## Time\n");
-
- let planning_search_time =
- search_queries_generated_at.unwrap() - context_retrieval_started_at.unwrap();
-
- println!("Planning searches: {}ms", planning_search_time.as_millis());
- println!(
- "Running searches: {}ms",
- (search_queries_executed_at.unwrap() - search_queries_generated_at.unwrap()).as_millis()
- );
-
- let filtering_search_time =
- context_retrieval_finished_at.unwrap() - search_queries_executed_at.unwrap();
- println!(
- "Filtering context results: {}ms",
- filtering_search_time.as_millis()
- );
-
- let prediction_time = prediction_finished_at.unwrap() - prediction_started_at.unwrap();
- println!("Making Prediction: {}ms", prediction_time.as_millis());
-
- println!("-------------------");
- let total_time =
- (prediction_finished_at.unwrap() - context_retrieval_started_at.unwrap()).as_millis();
- println!("Total: {}ms", total_time);
-
- let inference_time =
- (planning_search_time + filtering_search_time + prediction_time).as_millis();
- println!(
- "Inference: {}ms ({:.2}%)",
- inference_time,
- (inference_time as f64 / total_time as f64) * 100.
- );
-
- anyhow::Ok(())
-}
-
async fn zeta2_syntax_context(
zeta2_args: Zeta2Args,
syntax_args: Zeta2SyntaxArgs,
@@ -819,50 +619,62 @@ fn main() {
app.run(move |cx| {
let app_state = Arc::new(headless::init(cx));
cx.spawn(async move |cx| {
- let result = match args.command {
+ match args.command {
Command::Zeta1 {
command: Zeta1Command::Context { context_args },
} => {
let context = zeta1_context(context_args, &app_state, cx).await.unwrap();
- serde_json::to_string_pretty(&context.body).map_err(|err| anyhow::anyhow!(err))
+ let result = serde_json::to_string_pretty(&context.body).unwrap();
+ println!("{}", result);
}
Command::Zeta2 { command } => match command {
- Zeta2Command::Predict { example_path } => {
- let example = NamedExample::load(example_path).unwrap();
- zeta2_predict(example, &app_state, cx).await.unwrap();
- let _ = cx.update(|cx| cx.quit());
- return;
+ Zeta2Command::Predict(arguments) => {
+ run_zeta2_predict(arguments, &app_state, cx).await;
+ }
+ Zeta2Command::Eval(arguments) => {
+ run_evaluate(arguments, &app_state, cx).await;
}
Zeta2Command::Syntax {
args,
syntax_args,
command,
- } => match command {
- Zeta2SyntaxCommand::Context { context_args } => {
- zeta2_syntax_context(args, syntax_args, context_args, &app_state, cx)
+ } => {
+ let result = match command {
+ Zeta2SyntaxCommand::Context { context_args } => {
+ zeta2_syntax_context(
+ args,
+ syntax_args,
+ context_args,
+ &app_state,
+ cx,
+ )
.await
- }
- Zeta2SyntaxCommand::Stats {
- worktree,
- extension,
- limit,
- skip,
- } => {
- retrieval_stats(
+ }
+ Zeta2SyntaxCommand::Stats {
worktree,
- app_state,
extension,
limit,
skip,
- syntax_args_to_options(&args, &syntax_args, false),
- cx,
- )
- .await
- }
- },
+ } => {
+ retrieval_stats(
+ worktree,
+ app_state,
+ extension,
+ limit,
+ skip,
+ syntax_args_to_options(&args, &syntax_args, false),
+ cx,
+ )
+ .await
+ }
+ };
+ println!("{}", result.unwrap());
+ }
Zeta2Command::Llm { args, command } => match command {
Zeta2LlmCommand::Context { context_args } => {
- zeta2_llm_context(args, context_args, &app_state, cx).await
+ let result =
+ zeta2_llm_context(args, context_args, &app_state, cx).await;
+ println!("{}", result.unwrap());
}
},
},
@@ -872,21 +684,10 @@ fn main() {
} => {
let example = NamedExample::load(path).unwrap();
example.write(output_format, io::stdout()).unwrap();
- let _ = cx.update(|cx| cx.quit());
- return;
}
};
- match result {
- Ok(output) => {
- println!("{}", output);
- let _ = cx.update(|cx| cx.quit());
- }
- Err(e) => {
- eprintln!("Failed: {:?}", e);
- exit(1);
- }
- }
+ let _ = cx.update(|cx| cx.quit());
})
.detach();
});
@@ -0,0 +1,287 @@
+use crate::example::{ActualExcerpt, NamedExample};
+
+use crate::headless::ZetaCliAppState;
+use ::serde::Serialize;
+use ::util::paths::PathStyle;
+use anyhow::{Context as _, Result, anyhow};
+use clap::Args;
+use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
+use futures::StreamExt as _;
+use gpui::AsyncApp;
+use language_model::LanguageModelRegistry;
+use project::{Project, ProjectPath};
+use serde::Deserialize;
+use std::cell::Cell;
+use std::io::Write;
+use std::path::PathBuf;
+use std::sync::Arc;
+use std::time::{Duration, Instant};
+use util::rel_path::RelPath;
+
+#[derive(Debug, Args)]
+pub struct PredictArguments {
+ example_path: PathBuf,
+ #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
+ format: PredictionsOutputFormat,
+}
+
+#[derive(clap::ValueEnum, Debug, Clone)]
+pub enum PredictionsOutputFormat {
+ Json,
+ Md,
+ Diff,
+}
+pub async fn run_zeta2_predict(
+ args: PredictArguments,
+ app_state: &Arc<ZetaCliAppState>,
+ cx: &mut AsyncApp,
+) {
+ let example = NamedExample::load(args.example_path).unwrap();
+ let result = zeta2_predict(example, &app_state, cx).await.unwrap();
+ result.write(args.format, std::io::stdout()).unwrap();
+}
+
+thread_local! {
+ static AUTHENTICATED: Cell<bool> = const { Cell::new(false) };
+}
+
+pub async fn zeta2_predict(
+ example: NamedExample,
+ app_state: &Arc<ZetaCliAppState>,
+ cx: &mut AsyncApp,
+) -> Result<PredictionDetails> {
+ let worktree_path = example.setup_worktree().await?;
+
+ if !AUTHENTICATED.get() {
+ AUTHENTICATED.set(true);
+
+ cx.update(|cx| {
+ LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+ registry
+ .provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID)
+ .unwrap()
+ .authenticate(cx)
+ })
+ })?
+ .await?;
+
+ app_state
+ .client
+ .sign_in_with_optional_connect(true, cx)
+ .await?;
+ }
+
+ let project = cx.update(|cx| {
+ Project::local(
+ app_state.client.clone(),
+ app_state.node_runtime.clone(),
+ app_state.user_store.clone(),
+ app_state.languages.clone(),
+ app_state.fs.clone(),
+ None,
+ cx,
+ )
+ })?;
+
+ let worktree = project
+ .update(cx, |project, cx| {
+ project.create_worktree(&worktree_path, true, cx)
+ })?
+ .await?;
+ worktree
+ .read_with(cx, |worktree, _cx| {
+ worktree.as_local().unwrap().scan_complete()
+ })?
+ .await;
+
+ let _edited_buffers = example.apply_edit_history(&project, cx).await?;
+
+ let cursor_path = RelPath::new(&example.example.cursor_path, PathStyle::Posix)?.into_arc();
+
+ let cursor_buffer = project
+ .update(cx, |project, cx| {
+ project.open_buffer(
+ ProjectPath {
+ worktree_id: worktree.read(cx).id(),
+ path: cursor_path,
+ },
+ cx,
+ )
+ })?
+ .await?;
+
+ let cursor_offset_within_excerpt = example
+ .example
+ .cursor_position
+ .find(CURSOR_MARKER)
+ .ok_or_else(|| anyhow!("missing cursor marker"))?;
+ let mut cursor_excerpt = example.example.cursor_position.clone();
+ cursor_excerpt.replace_range(
+ cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
+ "",
+ );
+ let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
+ let text = buffer.text();
+
+ let mut matches = text.match_indices(&cursor_excerpt);
+ let Some((excerpt_offset, _)) = matches.next() else {
+ anyhow::bail!(
+ "Cursor excerpt did not exist in buffer.\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n"
+ );
+ };
+ assert!(matches.next().is_none());
+
+ Ok(excerpt_offset)
+ })??;
+
+ let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
+ let cursor_anchor =
+ cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
+
+ let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
+
+ let refresh_task = zeta.update(cx, |zeta, cx| {
+ zeta.register_buffer(&cursor_buffer, &project, cx);
+ zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
+ })?;
+
+ let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
+ let mut context_retrieval_started_at = None;
+ let mut context_retrieval_finished_at = None;
+ let mut search_queries_generated_at = None;
+ let mut search_queries_executed_at = None;
+ let mut prediction_started_at = None;
+ let mut prediction_finished_at = None;
+ let mut excerpts_text = String::new();
+ let mut prediction_task = None;
+ let mut result = PredictionDetails::default();
+ while let Some(event) = debug_rx.next().await {
+ match event {
+ zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
+ context_retrieval_started_at = Some(info.timestamp);
+ }
+ zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
+ search_queries_generated_at = Some(info.timestamp);
+ }
+ zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
+ search_queries_executed_at = Some(info.timestamp);
+ }
+ zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => {
+ context_retrieval_finished_at = Some(info.timestamp);
+
+ prediction_task = Some(zeta.update(cx, |zeta, cx| {
+ zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
+ })?);
+ }
+ zeta2::ZetaDebugInfo::EditPredicted(request) => {
+ prediction_started_at = Some(Instant::now());
+ request.response_rx.await?.map_err(|err| anyhow!(err))?;
+ prediction_finished_at = Some(Instant::now());
+
+ for included_file in request.request.included_files {
+ let insertions = vec![(request.request.cursor_point, CURSOR_MARKER)];
+ result
+ .excerpts
+ .extend(included_file.excerpts.iter().map(|excerpt| ActualExcerpt {
+ path: included_file.path.components().skip(1).collect(),
+ text: String::from(excerpt.text.as_ref()),
+ }));
+ write_codeblock(
+ &included_file.path,
+ included_file.excerpts.iter(),
+ if included_file.path == request.request.excerpt_path {
+ &insertions
+ } else {
+ &[]
+ },
+ included_file.max_row,
+ false,
+ &mut excerpts_text,
+ );
+ }
+ break;
+ }
+ _ => {}
+ }
+ }
+
+ refresh_task.await.context("context retrieval failed")?;
+ let prediction = prediction_task.unwrap().await?;
+
+ result.diff = prediction
+ .map(|prediction| {
+ let old_text = prediction.snapshot.text();
+ let new_text = prediction.buffer.update(cx, |buffer, cx| {
+ buffer.edit(prediction.edits.iter().cloned(), None, cx);
+ buffer.text()
+ })?;
+ anyhow::Ok(language::unified_diff(&old_text, &new_text))
+ })
+ .transpose()?
+ .unwrap_or_default();
+ result.excerpts_text = excerpts_text;
+
+ result.planning_search_time =
+ search_queries_generated_at.unwrap() - context_retrieval_started_at.unwrap();
+ result.running_search_time =
+ search_queries_executed_at.unwrap() - search_queries_generated_at.unwrap();
+ result.filtering_search_time =
+ context_retrieval_finished_at.unwrap() - search_queries_executed_at.unwrap();
+ result.prediction_time = prediction_finished_at.unwrap() - prediction_started_at.unwrap();
+ result.total_time = prediction_finished_at.unwrap() - context_retrieval_started_at.unwrap();
+
+ anyhow::Ok(result)
+}
+
+#[derive(Debug, Default, Serialize, Deserialize)]
+pub struct PredictionDetails {
+ pub diff: String,
+ pub excerpts: Vec<ActualExcerpt>,
+ pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
+ pub planning_search_time: Duration,
+ pub filtering_search_time: Duration,
+ pub running_search_time: Duration,
+ pub prediction_time: Duration,
+ pub total_time: Duration,
+}
+
+impl PredictionDetails {
+ pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
+ let formatted = match format {
+ PredictionsOutputFormat::Md => self.to_markdown(),
+ PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
+ PredictionsOutputFormat::Diff => self.diff.clone(),
+ };
+
+ Ok(out.write_all(formatted.as_bytes())?)
+ }
+
+ pub fn to_markdown(&self) -> String {
+ let inference_time =
+ self.planning_search_time + self.filtering_search_time + self.prediction_time;
+
+ format!(
+ "## Excerpts\n\n\
+ {}\n\n\
+ ## Prediction\n\n\
+ {}\n\n\
+ ## Time\n\n\
+ Planning searches: {}ms\n\
+ Running searches: {}ms\n\
+ Filtering context results: {}ms\n\
+ Making Prediction: {}ms\n\n\
+ -------------------\n\n\
+ Total: {}ms\n\
+ Inference: {}ms ({:.2}%)\n",
+ self.excerpts_text,
+ self.diff,
+ self.planning_search_time.as_millis(),
+ self.running_search_time.as_millis(),
+ self.filtering_search_time.as_millis(),
+ self.prediction_time.as_millis(),
+ self.total_time.as_millis(),
+ inference_time.as_millis(),
+ (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
+ )
+ }
+}