From 784fdcaee3be7bbcf9511097cf620256dc2f7ef6 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 6 Nov 2025 15:36:58 -0800 Subject: [PATCH] zeta2: Build edit prediction prompt and process model output in client (#41870) Release Notes: - N/A --------- Co-authored-by: Agus Zubiaga Co-authored-by: Ben Kunkle Co-authored-by: Piotr Osiewicz <24362066+osiewicz@users.noreply.github.com> --- Cargo.lock | 6 +- crates/acp_thread/Cargo.toml | 1 + .../cloud_llm_client/src/cloud_llm_client.rs | 1 - .../cloud_llm_client/src/predict_edits_v3.rs | 37 +- crates/cloud_llm_client/src/udiff.rs | 294 ----- crates/cloud_zeta2_prompt/Cargo.toml | 2 + .../src/cloud_zeta2_prompt.rs | 7 +- .../src/retrieval_prompt.rs | 92 ++ crates/codestral/src/codestral.rs | 8 +- crates/edit_prediction/src/edit_prediction.rs | 10 +- crates/editor/src/edit_prediction_tests.rs | 8 +- crates/editor/src/editor.rs | 19 +- crates/editor/src/editor_tests.rs | 2 +- crates/language/src/buffer.rs | 11 +- crates/language/src/buffer_tests.rs | 8 +- crates/open_ai/src/open_ai.rs | 27 +- .../src/supermaven_completion_provider.rs | 10 +- crates/zeta/src/zeta.rs | 49 +- crates/zeta2/Cargo.toml | 3 +- crates/zeta2/src/prediction.rs | 260 +---- crates/zeta2/src/related_excerpts.rs | 717 ------------ crates/zeta2/src/retrieval_search.rs | 194 ++++ crates/zeta2/src/udiff.rs | 1024 +++++++++++++++++ crates/zeta2/src/zeta2.rs | 826 ++++++++----- crates/zeta2_tools/Cargo.toml | 3 +- crates/zeta2_tools/src/zeta2_context_view.rs | 50 +- crates/zeta2_tools/src/zeta2_tools.rs | 93 +- crates/zeta_cli/src/evaluate.rs | 9 +- crates/zeta_cli/src/example.rs | 479 ++------ crates/zeta_cli/src/main.rs | 229 +--- crates/zeta_cli/src/paths.rs | 8 + crates/zeta_cli/src/predict.rs | 101 +- 32 files changed, 2197 insertions(+), 2391 deletions(-) delete mode 100644 crates/cloud_llm_client/src/udiff.rs create mode 100644 crates/cloud_zeta2_prompt/src/retrieval_prompt.rs delete mode 100644 crates/zeta2/src/related_excerpts.rs create mode 100644 crates/zeta2/src/retrieval_search.rs create mode 100644 crates/zeta2/src/udiff.rs create mode 100644 crates/zeta_cli/src/paths.rs diff --git a/Cargo.lock b/Cargo.lock index ddc18ba3c0e5ce089d12139a28a737c05ca8de03..a3a8fed78dcaae90eae9b026d2968eaeb2f8aad2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -39,6 +39,7 @@ dependencies = [ "util", "uuid", "watch", + "zlog", ] [[package]] @@ -3198,7 +3199,9 @@ dependencies = [ "indoc", "ordered-float 2.10.1", "rustc-hash 2.1.1", + "schemars 1.0.4", "serde", + "serde_json", "strum 0.27.2", ] @@ -21675,10 +21678,10 @@ dependencies = [ "language_model", "log", "lsp", + "open_ai", "pretty_assertions", "project", "release_channel", - "schemars 1.0.4", "serde", "serde_json", "settings", @@ -21687,6 +21690,7 @@ dependencies = [ "uuid", "workspace", "worktree", + "zlog", ] [[package]] diff --git a/crates/acp_thread/Cargo.toml b/crates/acp_thread/Cargo.toml index 09202dc57cb96f5f258e64063f5d61169fa7a045..4030dd89c5497c3fdd4af06d725aed8755da3cf5 100644 --- a/crates/acp_thread/Cargo.toml +++ b/crates/acp_thread/Cargo.toml @@ -56,3 +56,4 @@ rand.workspace = true tempfile.workspace = true util.workspace = true settings.workspace = true +zlog.workspace = true diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs index afa72665f168e7ec341d92df0a094f7880368087..bb77c3a5b7f8009093cbf7bc427160ed535e6c62 100644 --- a/crates/cloud_llm_client/src/cloud_llm_client.rs +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -1,5 +1,4 @@ pub mod predict_edits_v3; -pub mod udiff; use std::str::FromStr; use std::sync::Arc; diff --git a/crates/cloud_llm_client/src/predict_edits_v3.rs b/crates/cloud_llm_client/src/predict_edits_v3.rs index 7166139d9077394e684a8b53ce3d8300cb5fa2db..2e884ae9fcb27530e5579b83767bde95b5df414c 100644 --- a/crates/cloud_llm_client/src/predict_edits_v3.rs +++ b/crates/cloud_llm_client/src/predict_edits_v3.rs @@ -1,7 +1,7 @@ use chrono::Duration; use serde::{Deserialize, Serialize}; use std::{ - fmt::Display, + fmt::{Display, Write as _}, ops::{Add, Range, Sub}, path::{Path, PathBuf}, sync::Arc, @@ -11,7 +11,14 @@ use uuid::Uuid; use crate::PredictEditsGitInfo; -// TODO: snippet ordering within file / relative to excerpt +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PlanContextRetrievalRequest { + pub excerpt: String, + pub excerpt_path: Arc, + pub excerpt_line_range: Range, + pub cursor_file_max_row: Line, + pub events: Vec, +} #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PredictEditsRequest { @@ -125,15 +132,15 @@ impl Display for Event { write!( f, "// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}", - old_path.display(), - new_path.display() + DiffPathFmt(old_path), + DiffPathFmt(new_path) ) } else { write!( f, "--- a/{}\n+++ b/{}\n{diff}", - old_path.display(), - new_path.display() + DiffPathFmt(old_path), + DiffPathFmt(new_path) ) } } @@ -141,6 +148,24 @@ impl Display for Event { } } +/// always format the Path as a unix path with `/` as the path sep in Diffs +pub struct DiffPathFmt<'a>(pub &'a Path); + +impl<'a> std::fmt::Display for DiffPathFmt<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut is_first = true; + for component in self.0.components() { + if !is_first { + f.write_char('/')?; + } else { + is_first = false; + } + write!(f, "{}", component.as_os_str().display())?; + } + Ok(()) + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Signature { pub text: String, diff --git a/crates/cloud_llm_client/src/udiff.rs b/crates/cloud_llm_client/src/udiff.rs deleted file mode 100644 index c5972fc139dd105c9d4a0f4d5917950752278cb7..0000000000000000000000000000000000000000 --- a/crates/cloud_llm_client/src/udiff.rs +++ /dev/null @@ -1,294 +0,0 @@ -use std::{borrow::Cow, fmt::Display}; - -#[derive(Debug, PartialEq)] -pub enum DiffLine<'a> { - OldPath { path: Cow<'a, str> }, - NewPath { path: Cow<'a, str> }, - HunkHeader(Option), - Context(&'a str), - Deletion(&'a str), - Addition(&'a str), - Garbage(&'a str), -} - -#[derive(Debug, PartialEq)] -pub struct HunkLocation { - start_line_old: u32, - count_old: u32, - start_line_new: u32, - count_new: u32, -} - -impl<'a> DiffLine<'a> { - pub fn parse(line: &'a str) -> Self { - Self::try_parse(line).unwrap_or(Self::Garbage(line)) - } - - fn try_parse(line: &'a str) -> Option { - if let Some(header) = line.strip_prefix("---").and_then(eat_required_whitespace) { - let path = parse_header_path("a/", header); - Some(Self::OldPath { path }) - } else if let Some(header) = line.strip_prefix("+++").and_then(eat_required_whitespace) { - Some(Self::NewPath { - path: parse_header_path("b/", header), - }) - } else if let Some(header) = line.strip_prefix("@@").and_then(eat_required_whitespace) { - if header.starts_with("...") { - return Some(Self::HunkHeader(None)); - } - - let (start_line_old, header) = header.strip_prefix('-')?.split_once(',')?; - let mut parts = header.split_ascii_whitespace(); - let count_old = parts.next()?; - let (start_line_new, count_new) = parts.next()?.strip_prefix('+')?.split_once(',')?; - - Some(Self::HunkHeader(Some(HunkLocation { - start_line_old: start_line_old.parse::().ok()?.saturating_sub(1), - count_old: count_old.parse().ok()?, - start_line_new: start_line_new.parse::().ok()?.saturating_sub(1), - count_new: count_new.parse().ok()?, - }))) - } else if let Some(deleted_header) = line.strip_prefix("-") { - Some(Self::Deletion(deleted_header)) - } else if line.is_empty() { - Some(Self::Context("")) - } else if let Some(context) = line.strip_prefix(" ") { - Some(Self::Context(context)) - } else { - Some(Self::Addition(line.strip_prefix("+")?)) - } - } -} - -impl<'a> Display for DiffLine<'a> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - DiffLine::OldPath { path } => write!(f, "--- {path}"), - DiffLine::NewPath { path } => write!(f, "+++ {path}"), - DiffLine::HunkHeader(Some(hunk_location)) => { - write!( - f, - "@@ -{},{} +{},{} @@", - hunk_location.start_line_old + 1, - hunk_location.count_old, - hunk_location.start_line_new + 1, - hunk_location.count_new - ) - } - DiffLine::HunkHeader(None) => write!(f, "@@ ... @@"), - DiffLine::Context(content) => write!(f, " {content}"), - DiffLine::Deletion(content) => write!(f, "-{content}"), - DiffLine::Addition(content) => write!(f, "+{content}"), - DiffLine::Garbage(line) => write!(f, "{line}"), - } - } -} - -fn parse_header_path<'a>(strip_prefix: &'static str, header: &'a str) -> Cow<'a, str> { - if !header.contains(['"', '\\']) { - let path = header.split_ascii_whitespace().next().unwrap_or(header); - return Cow::Borrowed(path.strip_prefix(strip_prefix).unwrap_or(path)); - } - - let mut path = String::with_capacity(header.len()); - let mut in_quote = false; - let mut chars = header.chars().peekable(); - let mut strip_prefix = Some(strip_prefix); - - while let Some(char) = chars.next() { - if char == '"' { - in_quote = !in_quote; - } else if char == '\\' { - let Some(&next_char) = chars.peek() else { - break; - }; - chars.next(); - path.push(next_char); - } else if char.is_ascii_whitespace() && !in_quote { - break; - } else { - path.push(char); - } - - if let Some(prefix) = strip_prefix - && path == prefix - { - strip_prefix.take(); - path.clear(); - } - } - - Cow::Owned(path) -} - -fn eat_required_whitespace(header: &str) -> Option<&str> { - let trimmed = header.trim_ascii_start(); - - if trimmed.len() == header.len() { - None - } else { - Some(trimmed) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use indoc::indoc; - - #[test] - fn parse_lines_simple() { - let input = indoc! {" - diff --git a/text.txt b/text.txt - index 86c770d..a1fd855 100644 - --- a/file.txt - +++ b/file.txt - @@ -1,2 +1,3 @@ - context - -deleted - +inserted - garbage - - --- b/file.txt - +++ a/file.txt - "}; - - let lines = input.lines().map(DiffLine::parse).collect::>(); - - pretty_assertions::assert_eq!( - lines, - &[ - DiffLine::Garbage("diff --git a/text.txt b/text.txt"), - DiffLine::Garbage("index 86c770d..a1fd855 100644"), - DiffLine::OldPath { - path: "file.txt".into() - }, - DiffLine::NewPath { - path: "file.txt".into() - }, - DiffLine::HunkHeader(Some(HunkLocation { - start_line_old: 0, - count_old: 2, - start_line_new: 0, - count_new: 3 - })), - DiffLine::Context("context"), - DiffLine::Deletion("deleted"), - DiffLine::Addition("inserted"), - DiffLine::Garbage("garbage"), - DiffLine::Context(""), - DiffLine::OldPath { - path: "b/file.txt".into() - }, - DiffLine::NewPath { - path: "a/file.txt".into() - }, - ] - ); - } - - #[test] - fn file_header_extra_space() { - let options = ["--- file", "--- file", "---\tfile"]; - - for option in options { - pretty_assertions::assert_eq!( - DiffLine::parse(option), - DiffLine::OldPath { - path: "file".into() - }, - "{option}", - ); - } - } - - #[test] - fn hunk_header_extra_space() { - let options = [ - "@@ -1,2 +1,3 @@", - "@@ -1,2 +1,3 @@", - "@@\t-1,2\t+1,3\t@@", - "@@ -1,2 +1,3 @@", - "@@ -1,2 +1,3 @@", - "@@ -1,2 +1,3 @@", - "@@ -1,2 +1,3 @@ garbage", - ]; - - for option in options { - pretty_assertions::assert_eq!( - DiffLine::parse(option), - DiffLine::HunkHeader(Some(HunkLocation { - start_line_old: 0, - count_old: 2, - start_line_new: 0, - count_new: 3 - })), - "{option}", - ); - } - } - - #[test] - fn hunk_header_without_location() { - pretty_assertions::assert_eq!(DiffLine::parse("@@ ... @@"), DiffLine::HunkHeader(None)); - } - - #[test] - fn test_parse_path() { - assert_eq!(parse_header_path("a/", "foo.txt"), "foo.txt"); - assert_eq!( - parse_header_path("a/", "foo/bar/baz.txt"), - "foo/bar/baz.txt" - ); - assert_eq!(parse_header_path("a/", "a/foo.txt"), "foo.txt"); - assert_eq!( - parse_header_path("a/", "a/foo/bar/baz.txt"), - "foo/bar/baz.txt" - ); - - // Extra - assert_eq!( - parse_header_path("a/", "a/foo/bar/baz.txt 2025"), - "foo/bar/baz.txt" - ); - assert_eq!( - parse_header_path("a/", "a/foo/bar/baz.txt\t2025"), - "foo/bar/baz.txt" - ); - assert_eq!( - parse_header_path("a/", "a/foo/bar/baz.txt \""), - "foo/bar/baz.txt" - ); - - // Quoted - assert_eq!( - parse_header_path("a/", "a/foo/bar/\"baz quox.txt\""), - "foo/bar/baz quox.txt" - ); - assert_eq!( - parse_header_path("a/", "\"a/foo/bar/baz quox.txt\""), - "foo/bar/baz quox.txt" - ); - assert_eq!( - parse_header_path("a/", "\"foo/bar/baz quox.txt\""), - "foo/bar/baz quox.txt" - ); - assert_eq!(parse_header_path("a/", "\"whatever 🤷\""), "whatever 🤷"); - assert_eq!( - parse_header_path("a/", "\"foo/bar/baz quox.txt\" 2025"), - "foo/bar/baz quox.txt" - ); - // unescaped quotes are dropped - assert_eq!(parse_header_path("a/", "foo/\"bar\""), "foo/bar"); - - // Escaped - assert_eq!( - parse_header_path("a/", "\"foo/\\\"bar\\\"/baz.txt\""), - "foo/\"bar\"/baz.txt" - ); - assert_eq!( - parse_header_path("a/", "\"C:\\\\Projects\\\\My App\\\\old file.txt\""), - "C:\\Projects\\My App\\old file.txt" - ); - } -} diff --git a/crates/cloud_zeta2_prompt/Cargo.toml b/crates/cloud_zeta2_prompt/Cargo.toml index 43446f460c872afcdfe1d4bc47d14f894f0c9c09..fa8246950f8d03029388e0276954de946efc2346 100644 --- a/crates/cloud_zeta2_prompt/Cargo.toml +++ b/crates/cloud_zeta2_prompt/Cargo.toml @@ -17,5 +17,7 @@ cloud_llm_client.workspace = true indoc.workspace = true ordered-float.workspace = true rustc-hash.workspace = true +schemars.workspace = true serde.workspace = true +serde_json.workspace = true strum.workspace = true diff --git a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs index 6caf9941845146dc0c30c4606f677e5ec816c137..7fb79906f29f38579feef82bb25e7ed42d1d6c83 100644 --- a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs +++ b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs @@ -1,8 +1,9 @@ //! Zeta2 prompt planning and generation code shared with cloud. +pub mod retrieval_prompt; use anyhow::{Context as _, Result, anyhow}; use cloud_llm_client::predict_edits_v3::{ - self, Excerpt, Line, Point, PromptFormat, ReferencedDeclaration, + self, DiffPathFmt, Excerpt, Line, Point, PromptFormat, ReferencedDeclaration, }; use indoc::indoc; use ordered_float::OrderedFloat; @@ -212,7 +213,7 @@ pub fn write_codeblock<'a>( include_line_numbers: bool, output: &'a mut String, ) { - writeln!(output, "`````{}", path.display()).unwrap(); + writeln!(output, "`````{}", DiffPathFmt(path)).unwrap(); write_excerpts( excerpts, sorted_insertions, @@ -275,7 +276,7 @@ pub fn write_excerpts<'a>( } } -fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) { +pub fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) { if events.is_empty() { return; }; diff --git a/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs b/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs new file mode 100644 index 0000000000000000000000000000000000000000..54ef1999729f6976bd77d280508f8c370d54488e --- /dev/null +++ b/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs @@ -0,0 +1,92 @@ +use anyhow::Result; +use cloud_llm_client::predict_edits_v3::{self, Excerpt}; +use indoc::indoc; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::{fmt::Write, sync::LazyLock}; + +use crate::{push_events, write_codeblock}; + +pub fn build_prompt(request: predict_edits_v3::PlanContextRetrievalRequest) -> Result { + let mut prompt = SEARCH_INSTRUCTIONS.to_string(); + + if !request.events.is_empty() { + writeln!(&mut prompt, "## User Edits\n")?; + push_events(&mut prompt, &request.events); + } + + writeln!(&mut prompt, "## Excerpt around the cursor\n")?; + write_codeblock( + &request.excerpt_path, + &[Excerpt { + start_line: request.excerpt_line_range.start, + text: request.excerpt.into(), + }], + &[], + request.cursor_file_max_row, + true, + &mut prompt, + ); + + writeln!(&mut prompt, "{TOOL_USE_REMINDER}")?; + + Ok(prompt) +} + +/// Search for relevant code +/// +/// For the best results, run multiple queries at once with a single invocation of this tool. +#[derive(Clone, Deserialize, Serialize, JsonSchema)] +pub struct SearchToolInput { + /// An array of queries to run for gathering context relevant to the next prediction + #[schemars(length(max = 5))] + pub queries: Box<[SearchToolQuery]>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +pub struct SearchToolQuery { + /// A glob pattern to match file paths in the codebase + pub glob: String, + /// A regular expression to match content within the files matched by the glob pattern + pub regex: String, +} + +pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| { + let schema = schemars::schema_for!(SearchToolInput); + + let description = schema + .get("description") + .and_then(|description| description.as_str()) + .unwrap() + .to_string(); + + (schema.into(), description) +}); + +pub const TOOL_NAME: &str = "search"; + +const SEARCH_INSTRUCTIONS: &str = indoc! {r#" + ## Task + + You are part of an edit prediction system in a code editor. Your role is to identify relevant code locations + that will serve as context for predicting the next required edit. + + **Your task:** + - Analyze the user's recent edits and current cursor context + - Use the `search` tool to find code that may be relevant for predicting the next edit + - Focus on finding: + - Code patterns that might need similar changes based on the recent edits + - Functions, variables, types, and constants referenced in the current cursor context + - Related implementations, usages, or dependencies that may require consistent updates + + **Important constraints:** + - This conversation has exactly 2 turns + - You must make ALL search queries in your first response via the `search` tool + - All queries will be executed in parallel and results returned together + - In the second turn, you will select the most relevant results via the `select` tool. +"#}; + +const TOOL_USE_REMINDER: &str = indoc! {" + -- + Use the `search` tool now +"}; diff --git a/crates/codestral/src/codestral.rs b/crates/codestral/src/codestral.rs index e439cfb974fb55f4d30e5eb4be5c0dfa0d77c3d3..9fbd207a809fb2cb3ac685ea6629a36c8631d1fe 100644 --- a/crates/codestral/src/codestral.rs +++ b/crates/codestral/src/codestral.rs @@ -34,7 +34,7 @@ struct CurrentCompletion { snapshot: BufferSnapshot, /// The edits that should be applied to transform the original text into the predicted text. /// Each edit is a range in the buffer and the text to replace it with. - edits: Arc<[(Range, String)]>, + edits: Arc<[(Range, Arc)]>, /// Preview of how the buffer will look after applying the edits. edit_preview: EditPreview, } @@ -42,7 +42,7 @@ struct CurrentCompletion { impl CurrentCompletion { /// Attempts to adjust the edits based on changes made to the buffer since the completion was generated. /// Returns None if the user's edits conflict with the predicted edits. - fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option, String)>> { + fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option, Arc)>> { edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits) } } @@ -281,8 +281,8 @@ impl EditPredictionProvider for CodestralCompletionProvider { return Ok(()); } - let edits: Arc<[(Range, String)]> = - vec![(cursor_position..cursor_position, completion_text)].into(); + let edits: Arc<[(Range, Arc)]> = + vec![(cursor_position..cursor_position, completion_text.into())].into(); let edit_preview = buffer .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))? .await; diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 22cb1047d1dda93b639990e549f9b76b3ff385f5..c9bb0672a0c9cb7c56c3c703b0e10594d56cc0c1 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -1,4 +1,4 @@ -use std::ops::Range; +use std::{ops::Range, sync::Arc}; use client::EditPredictionUsage; use gpui::{App, Context, Entity, SharedString}; @@ -19,7 +19,7 @@ pub enum EditPrediction { /// Edits within the buffer that requested the prediction Local { id: Option, - edits: Vec<(Range, String)>, + edits: Vec<(Range, Arc)>, edit_preview: Option, }, /// Jump to a different file from the one that requested the prediction @@ -248,8 +248,8 @@ where pub fn interpolate_edits( old_snapshot: &BufferSnapshot, new_snapshot: &BufferSnapshot, - current_edits: &[(Range, String)], -) -> Option, String)>> { + current_edits: &[(Range, Arc)], +) -> Option, Arc)>> { let mut edits = Vec::new(); let mut model_edits = current_edits.iter().peekable(); @@ -274,7 +274,7 @@ pub fn interpolate_edits( if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) { if !model_suffix.is_empty() { let anchor = old_snapshot.anchor_after(user_edit.old.end); - edits.push((anchor..anchor, model_suffix.to_string())); + edits.push((anchor..anchor, model_suffix.into())); } model_edits.next(); diff --git a/crates/editor/src/edit_prediction_tests.rs b/crates/editor/src/edit_prediction_tests.rs index d897a670674cb23b075646c64f22e8b9bf0e4f90..74f13a404c6a52db448d68eba9e5c255e7276923 100644 --- a/crates/editor/src/edit_prediction_tests.rs +++ b/crates/editor/src/edit_prediction_tests.rs @@ -2,7 +2,7 @@ use edit_prediction::EditPredictionProvider; use gpui::{Entity, KeyBinding, Modifiers, prelude::*}; use indoc::indoc; use multi_buffer::{Anchor, MultiBufferSnapshot, ToPoint}; -use std::ops::Range; +use std::{ops::Range, sync::Arc}; use text::{Point, ToOffset}; use crate::{ @@ -24,7 +24,7 @@ async fn test_edit_prediction_insert(cx: &mut gpui::TestAppContext) { assert_editor_active_edit_completion(&mut cx, |_, edits| { assert_eq!(edits.len(), 1); - assert_eq!(edits[0].1.as_str(), "-273.15"); + assert_eq!(edits[0].1.as_ref(), "-273.15"); }); accept_completion(&mut cx); @@ -46,7 +46,7 @@ async fn test_edit_prediction_modification(cx: &mut gpui::TestAppContext) { assert_editor_active_edit_completion(&mut cx, |_, edits| { assert_eq!(edits.len(), 1); - assert_eq!(edits[0].1.as_str(), "3.14159"); + assert_eq!(edits[0].1.as_ref(), "3.14159"); }); accept_completion(&mut cx); @@ -330,7 +330,7 @@ async fn test_edit_prediction_preview_cleanup_on_toggle_off(cx: &mut gpui::TestA fn assert_editor_active_edit_completion( cx: &mut EditorTestContext, - assert: impl FnOnce(MultiBufferSnapshot, &Vec<(Range, String)>), + assert: impl FnOnce(MultiBufferSnapshot, &Vec<(Range, Arc)>), ) { cx.editor(|editor, _, cx| { let completion_state = editor diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 3bbd366795c4eab5a3febee2520213788799c451..4cb566f93e2bf74c5982f370e8ccaca548d4fb11 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -616,7 +616,7 @@ pub(crate) enum EditDisplayMode { enum EditPrediction { Edit { - edits: Vec<(Range, String)>, + edits: Vec<(Range, Arc)>, edit_preview: Option, display_mode: EditDisplayMode, snapshot: BufferSnapshot, @@ -7960,7 +7960,7 @@ impl Editor { let inlay = Inlay::edit_prediction( post_inc(&mut self.next_inlay_id), range.start, - new_text.as_str(), + new_text.as_ref(), ); inlay_ids.push(inlay.id); inlays.push(inlay); @@ -8982,7 +8982,7 @@ impl Editor { newest_selection_head: Option, editor_width: Pixels, style: &EditorStyle, - edits: &Vec<(Range, String)>, + edits: &Vec<(Range, Arc)>, edit_preview: &Option, snapshot: &language::BufferSnapshot, window: &mut Window, @@ -24382,25 +24382,20 @@ impl InvalidationRegion for SnippetState { fn edit_prediction_edit_text( current_snapshot: &BufferSnapshot, - edits: &[(Range, String)], + edits: &[(Range, impl AsRef)], edit_preview: &EditPreview, include_deletions: bool, cx: &App, ) -> HighlightedText { let edits = edits .iter() - .map(|(anchor, text)| { - ( - anchor.start.text_anchor..anchor.end.text_anchor, - text.clone(), - ) - }) + .map(|(anchor, text)| (anchor.start.text_anchor..anchor.end.text_anchor, text)) .collect::>(); edit_preview.highlight_edits(current_snapshot, &edits, include_deletions, cx) } -fn edit_prediction_fallback_text(edits: &[(Range, String)], cx: &App) -> HighlightedText { +fn edit_prediction_fallback_text(edits: &[(Range, Arc)], cx: &App) -> HighlightedText { // Fallback for providers that don't provide edit_preview (like Copilot/Supermaven) // Just show the raw edit text with basic styling let mut text = String::new(); @@ -24793,7 +24788,7 @@ impl Focusable for BreakpointPromptEditor { } fn all_edits_insertions_or_deletions( - edits: &Vec<(Range, String)>, + edits: &Vec<(Range, Arc)>, snapshot: &MultiBufferSnapshot, ) -> bool { let mut all_insertions = true; diff --git a/crates/editor/src/editor_tests.rs b/crates/editor/src/editor_tests.rs index 3709709c71fd1355014fce3f48681c632df7e18d..f411bda96ee444c1e4b2ff0c2d11f1475e85b9ac 100644 --- a/crates/editor/src/editor_tests.rs +++ b/crates/editor/src/editor_tests.rs @@ -22915,7 +22915,7 @@ async fn assert_highlighted_edits( let text_anchor_edits = edits .clone() .into_iter() - .map(|(range, edit)| (range.start.text_anchor..range.end.text_anchor, edit)) + .map(|(range, edit)| (range.start.text_anchor..range.end.text_anchor, edit.into())) .collect::>(); let edit_preview = window diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index 69e6b0a553cdb8c7ec90f1f19099f7cbc2a03e97..4dd90c15d9387327a75ece2d82385e406e5840d6 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -720,7 +720,7 @@ impl EditPreview { pub fn highlight_edits( &self, current_snapshot: &BufferSnapshot, - edits: &[(Range, String)], + edits: &[(Range, impl AsRef)], include_deletions: bool, cx: &App, ) -> HighlightedText { @@ -747,7 +747,8 @@ impl EditPreview { .end .bias_right(&self.old_snapshot) .to_offset(&self.applied_edits_snapshot); - let edit_start_in_preview_snapshot = edit_new_end_in_preview_snapshot - edit_text.len(); + let edit_start_in_preview_snapshot = + edit_new_end_in_preview_snapshot - edit_text.as_ref().len(); let unchanged_range_in_preview_snapshot = offset_in_preview_snapshot..edit_start_in_preview_snapshot; @@ -772,7 +773,7 @@ impl EditPreview { ); } - if !edit_text.is_empty() { + if !edit_text.as_ref().is_empty() { highlighted_text.add_text_from_buffer_range( edit_start_in_preview_snapshot..edit_new_end_in_preview_snapshot, &self.applied_edits_snapshot, @@ -796,7 +797,7 @@ impl EditPreview { highlighted_text.build() } - fn compute_visible_range(&self, edits: &[(Range, String)]) -> Option> { + fn compute_visible_range(&self, edits: &[(Range, T)]) -> Option> { let (first, _) = edits.first()?; let (last, _) = edits.last()?; @@ -1130,7 +1131,7 @@ impl Buffer { pub fn preview_edits( &self, - edits: Arc<[(Range, String)]>, + edits: Arc<[(Range, Arc)]>, cx: &App, ) -> Task { let registry = self.language_registry(); diff --git a/crates/language/src/buffer_tests.rs b/crates/language/src/buffer_tests.rs index f0267ebd99b3b1bf806058f98453714daed93ef5..ec584abf4876b38da38459ccae900631957258d1 100644 --- a/crates/language/src/buffer_tests.rs +++ b/crates/language/src/buffer_tests.rs @@ -3120,15 +3120,13 @@ async fn test_preview_edits(cx: &mut TestAppContext) { .map(|(range, text)| { ( buffer.anchor_before(range.start)..buffer.anchor_after(range.end), - text.to_string(), + text.into(), ) }) - .collect::>() + .collect::>() }); let edit_preview = buffer - .read_with(cx, |buffer, cx| { - buffer.preview_edits(edits.clone().into(), cx) - }) + .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx)) .await; let highlighted_edits = cx.read(|cx| { edit_preview.highlight_edits(&buffer.read(cx).snapshot(), &edits, include_deletions, cx) diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index 311fc7454ef9f11586a7ce0955b4e33d94e45c98..e1f58fe95a487f5be650d758df32b8097ee578e4 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -293,7 +293,7 @@ pub struct FunctionDefinition { pub parameters: Option, } -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)] #[serde(tag = "role", rename_all = "lowercase")] pub enum RequestMessage { Assistant { @@ -366,25 +366,42 @@ pub struct ImageUrl { pub detail: Option, } -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)] pub struct ToolCall { pub id: String, #[serde(flatten)] pub content: ToolCallContent, } -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)] #[serde(tag = "type", rename_all = "lowercase")] pub enum ToolCallContent { Function { function: FunctionContent }, } -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)] pub struct FunctionContent { pub name: String, pub arguments: String, } +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct Response { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, + pub usage: Usage, +} + +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct Choice { + pub index: u32, + pub message: RequestMessage, + pub finish_reason: Option, +} + #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] pub struct ResponseMessageDelta { pub role: Option, @@ -410,7 +427,7 @@ pub struct FunctionChunk { pub arguments: Option, } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Clone, Serialize, Deserialize, Debug)] pub struct Usage { pub prompt_tokens: u64, pub completion_tokens: u64, diff --git a/crates/supermaven/src/supermaven_completion_provider.rs b/crates/supermaven/src/supermaven_completion_provider.rs index 32177aaa427e8616a2767410a7a6ec84c05abbee..0c9fe85da6130f5ea2040434a0dcd3727754d3c0 100644 --- a/crates/supermaven/src/supermaven_completion_provider.rs +++ b/crates/supermaven/src/supermaven_completion_provider.rs @@ -7,6 +7,7 @@ use language::{Anchor, Buffer, BufferSnapshot}; use std::{ ops::{AddAssign, Range}, path::Path, + sync::Arc, time::Duration, }; use text::{ToOffset, ToPoint}; @@ -51,7 +52,7 @@ fn completion_from_diff( ) -> EditPrediction { let buffer_text = snapshot.text_for_range(delete_range).collect::(); - let mut edits: Vec<(Range, String)> = Vec::new(); + let mut edits: Vec<(Range, Arc)> = Vec::new(); let completion_graphemes: Vec<&str> = completion_text.graphemes(true).collect(); let buffer_graphemes: Vec<&str> = buffer_text.graphemes(true).collect(); @@ -70,7 +71,10 @@ fn completion_from_diff( if k != 0 { let offset = snapshot.anchor_after(offset); // the range from the current position to item is an inlay. - let edit = (offset..offset, completion_graphemes[i..i + k].join("")); + let edit = ( + offset..offset, + completion_graphemes[i..i + k].join("").into(), + ); edits.push(edit); } i += k + 1; @@ -90,7 +94,7 @@ fn completion_from_diff( // there is leftover completion text, so drop it as an inlay. let edit_range = offset..offset; let edit_text = completion_graphemes[i..].join(""); - edits.push((edit_range, edit_text)); + edits.push((edit_range, edit_text.into())); } EditPrediction::Local { diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 454a1526a9e8c6a75d47bda875feb6843b454a0d..3bd614b480793c07c0c7b7e4f2578cd2b6cba6bd 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -133,7 +133,7 @@ pub struct EditPrediction { path: Arc, excerpt_range: Range, cursor_offset: usize, - edits: Arc<[(Range, String)]>, + edits: Arc<[(Range, Arc)]>, snapshot: BufferSnapshot, edit_preview: EditPreview, input_outline: Arc, @@ -150,7 +150,7 @@ impl EditPrediction { .duration_since(self.buffer_snapshotted_at) } - fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option, String)>> { + fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option, Arc)>> { edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits) } } @@ -711,7 +711,7 @@ impl Zeta { cx.spawn(async move |cx| { let output_excerpt: Arc = output_excerpt.into(); - let edits: Arc<[(Range, String)]> = cx + let edits: Arc<[(Range, Arc)]> = cx .background_spawn({ let output_excerpt = output_excerpt.clone(); let editable_range = editable_range.clone(); @@ -725,7 +725,7 @@ impl Zeta { let edits = edits.clone(); move |buffer, cx| { let new_snapshot = buffer.snapshot(); - let edits: Arc<[(Range, String)]> = + let edits: Arc<[(Range, Arc)]> = edit_prediction::interpolate_edits(&snapshot, &new_snapshot, &edits)? .into(); Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx))) @@ -759,7 +759,7 @@ impl Zeta { output_excerpt: Arc, editable_range: Range, snapshot: &BufferSnapshot, - ) -> Result, String)>> { + ) -> Result, Arc)>> { let content = output_excerpt.replace(CURSOR_MARKER, ""); let start_markers = content @@ -817,7 +817,7 @@ impl Zeta { new_text: &str, offset: usize, snapshot: &BufferSnapshot, - ) -> Vec<(Range, String)> { + ) -> Vec<(Range, Arc)> { text_diff(&old_text, new_text) .into_iter() .map(|(mut old_range, new_text)| { @@ -836,7 +836,7 @@ impl Zeta { ); old_range.end = old_range.end.saturating_sub(suffix_len); - let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string(); + let new_text = new_text[prefix_len..new_text.len() - suffix_len].into(); let range = if old_range.is_empty() { let anchor = snapshot.anchor_after(old_range.start); anchor..anchor @@ -1183,7 +1183,7 @@ impl CurrentEditPrediction { if old_edits.len() == 1 && new_edits.len() == 1 { let (old_range, old_text) = &old_edits[0]; let (new_range, new_text) = &new_edits[0]; - new_range == old_range && new_text.starts_with(old_text) + new_range == old_range && new_text.starts_with(old_text.as_ref()) } else { true } @@ -1599,13 +1599,8 @@ mod tests { #[gpui::test] async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx)); - let edits: Arc<[(Range, String)]> = cx.update(|cx| { - to_completion_edits( - [(2..5, "REM".to_string()), (9..11, "".to_string())], - &buffer, - cx, - ) - .into() + let edits: Arc<[(Range, Arc)]> = cx.update(|cx| { + to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into() }); let edit_preview = cx @@ -1635,7 +1630,7 @@ mod tests { &buffer, cx ), - vec![(2..5, "REM".to_string()), (9..11, "".to_string())] + vec![(2..5, "REM".into()), (9..11, "".into())] ); buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx)); @@ -1645,7 +1640,7 @@ mod tests { &buffer, cx ), - vec![(2..2, "REM".to_string()), (6..8, "".to_string())] + vec![(2..2, "REM".into()), (6..8, "".into())] ); buffer.update(cx, |buffer, cx| buffer.undo(cx)); @@ -1655,7 +1650,7 @@ mod tests { &buffer, cx ), - vec![(2..5, "REM".to_string()), (9..11, "".to_string())] + vec![(2..5, "REM".into()), (9..11, "".into())] ); buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx)); @@ -1665,7 +1660,7 @@ mod tests { &buffer, cx ), - vec![(3..3, "EM".to_string()), (7..9, "".to_string())] + vec![(3..3, "EM".into()), (7..9, "".into())] ); buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx)); @@ -1675,7 +1670,7 @@ mod tests { &buffer, cx ), - vec![(4..4, "M".to_string()), (8..10, "".to_string())] + vec![(4..4, "M".into()), (8..10, "".into())] ); buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx)); @@ -1685,7 +1680,7 @@ mod tests { &buffer, cx ), - vec![(9..11, "".to_string())] + vec![(9..11, "".into())] ); buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx)); @@ -1695,7 +1690,7 @@ mod tests { &buffer, cx ), - vec![(4..4, "M".to_string()), (8..10, "".to_string())] + vec![(4..4, "M".into()), (8..10, "".into())] ); buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx)); @@ -1705,7 +1700,7 @@ mod tests { &buffer, cx ), - vec![(4..4, "M".to_string())] + vec![(4..4, "M".into())] ); buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx)); @@ -2211,10 +2206,10 @@ mod tests { } fn to_completion_edits( - iterator: impl IntoIterator, String)>, + iterator: impl IntoIterator, Arc)>, buffer: &Entity, cx: &App, - ) -> Vec<(Range, String)> { + ) -> Vec<(Range, Arc)> { let buffer = buffer.read(cx); iterator .into_iter() @@ -2228,10 +2223,10 @@ mod tests { } fn from_completion_edits( - editor_edits: &[(Range, String)], + editor_edits: &[(Range, Arc)], buffer: &Entity, cx: &App, - ) -> Vec<(Range, String)> { + ) -> Vec<(Range, Arc)> { let buffer = buffer.read(cx); editor_edits .iter() diff --git a/crates/zeta2/Cargo.toml b/crates/zeta2/Cargo.toml index 13bb4e9106de9f5f201ba59106304a6aab4208d1..cbde212dd104bdc909dda19de403f815ff4f6386 100644 --- a/crates/zeta2/Cargo.toml +++ b/crates/zeta2/Cargo.toml @@ -28,9 +28,9 @@ indoc.workspace = true language.workspace = true language_model.workspace = true log.workspace = true +open_ai.workspace = true project.workspace = true release_channel.workspace = true -schemars.workspace = true serde.workspace = true serde_json.workspace = true thiserror.workspace = true @@ -50,3 +50,4 @@ language_model = { workspace = true, features = ["test-support"] } pretty_assertions.workspace = true project = { workspace = true, features = ["test-support"] } settings = { workspace = true, features = ["test-support"] } +zlog.workspace = true diff --git a/crates/zeta2/src/prediction.rs b/crates/zeta2/src/prediction.rs index a0dcd83b88142a5746c0b3c7d82bc7a64965edab..54a6987b3f781a48fe928636dc3537117ee6a401 100644 --- a/crates/zeta2/src/prediction.rs +++ b/crates/zeta2/src/prediction.rs @@ -1,17 +1,11 @@ -use std::{borrow::Cow, ops::Range, path::Path, sync::Arc}; - -use anyhow::Context as _; -use cloud_llm_client::predict_edits_v3; -use gpui::{App, AsyncApp, Entity}; -use language::{ - Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot, text_diff, -}; -use project::Project; -use util::ResultExt; +use std::{ops::Range, sync::Arc}; + +use gpui::{AsyncApp, Entity}; +use language::{Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot}; use uuid::Uuid; #[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)] -pub struct EditPredictionId(Uuid); +pub struct EditPredictionId(pub Uuid); impl Into for EditPredictionId { fn into(self) -> Uuid { @@ -34,8 +28,7 @@ impl std::fmt::Display for EditPredictionId { #[derive(Clone)] pub struct EditPrediction { pub id: EditPredictionId, - pub path: Arc, - pub edits: Arc<[(Range, String)]>, + pub edits: Arc<[(Range, Arc)]>, pub snapshot: BufferSnapshot, pub edit_preview: EditPreview, // We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction. @@ -43,90 +36,43 @@ pub struct EditPrediction { } impl EditPrediction { - pub async fn from_response( - response: predict_edits_v3::PredictEditsResponse, - active_buffer_old_snapshot: &TextBufferSnapshot, - active_buffer: &Entity, - project: &Entity, + pub async fn new( + id: EditPredictionId, + edited_buffer: &Entity, + edited_buffer_snapshot: &BufferSnapshot, + edits: Vec<(Range, Arc)>, cx: &mut AsyncApp, ) -> Option { - // TODO only allow cloud to return one path - let Some(path) = response.edits.first().map(|e| e.path.clone()) else { - return None; - }; + let (edits, snapshot, edit_preview_task) = edited_buffer + .read_with(cx, |buffer, cx| { + let new_snapshot = buffer.snapshot(); + let edits: Arc<[_]> = + interpolate_edits(&edited_buffer_snapshot, &new_snapshot, edits.into())?.into(); - let is_same_path = active_buffer - .read_with(cx, |buffer, cx| buffer_path_eq(buffer, &path, cx)) - .ok()?; - - let (buffer, edits, snapshot, edit_preview_task) = if is_same_path { - active_buffer - .read_with(cx, |buffer, cx| { - let new_snapshot = buffer.snapshot(); - let edits = edits_from_response(&response.edits, &active_buffer_old_snapshot); - let edits: Arc<[_]> = - interpolate_edits(active_buffer_old_snapshot, &new_snapshot, edits)?.into(); - - Some(( - active_buffer.clone(), - edits.clone(), - new_snapshot, - buffer.preview_edits(edits, cx), - )) - }) - .ok()?? - } else { - let buffer_handle = project - .update(cx, |project, cx| { - let project_path = project - .find_project_path(&path, cx) - .context("Failed to find project path for zeta edit")?; - anyhow::Ok(project.open_buffer(project_path, cx)) - }) - .ok()? - .log_err()? - .await - .context("Failed to open buffer for zeta edit") - .log_err()?; - - buffer_handle - .read_with(cx, |buffer, cx| { - let snapshot = buffer.snapshot(); - let edits = edits_from_response(&response.edits, &snapshot); - if edits.is_empty() { - return None; - } - Some(( - buffer_handle.clone(), - edits.clone(), - snapshot, - buffer.preview_edits(edits, cx), - )) - }) - .ok()?? - }; + Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx))) + }) + .ok()??; let edit_preview = edit_preview_task.await; Some(EditPrediction { - id: EditPredictionId(response.request_id), - path, + id, edits, snapshot, edit_preview, - buffer, + buffer: edited_buffer.clone(), }) } pub fn interpolate( &self, new_snapshot: &TextBufferSnapshot, - ) -> Option, String)>> { + ) -> Option, Arc)>> { interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone()) } - pub fn targets_buffer(&self, buffer: &Buffer, cx: &App) -> bool { - buffer_path_eq(buffer, &self.path, cx) + pub fn targets_buffer(&self, buffer: &Buffer) -> bool { + self.snapshot.remote_id() == buffer.remote_id() } } @@ -134,21 +80,16 @@ impl std::fmt::Debug for EditPrediction { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("EditPrediction") .field("id", &self.id) - .field("path", &self.path) .field("edits", &self.edits) .finish() } } -pub fn buffer_path_eq(buffer: &Buffer, path: &Path, cx: &App) -> bool { - buffer.file().map(|p| p.full_path(cx)).as_deref() == Some(path) -} - pub fn interpolate_edits( old_snapshot: &TextBufferSnapshot, new_snapshot: &TextBufferSnapshot, - current_edits: Arc<[(Range, String)]>, -) -> Option, String)>> { + current_edits: Arc<[(Range, Arc)]>, +) -> Option, Arc)>> { let mut edits = Vec::new(); let mut model_edits = current_edits.iter().peekable(); @@ -173,7 +114,7 @@ pub fn interpolate_edits( if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) { if !model_suffix.is_empty() { let anchor = old_snapshot.anchor_after(user_edit.old.end); - edits.push((anchor..anchor, model_suffix.to_string())); + edits.push((anchor..anchor, model_suffix.into())); } model_edits.next(); @@ -190,135 +131,17 @@ pub fn interpolate_edits( if edits.is_empty() { None } else { Some(edits) } } -pub fn line_range_to_point_range(range: Range) -> Range { - language::Point::new(range.start.0, 0)..language::Point::new(range.end.0, 0) -} - -fn edits_from_response( - edits: &[predict_edits_v3::Edit], - snapshot: &TextBufferSnapshot, -) -> Arc<[(Range, String)]> { - edits - .iter() - .flat_map(|edit| { - let point_range = line_range_to_point_range(edit.range.clone()); - let offset = point_range.to_offset(snapshot).start; - let old_text = snapshot.text_for_range(point_range); - - excerpt_edits_from_response( - old_text.collect::>(), - &edit.content, - offset, - &snapshot, - ) - }) - .collect::>() - .into() -} - -fn excerpt_edits_from_response( - old_text: Cow, - new_text: &str, - offset: usize, - snapshot: &TextBufferSnapshot, -) -> impl Iterator, String)> { - text_diff(&old_text, new_text) - .into_iter() - .map(move |(mut old_range, new_text)| { - old_range.start += offset; - old_range.end += offset; - - let prefix_len = common_prefix( - snapshot.chars_for_range(old_range.clone()), - new_text.chars(), - ); - old_range.start += prefix_len; - - let suffix_len = common_prefix( - snapshot.reversed_chars_for_range(old_range.clone()), - new_text[prefix_len..].chars().rev(), - ); - old_range.end = old_range.end.saturating_sub(suffix_len); - - let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string(); - let range = if old_range.is_empty() { - let anchor = snapshot.anchor_after(old_range.start); - anchor..anchor - } else { - snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end) - }; - (range, new_text) - }) -} - -fn common_prefix, T2: Iterator>(a: T1, b: T2) -> usize { - a.zip(b) - .take_while(|(a, b)| a == b) - .map(|(a, _)| a.len_utf8()) - .sum() -} - #[cfg(test)] mod tests { - use std::path::PathBuf; - use super::*; - use cloud_llm_client::predict_edits_v3; - use edit_prediction_context::Line; use gpui::{App, Entity, TestAppContext, prelude::*}; - use indoc::indoc; use language::{Buffer, ToOffset as _}; - #[gpui::test] - async fn test_compute_edits(cx: &mut TestAppContext) { - let old = indoc! {r#" - fn main() { - let args = - println!("{}", args[1]) - } - "#}; - - let new = indoc! {r#" - fn main() { - let args = std::env::args(); - println!("{}", args[1]); - } - "#}; - - let buffer = cx.new(|cx| Buffer::local(old, cx)); - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); - - // TODO cover more cases when multi-file is supported - let big_edits = vec![predict_edits_v3::Edit { - path: PathBuf::from("test.txt").into(), - range: Line(0)..Line(old.lines().count() as u32), - content: new.into(), - }]; - - let edits = edits_from_response(&big_edits, &snapshot); - assert_eq!(edits.len(), 2); - assert_eq!( - edits[0].0.to_point(&snapshot).start, - language::Point::new(1, 14) - ); - assert_eq!(edits[0].1, " std::env::args();"); - assert_eq!( - edits[1].0.to_point(&snapshot).start, - language::Point::new(2, 27) - ); - assert_eq!(edits[1].1, ";"); - } - #[gpui::test] async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx)); - let edits: Arc<[(Range, String)]> = cx.update(|cx| { - to_prediction_edits( - [(2..5, "REM".to_string()), (9..11, "".to_string())], - &buffer, - cx, - ) - .into() + let edits: Arc<[(Range, Arc)]> = cx.update(|cx| { + to_prediction_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into() }); let edit_preview = cx @@ -329,7 +152,6 @@ mod tests { id: EditPredictionId(Uuid::new_v4()), edits, snapshot: cx.read(|cx| buffer.read(cx).snapshot()), - path: Path::new("test.txt").into(), buffer: buffer.clone(), edit_preview, }; @@ -341,7 +163,7 @@ mod tests { &buffer, cx ), - vec![(2..5, "REM".to_string()), (9..11, "".to_string())] + vec![(2..5, "REM".into()), (9..11, "".into())] ); buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx)); @@ -351,7 +173,7 @@ mod tests { &buffer, cx ), - vec![(2..2, "REM".to_string()), (6..8, "".to_string())] + vec![(2..2, "REM".into()), (6..8, "".into())] ); buffer.update(cx, |buffer, cx| buffer.undo(cx)); @@ -361,7 +183,7 @@ mod tests { &buffer, cx ), - vec![(2..5, "REM".to_string()), (9..11, "".to_string())] + vec![(2..5, "REM".into()), (9..11, "".into())] ); buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx)); @@ -371,7 +193,7 @@ mod tests { &buffer, cx ), - vec![(3..3, "EM".to_string()), (7..9, "".to_string())] + vec![(3..3, "EM".into()), (7..9, "".into())] ); buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx)); @@ -381,7 +203,7 @@ mod tests { &buffer, cx ), - vec![(4..4, "M".to_string()), (8..10, "".to_string())] + vec![(4..4, "M".into()), (8..10, "".into())] ); buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx)); @@ -391,7 +213,7 @@ mod tests { &buffer, cx ), - vec![(9..11, "".to_string())] + vec![(9..11, "".into())] ); buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx)); @@ -401,7 +223,7 @@ mod tests { &buffer, cx ), - vec![(4..4, "M".to_string()), (8..10, "".to_string())] + vec![(4..4, "M".into()), (8..10, "".into())] ); buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx)); @@ -411,7 +233,7 @@ mod tests { &buffer, cx ), - vec![(4..4, "M".to_string())] + vec![(4..4, "M".into())] ); buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx)); @@ -420,10 +242,10 @@ mod tests { } fn to_prediction_edits( - iterator: impl IntoIterator, String)>, + iterator: impl IntoIterator, Arc)>, buffer: &Entity, cx: &App, - ) -> Vec<(Range, String)> { + ) -> Vec<(Range, Arc)> { let buffer = buffer.read(cx); iterator .into_iter() @@ -437,10 +259,10 @@ mod tests { } fn from_prediction_edits( - editor_edits: &[(Range, String)], + editor_edits: &[(Range, Arc)], buffer: &Entity, cx: &App, - ) -> Vec<(Range, String)> { + ) -> Vec<(Range, Arc)> { let buffer = buffer.read(cx); editor_edits .iter() diff --git a/crates/zeta2/src/related_excerpts.rs b/crates/zeta2/src/related_excerpts.rs deleted file mode 100644 index f1721020d000ec9b7ec308eaa3bac4951c45c3f8..0000000000000000000000000000000000000000 --- a/crates/zeta2/src/related_excerpts.rs +++ /dev/null @@ -1,717 +0,0 @@ -use std::{ - cmp::Reverse, collections::hash_map::Entry, ops::Range, path::PathBuf, sync::Arc, time::Instant, -}; - -use crate::{ - ZetaContextRetrievalDebugInfo, ZetaContextRetrievalStartedDebugInfo, ZetaDebugInfo, - ZetaSearchQueryDebugInfo, merge_excerpts::merge_excerpts, -}; -use anyhow::{Result, anyhow}; -use cloud_zeta2_prompt::write_codeblock; -use collections::HashMap; -use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions, Line}; -use futures::{ - StreamExt, - channel::mpsc::{self, UnboundedSender}, - stream::BoxStream, -}; -use gpui::{App, AppContext, AsyncApp, Entity, Task}; -use indoc::indoc; -use language::{ - Anchor, Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point, TextBufferSnapshot, ToPoint as _, -}; -use language_model::{ - LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, - LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest, - LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, - LanguageModelToolUse, MessageContent, Role, -}; -use project::{ - Project, WorktreeSettings, - search::{SearchQuery, SearchResult}, -}; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use util::{ - ResultExt as _, - paths::{PathMatcher, PathStyle}, -}; -use workspace::item::Settings as _; - -const SEARCH_PROMPT: &str = indoc! {r#" - ## Task - - You are part of an edit prediction system in a code editor. Your role is to identify relevant code locations - that will serve as context for predicting the next required edit. - - **Your task:** - - Analyze the user's recent edits and current cursor context - - Use the `search` tool to find code that may be relevant for predicting the next edit - - Focus on finding: - - Code patterns that might need similar changes based on the recent edits - - Functions, variables, types, and constants referenced in the current cursor context - - Related implementations, usages, or dependencies that may require consistent updates - - **Important constraints:** - - This conversation has exactly 2 turns - - You must make ALL search queries in your first response via the `search` tool - - All queries will be executed in parallel and results returned together - - In the second turn, you will select the most relevant results via the `select` tool. - - ## User Edits - - {edits} - - ## Current cursor context - - `````{current_file_path} - {cursor_excerpt} - ````` - - -- - Use the `search` tool now -"#}; - -const SEARCH_TOOL_NAME: &str = "search"; - -/// Search for relevant code -/// -/// For the best results, run multiple queries at once with a single invocation of this tool. -#[derive(Clone, Deserialize, Serialize, JsonSchema)] -pub struct SearchToolInput { - /// An array of queries to run for gathering context relevant to the next prediction - #[schemars(length(max = 5))] - pub queries: Box<[SearchToolQuery]>, -} - -#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] -pub struct SearchToolQuery { - /// A glob pattern to match file paths in the codebase - pub glob: String, - /// A regular expression to match content within the files matched by the glob pattern - pub regex: String, -} - -const RESULTS_MESSAGE: &str = indoc! {" - Here are the results of your queries combined and grouped by file: - -"}; - -const SELECT_TOOL_NAME: &str = "select"; - -const SELECT_PROMPT: &str = indoc! {" - Use the `select` tool now to pick the most relevant line ranges according to the user state provided in the first message. - Make sure to include enough lines of context so that the edit prediction model can suggest accurate edits. - Include up to 200 lines in total. -"}; - -/// Select line ranges from search results -#[derive(Deserialize, JsonSchema)] -struct SelectToolInput { - /// The line ranges to select from search results. - ranges: Vec, -} - -/// A specific line range to select from a file -#[derive(Debug, Deserialize, JsonSchema)] -struct SelectLineRange { - /// The file path containing the lines to select - /// Exactly as it appears in the search result codeblocks. - path: PathBuf, - /// The starting line number (1-based) - #[schemars(range(min = 1))] - start_line: u32, - /// The ending line number (1-based, inclusive) - #[schemars(range(min = 1))] - end_line: u32, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct LlmContextOptions { - pub excerpt: EditPredictionExcerptOptions, -} - -pub const MODEL_PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID; - -pub fn find_related_excerpts( - buffer: Entity, - cursor_position: Anchor, - project: &Entity, - mut edit_history_unified_diff: String, - options: &LlmContextOptions, - debug_tx: Option>, - cx: &App, -) -> Task, Vec>>>> { - let language_model_registry = LanguageModelRegistry::global(cx); - let Some(model) = language_model_registry - .read(cx) - .available_models(cx) - .find(|model| { - model.provider_id() == MODEL_PROVIDER_ID - && model.id() == LanguageModelId("claude-haiku-4-5-latest".into()) - // model.provider_id() == LanguageModelProviderId::new("zeta-ctx-qwen-30b") - // model.provider_id() == LanguageModelProviderId::new("ollama") - // && model.id() == LanguageModelId("gpt-oss:20b".into()) - }) - else { - return Task::ready(Err(anyhow!("could not find context model"))); - }; - - if edit_history_unified_diff.is_empty() { - edit_history_unified_diff.push_str("(No user edits yet)"); - } - - // TODO [zeta2] include breadcrumbs? - let snapshot = buffer.read(cx).snapshot(); - let cursor_point = cursor_position.to_point(&snapshot); - let Some(cursor_excerpt) = - EditPredictionExcerpt::select_from_buffer(cursor_point, &snapshot, &options.excerpt, None) - else { - return Task::ready(Ok(HashMap::default())); - }; - - let current_file_path = snapshot - .file() - .map(|f| f.full_path(cx).display().to_string()) - .unwrap_or_else(|| "untitled".to_string()); - - let prompt = SEARCH_PROMPT - .replace("{edits}", &edit_history_unified_diff) - .replace("{current_file_path}", ¤t_file_path) - .replace("{cursor_excerpt}", &cursor_excerpt.text(&snapshot).body); - - if let Some(debug_tx) = &debug_tx { - debug_tx - .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted( - ZetaContextRetrievalStartedDebugInfo { - project: project.clone(), - timestamp: Instant::now(), - search_prompt: prompt.clone(), - }, - )) - .ok(); - } - - let path_style = project.read(cx).path_style(cx); - - let exclude_matcher = { - let global_settings = WorktreeSettings::get_global(cx); - let exclude_patterns = global_settings - .file_scan_exclusions - .sources() - .iter() - .chain(global_settings.private_files.sources().iter()); - - match PathMatcher::new(exclude_patterns, path_style) { - Ok(matcher) => matcher, - Err(err) => { - return Task::ready(Err(anyhow!(err))); - } - } - }; - - let project = project.clone(); - cx.spawn(async move |cx| { - let initial_prompt_message = LanguageModelRequestMessage { - role: Role::User, - content: vec![prompt.into()], - cache: false, - }; - - let mut search_stream = request_tool_call::( - vec![initial_prompt_message.clone()], - SEARCH_TOOL_NAME, - &model, - cx, - ) - .await?; - - let mut select_request_messages = Vec::with_capacity(5); // initial prompt, LLM response/thinking, tool use, tool result, select prompt - select_request_messages.push(initial_prompt_message); - - let mut regex_by_glob: HashMap = HashMap::default(); - let mut search_calls = Vec::new(); - - while let Some(event) = search_stream.next().await { - match event? { - LanguageModelCompletionEvent::ToolUse(tool_use) => { - if !tool_use.is_input_complete { - continue; - } - - if tool_use.name.as_ref() == SEARCH_TOOL_NAME { - let input = - serde_json::from_value::(tool_use.input.clone())?; - - for query in input.queries { - let regex = regex_by_glob.entry(query.glob).or_default(); - if !regex.is_empty() { - regex.push('|'); - } - regex.push_str(&query.regex); - } - - search_calls.push(tool_use); - } else { - log::warn!( - "context gathering model tried to use unknown tool: {}", - tool_use.name - ); - } - } - LanguageModelCompletionEvent::Text(txt) => { - if let Some(LanguageModelRequestMessage { - role: Role::Assistant, - content, - .. - }) = select_request_messages.last_mut() - { - if let Some(MessageContent::Text(existing_text)) = content.last_mut() { - existing_text.push_str(&txt); - } else { - content.push(MessageContent::Text(txt)); - } - } else { - select_request_messages.push(LanguageModelRequestMessage { - role: Role::Assistant, - content: vec![MessageContent::Text(txt)], - cache: false, - }); - } - } - LanguageModelCompletionEvent::Thinking { text, signature } => { - if let Some(LanguageModelRequestMessage { - role: Role::Assistant, - content, - .. - }) = select_request_messages.last_mut() - { - if let Some(MessageContent::Thinking { - text: existing_text, - signature: existing_signature, - }) = content.last_mut() - { - existing_text.push_str(&text); - *existing_signature = signature; - } else { - content.push(MessageContent::Thinking { text, signature }); - } - } else { - select_request_messages.push(LanguageModelRequestMessage { - role: Role::Assistant, - content: vec![MessageContent::Thinking { text, signature }], - cache: false, - }); - } - } - LanguageModelCompletionEvent::RedactedThinking { data } => { - if let Some(LanguageModelRequestMessage { - role: Role::Assistant, - content, - .. - }) = select_request_messages.last_mut() - { - if let Some(MessageContent::RedactedThinking(existing_data)) = - content.last_mut() - { - existing_data.push_str(&data); - } else { - content.push(MessageContent::RedactedThinking(data)); - } - } else { - select_request_messages.push(LanguageModelRequestMessage { - role: Role::Assistant, - content: vec![MessageContent::RedactedThinking(data)], - cache: false, - }); - } - } - ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => { - log::error!("{ev:?}"); - } - ev => { - log::trace!("context search event: {ev:?}") - } - } - } - - let search_tool_use = if search_calls.is_empty() { - log::warn!("context model ran 0 searches"); - return anyhow::Ok(Default::default()); - } else if search_calls.len() == 1 { - search_calls.swap_remove(0) - } else { - // In theory, the model could perform multiple search calls - // Dealing with them separately is not worth it when it doesn't happen in practice. - // If it were to happen, here we would combine them into one. - // The second request doesn't need to know it was actually two different calls ;) - let input = serde_json::to_value(&SearchToolInput { - queries: regex_by_glob - .iter() - .map(|(glob, regex)| SearchToolQuery { - glob: glob.clone(), - regex: regex.clone(), - }) - .collect(), - }) - .unwrap_or_default(); - - LanguageModelToolUse { - id: search_calls.swap_remove(0).id, - name: SELECT_TOOL_NAME.into(), - raw_input: serde_json::to_string(&input).unwrap_or_default(), - input, - is_input_complete: true, - } - }; - - if let Some(debug_tx) = &debug_tx { - debug_tx - .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated( - ZetaSearchQueryDebugInfo { - project: project.clone(), - timestamp: Instant::now(), - queries: regex_by_glob - .iter() - .map(|(glob, regex)| SearchToolQuery { - glob: glob.clone(), - regex: regex.clone(), - }) - .collect(), - }, - )) - .ok(); - } - - let (results_tx, mut results_rx) = mpsc::unbounded(); - - for (glob, regex) in regex_by_glob { - let exclude_matcher = exclude_matcher.clone(); - let results_tx = results_tx.clone(); - let project = project.clone(); - cx.spawn(async move |cx| { - run_query( - &glob, - ®ex, - results_tx.clone(), - path_style, - exclude_matcher, - &project, - cx, - ) - .await - .log_err(); - }) - .detach() - } - drop(results_tx); - - struct ResultBuffer { - buffer: Entity, - snapshot: TextBufferSnapshot, - } - - let (result_buffers_by_path, merged_result) = cx - .background_spawn(async move { - let mut excerpts_by_buffer: HashMap, MatchedBuffer> = - HashMap::default(); - - while let Some((buffer, matched)) = results_rx.next().await { - match excerpts_by_buffer.entry(buffer) { - Entry::Occupied(mut entry) => { - let entry = entry.get_mut(); - entry.full_path = matched.full_path; - entry.snapshot = matched.snapshot; - entry.line_ranges.extend(matched.line_ranges); - } - Entry::Vacant(entry) => { - entry.insert(matched); - } - } - } - - let mut result_buffers_by_path = HashMap::default(); - let mut merged_result = RESULTS_MESSAGE.to_string(); - - for (buffer, mut matched) in excerpts_by_buffer { - matched - .line_ranges - .sort_unstable_by_key(|range| (range.start, Reverse(range.end))); - - write_codeblock( - &matched.full_path, - merge_excerpts(&matched.snapshot, matched.line_ranges).iter(), - &[], - Line(matched.snapshot.max_point().row), - true, - &mut merged_result, - ); - - result_buffers_by_path.insert( - matched.full_path, - ResultBuffer { - buffer, - snapshot: matched.snapshot.text, - }, - ); - } - - (result_buffers_by_path, merged_result) - }) - .await; - - if let Some(debug_tx) = &debug_tx { - debug_tx - .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted( - ZetaContextRetrievalDebugInfo { - project: project.clone(), - timestamp: Instant::now(), - }, - )) - .ok(); - } - - let tool_result = LanguageModelToolResult { - tool_use_id: search_tool_use.id.clone(), - tool_name: SEARCH_TOOL_NAME.into(), - is_error: false, - content: merged_result.into(), - output: None, - }; - - select_request_messages.extend([ - LanguageModelRequestMessage { - role: Role::Assistant, - content: vec![MessageContent::ToolUse(search_tool_use)], - cache: false, - }, - LanguageModelRequestMessage { - role: Role::User, - content: vec![MessageContent::ToolResult(tool_result)], - cache: false, - }, - ]); - - if result_buffers_by_path.is_empty() { - log::trace!("context gathering queries produced no results"); - return anyhow::Ok(HashMap::default()); - } - - select_request_messages.push(LanguageModelRequestMessage { - role: Role::User, - content: vec![SELECT_PROMPT.into()], - cache: false, - }); - - let mut select_stream = request_tool_call::( - select_request_messages, - SELECT_TOOL_NAME, - &model, - cx, - ) - .await?; - - cx.background_spawn(async move { - let mut selected_ranges = Vec::new(); - - while let Some(event) = select_stream.next().await { - match event? { - LanguageModelCompletionEvent::ToolUse(tool_use) => { - if !tool_use.is_input_complete { - continue; - } - - if tool_use.name.as_ref() == SELECT_TOOL_NAME { - let call = - serde_json::from_value::(tool_use.input.clone())?; - selected_ranges.extend(call.ranges); - } else { - log::warn!( - "context gathering model tried to use unknown tool: {}", - tool_use.name - ); - } - } - ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => { - log::error!("{ev:?}"); - } - ev => { - log::trace!("context select event: {ev:?}") - } - } - } - - if let Some(debug_tx) = &debug_tx { - debug_tx - .unbounded_send(ZetaDebugInfo::SearchResultsFiltered( - ZetaContextRetrievalDebugInfo { - project: project.clone(), - timestamp: Instant::now(), - }, - )) - .ok(); - } - - if selected_ranges.is_empty() { - log::trace!("context gathering selected no ranges") - } - - selected_ranges.sort_unstable_by(|a, b| { - a.start_line - .cmp(&b.start_line) - .then(b.end_line.cmp(&a.end_line)) - }); - - let mut related_excerpts_by_buffer: HashMap<_, Vec<_>> = HashMap::default(); - - for selected_range in selected_ranges { - if let Some(ResultBuffer { buffer, snapshot }) = - result_buffers_by_path.get(&selected_range.path) - { - let start_point = Point::new(selected_range.start_line.saturating_sub(1), 0); - let end_point = - snapshot.clip_point(Point::new(selected_range.end_line, 0), Bias::Left); - let range = - snapshot.anchor_after(start_point)..snapshot.anchor_before(end_point); - - related_excerpts_by_buffer - .entry(buffer.clone()) - .or_default() - .push(range); - } else { - log::warn!( - "selected path that wasn't included in search results: {}", - selected_range.path.display() - ); - } - } - - anyhow::Ok(related_excerpts_by_buffer) - }) - .await - }) -} - -async fn request_tool_call( - messages: Vec, - tool_name: &'static str, - model: &Arc, - cx: &mut AsyncApp, -) -> Result>> -{ - let schema = schemars::schema_for!(T); - - let request = LanguageModelRequest { - messages, - tools: vec![LanguageModelRequestTool { - name: tool_name.into(), - description: schema - .get("description") - .and_then(|description| description.as_str()) - .unwrap() - .to_string(), - input_schema: serde_json::to_value(schema).unwrap(), - }], - ..Default::default() - }; - - Ok(model.stream_completion(request, cx).await?) -} - -const MIN_EXCERPT_LEN: usize = 16; -const MAX_EXCERPT_LEN: usize = 768; -const MAX_RESULT_BYTES_PER_QUERY: usize = MAX_EXCERPT_LEN * 5; - -struct MatchedBuffer { - snapshot: BufferSnapshot, - line_ranges: Vec>, - full_path: PathBuf, -} - -async fn run_query( - glob: &str, - regex: &str, - results_tx: UnboundedSender<(Entity, MatchedBuffer)>, - path_style: PathStyle, - exclude_matcher: PathMatcher, - project: &Entity, - cx: &mut AsyncApp, -) -> Result<()> { - let include_matcher = PathMatcher::new(vec![glob], path_style)?; - - let query = SearchQuery::regex( - regex, - false, - true, - false, - true, - include_matcher, - exclude_matcher, - true, - None, - )?; - - let results = project.update(cx, |project, cx| project.search(query, cx))?; - futures::pin_mut!(results); - - let mut total_bytes = 0; - - while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await { - if ranges.is_empty() { - continue; - } - - let Some((snapshot, full_path)) = buffer.read_with(cx, |buffer, cx| { - Some((buffer.snapshot(), buffer.file()?.full_path(cx))) - })? - else { - continue; - }; - - let results_tx = results_tx.clone(); - cx.background_spawn(async move { - let mut line_ranges = Vec::with_capacity(ranges.len()); - - for range in ranges { - let offset_range = range.to_offset(&snapshot); - let query_point = (offset_range.start + offset_range.len() / 2).to_point(&snapshot); - - if total_bytes + MIN_EXCERPT_LEN >= MAX_RESULT_BYTES_PER_QUERY { - break; - } - - let excerpt = EditPredictionExcerpt::select_from_buffer( - query_point, - &snapshot, - &EditPredictionExcerptOptions { - max_bytes: MAX_EXCERPT_LEN.min(MAX_RESULT_BYTES_PER_QUERY - total_bytes), - min_bytes: MIN_EXCERPT_LEN, - target_before_cursor_over_total_bytes: 0.5, - }, - None, - ); - - if let Some(excerpt) = excerpt { - total_bytes += excerpt.range.len(); - if !excerpt.line_range.is_empty() { - line_ranges.push(excerpt.line_range); - } - } - } - - results_tx - .unbounded_send(( - buffer, - MatchedBuffer { - snapshot, - line_ranges, - full_path, - }, - )) - .log_err(); - }) - .detach(); - } - - anyhow::Ok(()) -} diff --git a/crates/zeta2/src/retrieval_search.rs b/crates/zeta2/src/retrieval_search.rs new file mode 100644 index 0000000000000000000000000000000000000000..e2e78c3e3b295549ca2c294818f935f1d7d8a9f9 --- /dev/null +++ b/crates/zeta2/src/retrieval_search.rs @@ -0,0 +1,194 @@ +use std::ops::Range; + +use anyhow::Result; +use collections::HashMap; +use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions}; +use futures::{ + StreamExt, + channel::mpsc::{self, UnboundedSender}, +}; +use gpui::{AppContext, AsyncApp, Entity}; +use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt, ToPoint as _}; +use project::{ + Project, WorktreeSettings, + search::{SearchQuery, SearchResult}, +}; +use util::{ + ResultExt as _, + paths::{PathMatcher, PathStyle}, +}; +use workspace::item::Settings as _; + +pub async fn run_retrieval_searches( + project: Entity, + regex_by_glob: HashMap, + cx: &mut AsyncApp, +) -> Result, Vec>>> { + let (exclude_matcher, path_style) = project.update(cx, |project, cx| { + let global_settings = WorktreeSettings::get_global(cx); + let exclude_patterns = global_settings + .file_scan_exclusions + .sources() + .iter() + .chain(global_settings.private_files.sources().iter()); + let path_style = project.path_style(cx); + anyhow::Ok((PathMatcher::new(exclude_patterns, path_style)?, path_style)) + })??; + + let (results_tx, mut results_rx) = mpsc::unbounded(); + + for (glob, regex) in regex_by_glob { + let exclude_matcher = exclude_matcher.clone(); + let results_tx = results_tx.clone(); + let project = project.clone(); + cx.spawn(async move |cx| { + run_query( + &glob, + ®ex, + results_tx.clone(), + path_style, + exclude_matcher, + &project, + cx, + ) + .await + .log_err(); + }) + .detach() + } + drop(results_tx); + + cx.background_spawn(async move { + let mut results: HashMap, Vec>> = HashMap::default(); + let mut snapshots = HashMap::default(); + + let mut total_bytes = 0; + 'outer: while let Some((buffer, snapshot, excerpts)) = results_rx.next().await { + snapshots.insert(buffer.entity_id(), snapshot); + let existing = results.entry(buffer).or_default(); + existing.reserve(excerpts.len()); + + for (range, size) in excerpts { + // Blunt trimming of the results until we have a proper algorithmic filtering step + if (total_bytes + size) > MAX_RESULTS_LEN { + log::trace!("Combined results reached limit of {MAX_RESULTS_LEN}B"); + break 'outer; + } + total_bytes += size; + existing.push(range); + } + } + + for (buffer, ranges) in results.iter_mut() { + if let Some(snapshot) = snapshots.get(&buffer.entity_id()) { + ranges.sort_unstable_by(|a, b| { + a.start + .cmp(&b.start, snapshot) + .then(b.end.cmp(&b.end, snapshot)) + }); + + let mut index = 1; + while index < ranges.len() { + if ranges[index - 1] + .end + .cmp(&ranges[index].start, snapshot) + .is_gt() + { + let removed = ranges.remove(index); + ranges[index - 1].end = removed.end; + } else { + index += 1; + } + } + } + } + + Ok(results) + }) + .await +} + +const MIN_EXCERPT_LEN: usize = 16; +const MAX_EXCERPT_LEN: usize = 768; +const MAX_RESULTS_LEN: usize = MAX_EXCERPT_LEN * 5; + +async fn run_query( + glob: &str, + regex: &str, + results_tx: UnboundedSender<(Entity, BufferSnapshot, Vec<(Range, usize)>)>, + path_style: PathStyle, + exclude_matcher: PathMatcher, + project: &Entity, + cx: &mut AsyncApp, +) -> Result<()> { + let include_matcher = PathMatcher::new(vec![glob], path_style)?; + + let query = SearchQuery::regex( + regex, + false, + true, + false, + true, + include_matcher, + exclude_matcher, + true, + None, + )?; + + let results = project.update(cx, |project, cx| project.search(query, cx))?; + futures::pin_mut!(results); + + while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await { + if results_tx.is_closed() { + break; + } + + if ranges.is_empty() { + continue; + } + + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + let results_tx = results_tx.clone(); + + cx.background_spawn(async move { + let mut excerpts = Vec::with_capacity(ranges.len()); + + for range in ranges { + let offset_range = range.to_offset(&snapshot); + let query_point = (offset_range.start + offset_range.len() / 2).to_point(&snapshot); + + let excerpt = EditPredictionExcerpt::select_from_buffer( + query_point, + &snapshot, + &EditPredictionExcerptOptions { + max_bytes: MAX_EXCERPT_LEN, + min_bytes: MIN_EXCERPT_LEN, + target_before_cursor_over_total_bytes: 0.5, + }, + None, + ); + + if let Some(excerpt) = excerpt + && !excerpt.line_range.is_empty() + { + excerpts.push(( + snapshot.anchor_after(excerpt.range.start) + ..snapshot.anchor_before(excerpt.range.end), + excerpt.range.len(), + )); + } + } + + let send_result = results_tx.unbounded_send((buffer, snapshot, excerpts)); + + if let Err(err) = send_result + && !err.is_disconnected() + { + log::error!("{err}"); + } + }) + .detach(); + } + + anyhow::Ok(()) +} diff --git a/crates/zeta2/src/udiff.rs b/crates/zeta2/src/udiff.rs new file mode 100644 index 0000000000000000000000000000000000000000..866ab6f7cead61ba5add462404bd594080e3098e --- /dev/null +++ b/crates/zeta2/src/udiff.rs @@ -0,0 +1,1024 @@ +use std::borrow::Cow; +use std::fmt::Display; +use std::sync::Arc; +use std::{ + fmt::{Debug, Write}, + mem, + ops::Range, + path::Path, +}; + +use anyhow::Context as _; +use anyhow::Result; +use anyhow::anyhow; +use collections::HashMap; +use gpui::AsyncApp; +use gpui::Entity; +use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, TextBufferSnapshot}; +use project::Project; + +pub async fn parse_diff<'a>( + diff: &'a str, + get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range])> + Send, +) -> Result<(&'a BufferSnapshot, Vec<(Range, Arc)>)> { + let mut diff = DiffParser::new(diff); + let mut edited_buffer = None; + let mut edits = Vec::new(); + + while let Some(event) = diff.next()? { + match event { + DiffEvent::Hunk { + path: file_path, + hunk, + } => { + let (buffer, ranges) = match edited_buffer { + None => { + edited_buffer = get_buffer(&Path::new(file_path.as_ref())); + edited_buffer + .as_ref() + .context("Model tried to edit a file that wasn't included")? + } + Some(ref current) => current, + }; + + edits.extend(resolve_hunk_edits_in_buffer(hunk, &buffer.text, ranges)?); + } + DiffEvent::FileEnd { renamed_to } => { + let (buffer, _) = edited_buffer + .take() + .expect("Got a FileEnd event before an Hunk event"); + + if renamed_to.is_some() { + anyhow::bail!("edit predictions cannot rename files"); + } + + if diff.next()?.is_some() { + anyhow::bail!("Edited more than one file"); + } + + return Ok((buffer, edits)); + } + } + } + + Err(anyhow::anyhow!("No EOF")) +} + +#[derive(Debug)] +pub struct OpenedBuffers<'a>(#[allow(unused)] HashMap, Entity>); + +#[must_use] +pub async fn apply_diff<'a>( + diff: &'a str, + project: &Entity, + cx: &mut AsyncApp, +) -> Result> { + let mut included_files = HashMap::default(); + + for line in diff.lines() { + let diff_line = DiffLine::parse(line); + + if let DiffLine::OldPath { path } = diff_line { + let buffer = project + .update(cx, |project, cx| { + let project_path = + project + .find_project_path(path.as_ref(), cx) + .with_context(|| { + format!("Failed to find worktree for new path: {}", path) + })?; + anyhow::Ok(project.open_buffer(project_path, cx)) + })?? + .await?; + + included_files.insert(path, buffer); + } + } + + let ranges = [Anchor::MIN..Anchor::MAX]; + + let mut diff = DiffParser::new(diff); + let mut current_file = None; + let mut edits = vec![]; + + while let Some(event) = diff.next()? { + match event { + DiffEvent::Hunk { + path: file_path, + hunk, + } => { + let (buffer, ranges) = match current_file { + None => { + let buffer = included_files + .get_mut(&file_path) + .expect("Opened all files in diff"); + + current_file = Some((buffer, ranges.as_slice())); + current_file.as_ref().unwrap() + } + Some(ref current) => current, + }; + + buffer.read_with(cx, |buffer, _| { + edits.extend(resolve_hunk_edits_in_buffer(hunk, buffer, ranges)?); + anyhow::Ok(()) + })??; + } + DiffEvent::FileEnd { renamed_to } => { + let (buffer, _) = current_file + .take() + .expect("Got a FileEnd event before an Hunk event"); + + if let Some(renamed_to) = renamed_to { + project + .update(cx, |project, cx| { + let new_project_path = project + .find_project_path(Path::new(renamed_to.as_ref()), cx) + .with_context(|| { + format!("Failed to find worktree for new path: {}", renamed_to) + })?; + + let project_file = project::File::from_dyn(buffer.read(cx).file()) + .expect("Wrong file type"); + + anyhow::Ok(project.rename_entry( + project_file.entry_id.unwrap(), + new_project_path, + cx, + )) + })?? + .await?; + } + + let edits = mem::take(&mut edits); + buffer.update(cx, |buffer, cx| { + buffer.edit(edits, None, cx); + })?; + } + } + } + + Ok(OpenedBuffers(included_files)) +} + +struct PatchFile<'a> { + old_path: Cow<'a, str>, + new_path: Cow<'a, str>, +} + +struct DiffParser<'a> { + current_file: Option>, + current_line: Option<(&'a str, DiffLine<'a>)>, + hunk: Hunk, + diff: std::str::Lines<'a>, +} + +#[derive(Debug, PartialEq)] +enum DiffEvent<'a> { + Hunk { path: Cow<'a, str>, hunk: Hunk }, + FileEnd { renamed_to: Option> }, +} + +#[derive(Debug, Default, PartialEq)] +struct Hunk { + context: String, + edits: Vec, +} + +impl Hunk { + fn is_empty(&self) -> bool { + self.context.is_empty() && self.edits.is_empty() + } +} + +#[derive(Debug, PartialEq)] +struct Edit { + range: Range, + text: String, +} + +impl<'a> DiffParser<'a> { + fn new(diff: &'a str) -> Self { + let mut diff = diff.lines(); + let current_line = diff.next().map(|line| (line, DiffLine::parse(line))); + DiffParser { + current_file: None, + hunk: Hunk::default(), + current_line, + diff, + } + } + + fn next(&mut self) -> Result>> { + loop { + let (hunk_done, file_done) = match self.current_line.as_ref().map(|e| &e.1) { + Some(DiffLine::OldPath { .. }) | Some(DiffLine::Garbage(_)) | None => (true, true), + Some(DiffLine::HunkHeader(_)) => (true, false), + _ => (false, false), + }; + + if hunk_done { + if let Some(file) = &self.current_file + && !self.hunk.is_empty() + { + return Ok(Some(DiffEvent::Hunk { + path: file.old_path.clone(), + hunk: mem::take(&mut self.hunk), + })); + } + } + + if file_done { + if let Some(PatchFile { old_path, new_path }) = self.current_file.take() { + return Ok(Some(DiffEvent::FileEnd { + renamed_to: if old_path != new_path { + Some(new_path) + } else { + None + }, + })); + } + } + + let Some((line, parsed_line)) = self.current_line.take() else { + break; + }; + + util::maybe!({ + match parsed_line { + DiffLine::OldPath { path } => { + self.current_file = Some(PatchFile { + old_path: path, + new_path: "".into(), + }); + } + DiffLine::NewPath { path } => { + if let Some(current_file) = &mut self.current_file { + current_file.new_path = path + } + } + DiffLine::HunkHeader(_) => {} + DiffLine::Context(ctx) => { + if self.current_file.is_some() { + writeln!(&mut self.hunk.context, "{ctx}")?; + } + } + DiffLine::Deletion(del) => { + if self.current_file.is_some() { + let range = self.hunk.context.len() + ..self.hunk.context.len() + del.len() + '\n'.len_utf8(); + if let Some(last_edit) = self.hunk.edits.last_mut() + && last_edit.range.end == range.start + { + last_edit.range.end = range.end; + } else { + self.hunk.edits.push(Edit { + range, + text: String::new(), + }); + } + writeln!(&mut self.hunk.context, "{del}")?; + } + } + DiffLine::Addition(add) => { + if self.current_file.is_some() { + let range = self.hunk.context.len()..self.hunk.context.len(); + if let Some(last_edit) = self.hunk.edits.last_mut() + && last_edit.range.end == range.start + { + writeln!(&mut last_edit.text, "{add}").unwrap(); + } else { + self.hunk.edits.push(Edit { + range, + text: format!("{add}\n"), + }); + } + } + } + DiffLine::Garbage(_) => {} + } + + anyhow::Ok(()) + }) + .with_context(|| format!("on line:\n\n```\n{}```", line))?; + + self.current_line = self.diff.next().map(|line| (line, DiffLine::parse(line))); + } + + anyhow::Ok(None) + } +} + +fn resolve_hunk_edits_in_buffer( + hunk: Hunk, + buffer: &TextBufferSnapshot, + ranges: &[Range], +) -> Result, Arc)>, anyhow::Error> { + let context_offset = if hunk.context.is_empty() { + Ok(0) + } else { + let mut offset = None; + for range in ranges { + let range = range.to_offset(buffer); + let text = buffer.text_for_range(range.clone()).collect::(); + for (ix, _) in text.match_indices(&hunk.context) { + if offset.is_some() { + anyhow::bail!("Context is not unique enough:\n{}", hunk.context); + } + offset = Some(range.start + ix); + } + } + offset.ok_or_else(|| { + anyhow!( + "Failed to match context:\n{}\n\nBuffer:\n{}", + hunk.context, + buffer.text(), + ) + }) + }?; + let iter = hunk.edits.into_iter().flat_map(move |edit| { + let old_text = buffer + .text_for_range(context_offset + edit.range.start..context_offset + edit.range.end) + .collect::(); + let edits_within_hunk = language::text_diff(&old_text, &edit.text); + edits_within_hunk + .into_iter() + .map(move |(inner_range, inner_text)| { + ( + buffer.anchor_after(context_offset + edit.range.start + inner_range.start) + ..buffer.anchor_before(context_offset + edit.range.start + inner_range.end), + inner_text, + ) + }) + }); + Ok(iter) +} + +#[derive(Debug, PartialEq)] +pub enum DiffLine<'a> { + OldPath { path: Cow<'a, str> }, + NewPath { path: Cow<'a, str> }, + HunkHeader(Option), + Context(&'a str), + Deletion(&'a str), + Addition(&'a str), + Garbage(&'a str), +} + +#[derive(Debug, PartialEq)] +pub struct HunkLocation { + start_line_old: u32, + count_old: u32, + start_line_new: u32, + count_new: u32, +} + +impl<'a> DiffLine<'a> { + pub fn parse(line: &'a str) -> Self { + Self::try_parse(line).unwrap_or(Self::Garbage(line)) + } + + fn try_parse(line: &'a str) -> Option { + if let Some(header) = line.strip_prefix("---").and_then(eat_required_whitespace) { + let path = parse_header_path("a/", header); + Some(Self::OldPath { path }) + } else if let Some(header) = line.strip_prefix("+++").and_then(eat_required_whitespace) { + Some(Self::NewPath { + path: parse_header_path("b/", header), + }) + } else if let Some(header) = line.strip_prefix("@@").and_then(eat_required_whitespace) { + if header.starts_with("...") { + return Some(Self::HunkHeader(None)); + } + + let (start_line_old, header) = header.strip_prefix('-')?.split_once(',')?; + let mut parts = header.split_ascii_whitespace(); + let count_old = parts.next()?; + let (start_line_new, count_new) = parts.next()?.strip_prefix('+')?.split_once(',')?; + + Some(Self::HunkHeader(Some(HunkLocation { + start_line_old: start_line_old.parse::().ok()?.saturating_sub(1), + count_old: count_old.parse().ok()?, + start_line_new: start_line_new.parse::().ok()?.saturating_sub(1), + count_new: count_new.parse().ok()?, + }))) + } else if let Some(deleted_header) = line.strip_prefix("-") { + Some(Self::Deletion(deleted_header)) + } else if line.is_empty() { + Some(Self::Context("")) + } else if let Some(context) = line.strip_prefix(" ") { + Some(Self::Context(context)) + } else { + Some(Self::Addition(line.strip_prefix("+")?)) + } + } +} + +impl<'a> Display for DiffLine<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + DiffLine::OldPath { path } => write!(f, "--- {path}"), + DiffLine::NewPath { path } => write!(f, "+++ {path}"), + DiffLine::HunkHeader(Some(hunk_location)) => { + write!( + f, + "@@ -{},{} +{},{} @@", + hunk_location.start_line_old + 1, + hunk_location.count_old, + hunk_location.start_line_new + 1, + hunk_location.count_new + ) + } + DiffLine::HunkHeader(None) => write!(f, "@@ ... @@"), + DiffLine::Context(content) => write!(f, " {content}"), + DiffLine::Deletion(content) => write!(f, "-{content}"), + DiffLine::Addition(content) => write!(f, "+{content}"), + DiffLine::Garbage(line) => write!(f, "{line}"), + } + } +} + +fn parse_header_path<'a>(strip_prefix: &'static str, header: &'a str) -> Cow<'a, str> { + if !header.contains(['"', '\\']) { + let path = header.split_ascii_whitespace().next().unwrap_or(header); + return Cow::Borrowed(path.strip_prefix(strip_prefix).unwrap_or(path)); + } + + let mut path = String::with_capacity(header.len()); + let mut in_quote = false; + let mut chars = header.chars().peekable(); + let mut strip_prefix = Some(strip_prefix); + + while let Some(char) = chars.next() { + if char == '"' { + in_quote = !in_quote; + } else if char == '\\' { + let Some(&next_char) = chars.peek() else { + break; + }; + chars.next(); + path.push(next_char); + } else if char.is_ascii_whitespace() && !in_quote { + break; + } else { + path.push(char); + } + + if let Some(prefix) = strip_prefix + && path == prefix + { + strip_prefix.take(); + path.clear(); + } + } + + Cow::Owned(path) +} + +fn eat_required_whitespace(header: &str) -> Option<&str> { + let trimmed = header.trim_ascii_start(); + + if trimmed.len() == header.len() { + None + } else { + Some(trimmed) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use gpui::TestAppContext; + use indoc::indoc; + use language::Point; + use pretty_assertions::assert_eq; + use project::{FakeFs, Project}; + use serde_json::json; + use settings::SettingsStore; + use util::path; + + #[test] + fn parse_lines_simple() { + let input = indoc! {" + diff --git a/text.txt b/text.txt + index 86c770d..a1fd855 100644 + --- a/file.txt + +++ b/file.txt + @@ -1,2 +1,3 @@ + context + -deleted + +inserted + garbage + + --- b/file.txt + +++ a/file.txt + "}; + + let lines = input.lines().map(DiffLine::parse).collect::>(); + + pretty_assertions::assert_eq!( + lines, + &[ + DiffLine::Garbage("diff --git a/text.txt b/text.txt"), + DiffLine::Garbage("index 86c770d..a1fd855 100644"), + DiffLine::OldPath { + path: "file.txt".into() + }, + DiffLine::NewPath { + path: "file.txt".into() + }, + DiffLine::HunkHeader(Some(HunkLocation { + start_line_old: 0, + count_old: 2, + start_line_new: 0, + count_new: 3 + })), + DiffLine::Context("context"), + DiffLine::Deletion("deleted"), + DiffLine::Addition("inserted"), + DiffLine::Garbage("garbage"), + DiffLine::Context(""), + DiffLine::OldPath { + path: "b/file.txt".into() + }, + DiffLine::NewPath { + path: "a/file.txt".into() + }, + ] + ); + } + + #[test] + fn file_header_extra_space() { + let options = ["--- file", "--- file", "---\tfile"]; + + for option in options { + pretty_assertions::assert_eq!( + DiffLine::parse(option), + DiffLine::OldPath { + path: "file".into() + }, + "{option}", + ); + } + } + + #[test] + fn hunk_header_extra_space() { + let options = [ + "@@ -1,2 +1,3 @@", + "@@ -1,2 +1,3 @@", + "@@\t-1,2\t+1,3\t@@", + "@@ -1,2 +1,3 @@", + "@@ -1,2 +1,3 @@", + "@@ -1,2 +1,3 @@", + "@@ -1,2 +1,3 @@ garbage", + ]; + + for option in options { + pretty_assertions::assert_eq!( + DiffLine::parse(option), + DiffLine::HunkHeader(Some(HunkLocation { + start_line_old: 0, + count_old: 2, + start_line_new: 0, + count_new: 3 + })), + "{option}", + ); + } + } + + #[test] + fn hunk_header_without_location() { + pretty_assertions::assert_eq!(DiffLine::parse("@@ ... @@"), DiffLine::HunkHeader(None)); + } + + #[test] + fn test_parse_path() { + assert_eq!(parse_header_path("a/", "foo.txt"), "foo.txt"); + assert_eq!( + parse_header_path("a/", "foo/bar/baz.txt"), + "foo/bar/baz.txt" + ); + assert_eq!(parse_header_path("a/", "a/foo.txt"), "foo.txt"); + assert_eq!( + parse_header_path("a/", "a/foo/bar/baz.txt"), + "foo/bar/baz.txt" + ); + + // Extra + assert_eq!( + parse_header_path("a/", "a/foo/bar/baz.txt 2025"), + "foo/bar/baz.txt" + ); + assert_eq!( + parse_header_path("a/", "a/foo/bar/baz.txt\t2025"), + "foo/bar/baz.txt" + ); + assert_eq!( + parse_header_path("a/", "a/foo/bar/baz.txt \""), + "foo/bar/baz.txt" + ); + + // Quoted + assert_eq!( + parse_header_path("a/", "a/foo/bar/\"baz quox.txt\""), + "foo/bar/baz quox.txt" + ); + assert_eq!( + parse_header_path("a/", "\"a/foo/bar/baz quox.txt\""), + "foo/bar/baz quox.txt" + ); + assert_eq!( + parse_header_path("a/", "\"foo/bar/baz quox.txt\""), + "foo/bar/baz quox.txt" + ); + assert_eq!(parse_header_path("a/", "\"whatever 🤷\""), "whatever 🤷"); + assert_eq!( + parse_header_path("a/", "\"foo/bar/baz quox.txt\" 2025"), + "foo/bar/baz quox.txt" + ); + // unescaped quotes are dropped + assert_eq!(parse_header_path("a/", "foo/\"bar\""), "foo/bar"); + + // Escaped + assert_eq!( + parse_header_path("a/", "\"foo/\\\"bar\\\"/baz.txt\""), + "foo/\"bar\"/baz.txt" + ); + assert_eq!( + parse_header_path("a/", "\"C:\\\\Projects\\\\My App\\\\old file.txt\""), + "C:\\Projects\\My App\\old file.txt" + ); + } + + #[test] + fn test_parse_diff_with_leading_and_trailing_garbage() { + let diff = indoc! {" + I need to make some changes. + + I'll change the following things: + - one + - two + - three + + ``` + --- a/file.txt + +++ b/file.txt + one + +AND + two + ``` + + Summary of what I did: + - one + - two + - three + + That's about it. + "}; + + let mut events = Vec::new(); + let mut parser = DiffParser::new(diff); + while let Some(event) = parser.next().unwrap() { + events.push(event); + } + + assert_eq!( + events, + &[ + DiffEvent::Hunk { + path: "file.txt".into(), + hunk: Hunk { + context: "one\ntwo\n".into(), + edits: vec![Edit { + range: 4..4, + text: "AND\n".into() + }], + } + }, + DiffEvent::FileEnd { renamed_to: None } + ], + ) + } + + #[gpui::test] + async fn test_apply_diff_successful(cx: &mut TestAppContext) { + let fs = init_test(cx); + + let buffer_1_text = indoc! {r#" + one + two + three + four + five + "# }; + + let buffer_1_text_final = indoc! {r#" + 3 + 4 + 5 + "# }; + + let buffer_2_text = indoc! {r#" + six + seven + eight + nine + ten + "# }; + + let buffer_2_text_final = indoc! {r#" + 5 + six + seven + 7.5 + eight + nine + ten + 11 + "# }; + + fs.insert_tree( + path!("/root"), + json!({ + "file1": buffer_1_text, + "file2": buffer_2_text, + }), + ) + .await; + + let project = Project::test(fs, [path!("/root").as_ref()], cx).await; + + let diff = indoc! {r#" + --- a/root/file1 + +++ b/root/file1 + one + two + -three + +3 + four + five + --- a/root/file1 + +++ b/root/file1 + 3 + -four + -five + +4 + +5 + --- a/root/file1 + +++ b/root/file1 + -one + -two + 3 + 4 + --- a/root/file2 + +++ b/root/file2 + +5 + six + --- a/root/file2 + +++ b/root/file2 + seven + +7.5 + eight + --- a/root/file2 + +++ b/root/file2 + ten + +11 + "#}; + + let _buffers = apply_diff(diff, &project, &mut cx.to_async()) + .await + .unwrap(); + let buffer_1 = project + .update(cx, |project, cx| { + let project_path = project.find_project_path(path!("/root/file1"), cx).unwrap(); + project.open_buffer(project_path, cx) + }) + .await + .unwrap(); + + buffer_1.read_with(cx, |buffer, _cx| { + assert_eq!(buffer.text(), buffer_1_text_final); + }); + let buffer_2 = project + .update(cx, |project, cx| { + let project_path = project.find_project_path(path!("/root/file2"), cx).unwrap(); + project.open_buffer(project_path, cx) + }) + .await + .unwrap(); + + buffer_2.read_with(cx, |buffer, _cx| { + assert_eq!(buffer.text(), buffer_2_text_final); + }); + } + + #[gpui::test] + async fn test_apply_diff_non_unique(cx: &mut TestAppContext) { + let fs = init_test(cx); + + let buffer_1_text = indoc! {r#" + one + two + three + four + five + one + two + three + four + five + "# }; + + fs.insert_tree( + path!("/root"), + json!({ + "file1": buffer_1_text, + }), + ) + .await; + + let project = Project::test(fs, [path!("/root").as_ref()], cx).await; + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/root/file1"), cx) + }) + .await + .unwrap(); + let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); + + let diff = indoc! {r#" + --- a/root/file1 + +++ b/root/file1 + one + two + -three + +3 + four + five + "#}; + + let final_text = indoc! {r#" + one + two + three + four + five + one + two + 3 + four + five + "#}; + + apply_diff(diff, &project, &mut cx.to_async()) + .await + .expect_err("Non-unique edits should fail"); + + let ranges = [buffer_snapshot.anchor_before(Point::new(1, 0)) + ..buffer_snapshot.anchor_after(buffer_snapshot.max_point())]; + + let (edited_snapshot, edits) = parse_diff(diff, |_path| Some((&buffer_snapshot, &ranges))) + .await + .unwrap(); + + assert_eq!(edited_snapshot.remote_id(), buffer_snapshot.remote_id()); + buffer.update(cx, |buffer, cx| { + buffer.edit(edits, None, cx); + assert_eq!(buffer.text(), final_text); + }); + } + + #[gpui::test] + async fn test_parse_diff_with_edits_within_line(cx: &mut TestAppContext) { + let fs = init_test(cx); + + let buffer_1_text = indoc! {r#" + one two three four + five six seven eight + nine ten eleven twelve + "# }; + + fs.insert_tree( + path!("/root"), + json!({ + "file1": buffer_1_text, + }), + ) + .await; + + let project = Project::test(fs, [path!("/root").as_ref()], cx).await; + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/root/file1"), cx) + }) + .await + .unwrap(); + let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); + + let diff = indoc! {r#" + --- a/root/file1 + +++ b/root/file1 + one two three four + -five six seven eight + +five SIX seven eight! + nine ten eleven twelve + "#}; + + let (buffer, edits) = parse_diff(diff, |_path| { + Some((&buffer_snapshot, &[(Anchor::MIN..Anchor::MAX)] as &[_])) + }) + .await + .unwrap(); + + let edits = edits + .into_iter() + .map(|(range, text)| (range.to_point(&buffer), text)) + .collect::>(); + assert_eq!( + edits, + &[ + (Point::new(1, 5)..Point::new(1, 8), "SIX".into()), + (Point::new(1, 20)..Point::new(1, 20), "!".into()) + ] + ); + } + + #[gpui::test] + async fn test_apply_diff_unique_via_previous_context(cx: &mut TestAppContext) { + let fs = init_test(cx); + + let start = indoc! {r#" + one + two + three + four + five + + four + five + "# }; + + let end = indoc! {r#" + one + two + 3 + four + 5 + + four + five + "# }; + + fs.insert_tree( + path!("/root"), + json!({ + "file1": start, + }), + ) + .await; + + let project = Project::test(fs, [path!("/root").as_ref()], cx).await; + + let diff = indoc! {r#" + --- a/root/file1 + +++ b/root/file1 + one + two + -three + +3 + four + -five + +5 + "#}; + + let _buffers = apply_diff(diff, &project, &mut cx.to_async()) + .await + .unwrap(); + + let buffer_1 = project + .update(cx, |project, cx| { + let project_path = project.find_project_path(path!("/root/file1"), cx).unwrap(); + project.open_buffer(project_path, cx) + }) + .await + .unwrap(); + + buffer_1.read_with(cx, |buffer, _cx| { + assert_eq!(buffer.text(), end); + }); + } + + fn init_test(cx: &mut TestAppContext) -> Arc { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + Project::init_settings(cx); + language::init(cx); + }); + + FakeFs::new(cx.background_executor.clone()) + } +} diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index 92e64f7f332accddbca46ee631f64e5b14be376d..d7a645794fd8a21f8b70b3786b435a23a9babe63 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -6,7 +6,8 @@ use cloud_llm_client::{ AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME, }; -use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, build_prompt}; +use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES; +use cloud_zeta2_prompt::retrieval_prompt::SearchToolInput; use collections::HashMap; use edit_prediction_context::{ DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions, @@ -24,11 +25,13 @@ use gpui::{ use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint}; use language::{BufferSnapshot, OffsetRangeExt}; use language_model::{LlmApiToken, RefreshLlmTokenListener}; +use open_ai::FunctionDefinition; use project::Project; use release_channel::AppVersion; use serde::de::DeserializeOwned; use std::collections::{VecDeque, hash_map}; -use std::fmt::Write; +use uuid::Uuid; + use std::ops::Range; use std::path::Path; use std::str::FromStr as _; @@ -42,12 +45,12 @@ use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_noti pub mod merge_excerpts; mod prediction; mod provider; -pub mod related_excerpts; +pub mod retrieval_search; +pub mod udiff; use crate::merge_excerpts::merge_excerpts; use crate::prediction::EditPrediction; -use crate::related_excerpts::find_related_excerpts; -pub use crate::related_excerpts::{LlmContextOptions, SearchToolQuery}; +pub use crate::prediction::EditPredictionId; pub use provider::ZetaEditPredictionProvider; /// Maximum number of events to track. @@ -59,9 +62,10 @@ pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPrediction target_before_cursor_over_total_bytes: 0.5, }; -pub const DEFAULT_CONTEXT_OPTIONS: ContextMode = ContextMode::Llm(DEFAULT_LLM_CONTEXT_OPTIONS); +pub const DEFAULT_CONTEXT_OPTIONS: ContextMode = + ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS); -pub const DEFAULT_LLM_CONTEXT_OPTIONS: LlmContextOptions = LlmContextOptions { +pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions { excerpt: DEFAULT_EXCERPT_OPTIONS, }; @@ -122,14 +126,19 @@ pub struct ZetaOptions { #[derive(Debug, Clone, PartialEq)] pub enum ContextMode { - Llm(LlmContextOptions), + Agentic(AgenticContextOptions), Syntax(EditPredictionContextOptions), } +#[derive(Debug, Clone, PartialEq)] +pub struct AgenticContextOptions { + pub excerpt: EditPredictionExcerptOptions, +} + impl ContextMode { pub fn excerpt(&self) -> &EditPredictionExcerptOptions { match self { - ContextMode::Llm(options) => &options.excerpt, + ContextMode::Agentic(options) => &options.excerpt, ContextMode::Syntax(options) => &options.excerpt, } } @@ -140,9 +149,8 @@ pub enum ZetaDebugInfo { ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo), SearchQueriesGenerated(ZetaSearchQueryDebugInfo), SearchQueriesExecuted(ZetaContextRetrievalDebugInfo), - SearchResultsFiltered(ZetaContextRetrievalDebugInfo), ContextRetrievalFinished(ZetaContextRetrievalDebugInfo), - EditPredicted(ZetaEditPredictionDebugInfo), + EditPredictionRequested(ZetaEditPredictionDebugInfo), } #[derive(Debug)] @@ -165,14 +173,14 @@ pub struct ZetaEditPredictionDebugInfo { pub buffer: WeakEntity, pub position: language::Anchor, pub local_prompt: Result, - pub response_rx: oneshot::Receiver>, + pub response_rx: oneshot::Receiver<(Result, TimeDelta)>, } #[derive(Debug)] pub struct ZetaSearchQueryDebugInfo { pub project: Entity, pub timestamp: Instant, - pub queries: Vec, + pub regex_by_glob: HashMap, } pub type RequestDebugInfo = predict_edits_v3::DebugInfo; @@ -224,7 +232,7 @@ impl CurrentEditPrediction { { let (old_range, old_text) = &old_edits[0]; let (new_range, new_text) = &new_edits[0]; - new_range == old_range && new_text.starts_with(old_text) + new_range == old_range && new_text.starts_with(old_text.as_ref()) } else { true } @@ -539,7 +547,7 @@ impl Zeta { prediction, } = project_state.current_prediction.as_ref()?; - if prediction.targets_buffer(buffer.read(cx), cx) { + if prediction.targets_buffer(buffer.read(cx)) { Some(BufferEditPrediction::Local { prediction }) } else if *requested_by_buffer_id == buffer.entity_id() { Some(BufferEditPrediction::Jump { prediction }) @@ -639,7 +647,7 @@ impl Zeta { pub fn request_prediction( &mut self, project: &Entity, - buffer: &Entity, + active_buffer: &Entity, position: language::Anchor, cx: &mut Context, ) -> Task>> { @@ -651,8 +659,8 @@ impl Zeta { .read_with(cx, |index, _cx| index.state().clone()) }); let options = self.options.clone(); - let snapshot = buffer.read(cx).snapshot(); - let Some(excerpt_path) = snapshot + let active_snapshot = active_buffer.read(cx).snapshot(); + let Some(excerpt_path) = active_snapshot .file() .map(|path| -> Arc { path.full_path(cx).into() }) else { @@ -678,12 +686,13 @@ impl Zeta { }) .unwrap_or_default(); - let diagnostics = snapshot.diagnostic_sets().clone(); + let diagnostics = active_snapshot.diagnostic_sets().clone(); - let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| { - let mut path = f.worktree.read(cx).absolutize(&f.path); - if path.pop() { Some(path) } else { None } - }); + let parent_abs_path = + project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| { + let mut path = f.worktree.read(cx).absolutize(&f.path); + if path.pop() { Some(path) } else { None } + }); // TODO data collection let can_collect_data = cx.is_staff(); @@ -692,9 +701,10 @@ impl Zeta { .and_then(|project_state| project_state.context.as_ref()) .unwrap_or(&HashMap::default()) .iter() - .filter_map(|(buffer, ranges)| { - let buffer = buffer.read(cx); + .filter_map(|(buffer_entity, ranges)| { + let buffer = buffer_entity.read(cx); Some(( + buffer_entity.clone(), buffer.snapshot(), buffer.file()?.full_path(cx).into(), ranges.clone(), @@ -703,8 +713,7 @@ impl Zeta { .collect::>(); let request_task = cx.background_spawn({ - let snapshot = snapshot.clone(); - let buffer = buffer.clone(); + let active_buffer = active_buffer.clone(); async move { let index_state = if let Some(index_state) = index_state { Some(index_state.lock_owned().await) @@ -712,8 +721,8 @@ impl Zeta { None }; - let cursor_offset = position.to_offset(&snapshot); - let cursor_point = cursor_offset.to_point(&snapshot); + let cursor_offset = position.to_offset(&active_snapshot); + let cursor_point = cursor_offset.to_point(&active_snapshot); let before_retrieval = chrono::Utc::now(); @@ -721,29 +730,30 @@ impl Zeta { Self::gather_nearby_diagnostics( cursor_offset, &diagnostics, - &snapshot, + &active_snapshot, options.max_diagnostic_bytes, ); - let request = match options.context { - ContextMode::Llm(context_options) => { + let cloud_request = match options.context { + ContextMode::Agentic(context_options) => { let Some(excerpt) = EditPredictionExcerpt::select_from_buffer( cursor_point, - &snapshot, + &active_snapshot, &context_options.excerpt, index_state.as_deref(), ) else { return Ok((None, None)); }; - let excerpt_anchor_range = snapshot.anchor_after(excerpt.range.start) - ..snapshot.anchor_before(excerpt.range.end); + let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start) + ..active_snapshot.anchor_before(excerpt.range.end); - if let Some(buffer_ix) = included_files - .iter() - .position(|(buffer, _, _)| buffer.remote_id() == snapshot.remote_id()) + if let Some(buffer_ix) = + included_files.iter().position(|(_, snapshot, _, _)| { + snapshot.remote_id() == active_snapshot.remote_id() + }) { - let (buffer, _, ranges) = &mut included_files[buffer_ix]; + let (_, buffer, _, ranges) = &mut included_files[buffer_ix]; let range_ix = ranges .binary_search_by(|probe| { probe @@ -758,15 +768,16 @@ impl Zeta { included_files.swap(buffer_ix, last_ix); } else { included_files.push(( - snapshot, + active_buffer.clone(), + active_snapshot, excerpt_path.clone(), vec![excerpt_anchor_range], )); } let included_files = included_files - .into_iter() - .map(|(buffer, path, ranges)| { + .iter() + .map(|(_, buffer, path, ranges)| { let excerpts = merge_excerpts( &buffer, ranges.iter().map(|range| { @@ -775,7 +786,7 @@ impl Zeta { }), ); predict_edits_v3::IncludedFile { - path, + path: path.clone(), max_row: Line(buffer.max_point().row), excerpts, } @@ -809,7 +820,7 @@ impl Zeta { ContextMode::Syntax(context_options) => { let Some(context) = EditPredictionContext::gather_context( cursor_point, - &snapshot, + &active_snapshot, parent_abs_path.as_deref(), &context_options, index_state.as_deref(), @@ -834,24 +845,27 @@ impl Zeta { } }; + let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request); + let retrieval_time = chrono::Utc::now() - before_retrieval; let debug_response_tx = if let Some(debug_tx) = &debug_tx { let (response_tx, response_rx) = oneshot::channel(); - let local_prompt = build_prompt(&request) - .map(|(prompt, _)| prompt) - .map_err(|err| err.to_string()); - debug_tx - .unbounded_send(ZetaDebugInfo::EditPredicted(ZetaEditPredictionDebugInfo { - request: request.clone(), - retrieval_time, - buffer: buffer.downgrade(), - local_prompt, - position, - response_rx, - })) + .unbounded_send(ZetaDebugInfo::EditPredictionRequested( + ZetaEditPredictionDebugInfo { + request: cloud_request.clone(), + retrieval_time, + buffer: active_buffer.downgrade(), + local_prompt: match prompt_result.as_ref() { + Ok((prompt, _)) => Ok(prompt.clone()), + Err(err) => Err(err.to_string()), + }, + position, + response_rx, + }, + )) .ok(); Some(response_tx) } else { @@ -861,61 +875,144 @@ impl Zeta { if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() { if let Some(debug_response_tx) = debug_response_tx { debug_response_tx - .send(Err("Request skipped".to_string())) + .send((Err("Request skipped".to_string()), TimeDelta::zero())) .ok(); } anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set") } + let (prompt, _) = prompt_result?; + let request = open_ai::Request { + model: std::env::var("ZED_ZETA2_MODEL").unwrap_or("yqvev8r3".to_string()), + messages: vec![open_ai::RequestMessage::User { + content: open_ai::MessageContent::Plain(prompt), + }], + stream: false, + max_completion_tokens: None, + stop: Default::default(), + temperature: 0.7, + tool_choice: None, + parallel_tool_calls: None, + tools: vec![], + prompt_cache_key: None, + reasoning_effort: None, + }; + + log::trace!("Sending edit prediction request"); + + let before_request = chrono::Utc::now(); let response = - Self::send_prediction_request(client, llm_token, app_version, request).await; + Self::send_raw_llm_request(client, llm_token, app_version, request).await; + let request_time = chrono::Utc::now() - before_request; + + log::trace!("Got edit prediction response"); if let Some(debug_response_tx) = debug_response_tx { debug_response_tx - .send( + .send(( response .as_ref() .map_err(|err| err.to_string()) .map(|response| response.0.clone()), - ) + request_time, + )) .ok(); } - response.map(|(res, usage)| (Some(res), usage)) + let (mut res, usage) = response?; + + let request_id = EditPredictionId(Uuid::from_str(&res.id)?); + + let Some(choice) = res.choices.pop() else { + return Ok((None, usage)); + }; + + let output_text = match choice.message { + open_ai::RequestMessage::Assistant { + content: Some(open_ai::MessageContent::Plain(content)), + .. + } => content, + open_ai::RequestMessage::Assistant { + content: Some(open_ai::MessageContent::Multipart(mut content)), + .. + } => { + if content.is_empty() { + log::error!("No output from Baseten completion response"); + return Ok((None, usage)); + } + + match content.remove(0) { + open_ai::MessagePart::Text { text } => text, + open_ai::MessagePart::Image { .. } => { + log::error!("Expected text, got an image"); + return Ok((None, usage)); + } + } + } + _ => { + log::error!("Invalid response message: {:?}", choice.message); + return Ok((None, usage)); + } + }; + + let (edited_buffer_snapshot, edits) = + crate::udiff::parse_diff(&output_text, |path| { + included_files + .iter() + .find_map(|(_, buffer, probe_path, ranges)| { + if probe_path.as_ref() == path { + Some((buffer, ranges.as_slice())) + } else { + None + } + }) + }) + .await?; + + let edited_buffer = included_files + .iter() + .find_map(|(buffer, snapshot, _, _)| { + if snapshot.remote_id() == edited_buffer_snapshot.remote_id() { + Some(buffer.clone()) + } else { + None + } + }) + .context("Failed to find buffer in included_buffers, even though we just found the snapshot")?; + + anyhow::Ok((Some((request_id, edited_buffer, edited_buffer_snapshot.clone(), edits)), usage)) } }); - let buffer = buffer.clone(); - cx.spawn({ - let project = project.clone(); async move |this, cx| { - let Some(response) = Self::handle_api_response(&this, request_task.await, cx)? + let Some((id, edited_buffer, edited_buffer_snapshot, edits)) = + Self::handle_api_response(&this, request_task.await, cx)? else { return Ok(None); }; // TODO telemetry: duration, etc - Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await) + Ok( + EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx) + .await, + ) } }) } - async fn send_prediction_request( + async fn send_raw_llm_request( client: Arc, llm_token: LlmApiToken, app_version: SemanticVersion, - request: predict_edits_v3::PredictEditsRequest, - ) -> Result<( - predict_edits_v3::PredictEditsResponse, - Option, - )> { + request: open_ai::Request, + ) -> Result<(open_ai::Response, Option)> { let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") { http_client::Url::parse(&predict_edits_url)? } else { client .http_client() - .build_zed_llm_url("/predict_edits/v3", &[])? + .build_zed_llm_url("/predict_edits/raw", &[])? }; Self::send_api_request( @@ -1052,7 +1149,7 @@ impl Zeta { cursor_position: language::Anchor, cx: &mut Context, ) { - if !matches!(&self.options().context, ContextMode::Llm { .. }) { + if !matches!(&self.options().context, ContextMode::Agentic { .. }) { return; } @@ -1100,36 +1197,149 @@ impl Zeta { cursor_position: language::Anchor, cx: &mut Context, ) -> Task> { + let Some(zeta_project) = self.projects.get(&project.entity_id()) else { + return Task::ready(anyhow::Ok(())); + }; + + let ContextMode::Agentic(options) = &self.options().context else { + return Task::ready(anyhow::Ok(())); + }; + + let snapshot = buffer.read(cx).snapshot(); + let cursor_point = cursor_position.to_point(&snapshot); + let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer( + cursor_point, + &snapshot, + &options.excerpt, + None, + ) else { + return Task::ready(Ok(())); + }; + + let current_file_path: Arc = snapshot + .file() + .map(|f| f.full_path(cx).into()) + .unwrap_or_else(|| Path::new("untitled").into()); + + let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt( + predict_edits_v3::PlanContextRetrievalRequest { + excerpt: cursor_excerpt.text(&snapshot).body, + excerpt_path: current_file_path, + excerpt_line_range: cursor_excerpt.line_range, + cursor_file_max_row: Line(snapshot.max_point().row), + events: zeta_project + .events + .iter() + .filter_map(|ev| ev.to_request_event(cx)) + .collect(), + }, + ) { + Ok(prompt) => prompt, + Err(err) => { + return Task::ready(Err(err)); + } + }; + + let app_version = AppVersion::global(cx); + let client = self.client.clone(); + let llm_token = self.llm_token.clone(); + let debug_tx = self.debug_tx.clone(); + + let (tool_schema, tool_description) = &*cloud_zeta2_prompt::retrieval_prompt::TOOL_SCHEMA; + + let request = open_ai::Request { + model: std::env::var("ZED_ZETA2_MODEL").unwrap_or("2327jz9q".to_string()), + messages: vec![open_ai::RequestMessage::User { + content: open_ai::MessageContent::Plain(prompt), + }], + stream: false, + max_completion_tokens: None, + stop: Default::default(), + temperature: 0.7, + tool_choice: None, + parallel_tool_calls: None, + tools: vec![open_ai::ToolDefinition::Function { + function: FunctionDefinition { + name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(), + description: Some(tool_description.clone()), + parameters: Some(tool_schema.clone()), + }, + }], + prompt_cache_key: None, + reasoning_effort: None, + }; + cx.spawn(async move |this, cx| { - let related_excerpts_result = this - .update(cx, |this, cx| { - let Some(zeta_project) = this.projects.get(&project.entity_id()) else { - return Task::ready(anyhow::Ok(HashMap::default())); - }; + log::trace!("Sending search planning request"); + let response = + Self::send_raw_llm_request(client, llm_token, app_version, request).await; + let mut response = Self::handle_api_response(&this, response, cx)?; + + log::trace!("Got search planning response"); + + let choice = response + .choices + .pop() + .context("No choices in retrieval response")?; + let open_ai::RequestMessage::Assistant { + content: _, + tool_calls, + } = choice.message + else { + anyhow::bail!("Retrieval response didn't include an assistant message"); + }; - let ContextMode::Llm(options) = &this.options().context else { - return Task::ready(anyhow::Ok(HashMap::default())); - }; + let mut regex_by_glob: HashMap = HashMap::default(); + for tool_call in tool_calls { + let open_ai::ToolCallContent::Function { function } = tool_call.content; + if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME { + log::warn!( + "Context retrieval response tried to call an unknown tool: {}", + function.name + ); - let mut edit_history_unified_diff = String::new(); + continue; + } - for event in zeta_project.events.iter() { - if let Some(event) = event.to_request_event(cx) { - writeln!(&mut edit_history_unified_diff, "{event}").ok(); - } + let input: SearchToolInput = serde_json::from_str(&function.arguments)?; + for query in input.queries { + let regex = regex_by_glob.entry(query.glob).or_default(); + if !regex.is_empty() { + regex.push('|'); } + regex.push_str(&query.regex); + } + } - find_related_excerpts( - buffer.clone(), - cursor_position, - &project, - edit_history_unified_diff, - options, - this.debug_tx.clone(), - cx, - ) - })? - .await; + if let Some(debug_tx) = &debug_tx { + debug_tx + .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated( + ZetaSearchQueryDebugInfo { + project: project.clone(), + timestamp: Instant::now(), + regex_by_glob: regex_by_glob.clone(), + }, + )) + .ok(); + } + + log::trace!("Running retrieval search: {regex_by_glob:#?}"); + + let related_excerpts_result = + retrieval_search::run_retrieval_searches(project.clone(), regex_by_glob, cx).await; + + log::trace!("Search queries executed"); + + if let Some(debug_tx) = &debug_tx { + debug_tx + .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted( + ZetaContextRetrievalDebugInfo { + project: project.clone(), + timestamp: Instant::now(), + }, + )) + .ok(); + } this.update(cx, |this, _cx| { let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else { @@ -1249,7 +1459,7 @@ impl Zeta { &snapshot, parent_abs_path.as_deref(), match &options.context { - ContextMode::Llm(_) => { + ContextMode::Agentic(_) => { // TODO panic!("Llm mode not supported in zeta cli yet"); } @@ -1426,15 +1636,11 @@ fn add_signature( #[cfg(test)] mod tests { - use std::{ - path::{Path, PathBuf}, - sync::Arc, - }; + use std::{path::Path, sync::Arc}; use client::UserStore; use clock::FakeSystemClock; - use cloud_llm_client::predict_edits_v3::{self, Point}; - use edit_prediction_context::Line; + use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery}; use futures::{ AsyncReadExt, StreamExt, channel::{mpsc, oneshot}, @@ -1445,7 +1651,8 @@ mod tests { prelude::*, }; use indoc::indoc; - use language::{LanguageServerId, OffsetRangeExt as _}; + use language::OffsetRangeExt as _; + use open_ai::Usage; use pretty_assertions::{assert_eq, assert_matches}; use project::{FakeFs, Project}; use serde_json::json; @@ -1462,8 +1669,8 @@ mod tests { fs.insert_tree( "/root", json!({ - "1.txt": "Hello!\nHow\nBye", - "2.txt": "Hola!\nComo\nAdios" + "1.txt": "Hello!\nHow\nBye\n", + "2.txt": "Hola!\nComo\nAdios\n" }), ) .await; @@ -1489,16 +1696,17 @@ mod tests { zeta.refresh_prediction(&project, &buffer1, position, cx) }); let (_request, respond_tx) = req_rx.next().await.unwrap(); + respond_tx - .send(predict_edits_v3::PredictEditsResponse { - request_id: Uuid::new_v4(), - edits: vec![predict_edits_v3::Edit { - path: Path::new(path!("root/1.txt")).into(), - range: Line(0)..Line(snapshot1.max_point().row + 1), - content: "Hello!\nHow are you?\nBye".into(), - }], - debug_info: None, - }) + .send(model_response(indoc! {r" + --- a/root/1.txt + +++ b/root/1.txt + @@ ... @@ + Hello! + -How + +How are you? + Bye + "})) .unwrap(); prediction_task.await.unwrap(); @@ -1509,21 +1717,67 @@ mod tests { assert_matches!(prediction, BufferEditPrediction::Local { .. }); }); + // Context refresh + let refresh_task = zeta.update(cx, |zeta, cx| { + zeta.refresh_context(project.clone(), buffer1.clone(), position, cx) + }); + let (_request, respond_tx) = req_rx.next().await.unwrap(); + respond_tx + .send(open_ai::Response { + id: Uuid::new_v4().to_string(), + object: "response".into(), + created: 0, + model: "model".into(), + choices: vec![open_ai::Choice { + index: 0, + message: open_ai::RequestMessage::Assistant { + content: None, + tool_calls: vec![open_ai::ToolCall { + id: "search".into(), + content: open_ai::ToolCallContent::Function { + function: open_ai::FunctionContent { + name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME + .to_string(), + arguments: serde_json::to_string(&SearchToolInput { + queries: Box::new([SearchToolQuery { + glob: "root/2.txt".to_string(), + regex: ".".to_string(), + }]), + }) + .unwrap(), + }, + }, + }], + }, + finish_reason: None, + }], + usage: Usage { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0, + }, + }) + .unwrap(); + refresh_task.await.unwrap(); + + zeta.update(cx, |zeta, _cx| { + zeta.discard_current_prediction(&project); + }); + // Prediction for another file let prediction_task = zeta.update(cx, |zeta, cx| { zeta.refresh_prediction(&project, &buffer1, position, cx) }); let (_request, respond_tx) = req_rx.next().await.unwrap(); respond_tx - .send(predict_edits_v3::PredictEditsResponse { - request_id: Uuid::new_v4(), - edits: vec![predict_edits_v3::Edit { - path: Path::new(path!("root/2.txt")).into(), - range: Line(0)..Line(snapshot1.max_point().row + 1), - content: "Hola!\nComo estas?\nAdios".into(), - }], - debug_info: None, - }) + .send(model_response(indoc! {r#" + --- a/root/2.txt + +++ b/root/2.txt + Hola! + -Como + +Como estas? + Adios + "#})) .unwrap(); prediction_task.await.unwrap(); zeta.read_with(cx, |zeta, cx| { @@ -1532,7 +1786,7 @@ mod tests { .unwrap(); assert_matches!( prediction, - BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt")) + BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt")) ); }); @@ -1559,7 +1813,7 @@ mod tests { fs.insert_tree( "/root", json!({ - "foo.md": "Hello!\nHow\nBye" + "foo.md": "Hello!\nHow\nBye\n" }), ) .await; @@ -1579,29 +1833,31 @@ mod tests { zeta.request_prediction(&project, &buffer, position, cx) }); - let (request, respond_tx) = req_rx.next().await.unwrap(); - assert_eq!( - request.excerpt_path.as_ref(), - Path::new(path!("root/foo.md")) - ); - assert_eq!( - request.cursor_point, - Point { - line: Line(1), - column: 3 - } - ); + let (_, respond_tx) = req_rx.next().await.unwrap(); + + // TODO Put back when we have a structured request again + // assert_eq!( + // request.excerpt_path.as_ref(), + // Path::new(path!("root/foo.md")) + // ); + // assert_eq!( + // request.cursor_point, + // Point { + // line: Line(1), + // column: 3 + // } + // ); respond_tx - .send(predict_edits_v3::PredictEditsResponse { - request_id: Uuid::new_v4(), - edits: vec![predict_edits_v3::Edit { - path: Path::new(path!("root/foo.md")).into(), - range: Line(0)..Line(snapshot.max_point().row + 1), - content: "Hello!\nHow are you?\nBye".into(), - }], - debug_info: None, - }) + .send(model_response(indoc! { r" + --- a/root/foo.md + +++ b/root/foo.md + @@ ... @@ + Hello! + -How + +How are you? + Bye + "})) .unwrap(); let prediction = prediction_task.await.unwrap().unwrap(); @@ -1611,7 +1867,7 @@ mod tests { prediction.edits[0].0.to_point(&snapshot).start, language::Point::new(1, 3) ); - assert_eq!(prediction.edits[0].1, " are you?"); + assert_eq!(prediction.edits[0].1.as_ref(), " are you?"); } #[gpui::test] @@ -1621,7 +1877,7 @@ mod tests { fs.insert_tree( "/root", json!({ - "foo.md": "Hello!\n\nBye" + "foo.md": "Hello!\n\nBye\n" }), ) .await; @@ -1652,34 +1908,30 @@ mod tests { let (request, respond_tx) = req_rx.next().await.unwrap(); - assert_eq!(request.events.len(), 1); - assert_eq!( - request.events[0], - predict_edits_v3::Event::BufferChange { - path: Some(PathBuf::from(path!("root/foo.md"))), - old_path: None, - diff: indoc! {" - @@ -1,3 +1,3 @@ - Hello! - - - +How - Bye - "} - .to_string(), - predicted: false - } + let prompt = prompt_from_request(&request); + assert!( + prompt.contains(indoc! {" + --- a/root/foo.md + +++ b/root/foo.md + @@ -1,3 +1,3 @@ + Hello! + - + +How + Bye + "}), + "{prompt}" ); respond_tx - .send(predict_edits_v3::PredictEditsResponse { - request_id: Uuid::new_v4(), - edits: vec![predict_edits_v3::Edit { - path: Path::new(path!("root/foo.md")).into(), - range: Line(0)..Line(snapshot.max_point().row + 1), - content: "Hello!\nHow are you?\nBye".into(), - }], - debug_info: None, - }) + .send(model_response(indoc! {r#" + --- a/root/foo.md + +++ b/root/foo.md + @@ ... @@ + Hello! + -How + +How are you? + Bye + "#})) .unwrap(); let prediction = prediction_task.await.unwrap().unwrap(); @@ -1689,114 +1941,150 @@ mod tests { prediction.edits[0].0.to_point(&snapshot).start, language::Point::new(1, 3) ); - assert_eq!(prediction.edits[0].1, " are you?"); + assert_eq!(prediction.edits[0].1.as_ref(), " are you?"); } - #[gpui::test] - async fn test_request_diagnostics(cx: &mut TestAppContext) { - let (zeta, mut req_rx) = init_test(cx); - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "foo.md": "Hello!\nBye" - }), - ) - .await; - let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + // Skipped until we start including diagnostics in prompt + // #[gpui::test] + // async fn test_request_diagnostics(cx: &mut TestAppContext) { + // let (zeta, mut req_rx) = init_test(cx); + // let fs = FakeFs::new(cx.executor()); + // fs.insert_tree( + // "/root", + // json!({ + // "foo.md": "Hello!\nBye" + // }), + // ) + // .await; + // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap(); + // let diagnostic = lsp::Diagnostic { + // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)), + // severity: Some(lsp::DiagnosticSeverity::ERROR), + // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(), + // ..Default::default() + // }; + + // project.update(cx, |project, cx| { + // project.lsp_store().update(cx, |lsp_store, cx| { + // // Create some diagnostics + // lsp_store + // .update_diagnostics( + // LanguageServerId(0), + // lsp::PublishDiagnosticsParams { + // uri: path_to_buffer_uri.clone(), + // diagnostics: vec![diagnostic], + // version: None, + // }, + // None, + // language::DiagnosticSourceKind::Pushed, + // &[], + // cx, + // ) + // .unwrap(); + // }); + // }); + + // let buffer = project + // .update(cx, |project, cx| { + // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + // project.open_buffer(path, cx) + // }) + // .await + // .unwrap(); + + // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + // let position = snapshot.anchor_before(language::Point::new(0, 0)); + + // let _prediction_task = zeta.update(cx, |zeta, cx| { + // zeta.request_prediction(&project, &buffer, position, cx) + // }); + + // let (request, _respond_tx) = req_rx.next().await.unwrap(); + + // assert_eq!(request.diagnostic_groups.len(), 1); + // let value = serde_json::from_str::(request.diagnostic_groups[0].0.get()) + // .unwrap(); + // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3 + // assert_eq!( + // value, + // json!({ + // "entries": [{ + // "range": { + // "start": 8, + // "end": 10 + // }, + // "diagnostic": { + // "source": null, + // "code": null, + // "code_description": null, + // "severity": 1, + // "message": "\"Hello\" deprecated. Use \"Hi\" instead", + // "markdown": null, + // "group_id": 0, + // "is_primary": true, + // "is_disk_based": false, + // "is_unnecessary": false, + // "source_kind": "Pushed", + // "data": null, + // "underline": true + // } + // }], + // "primary_ix": 0 + // }) + // ); + // } + + fn model_response(text: &str) -> open_ai::Response { + open_ai::Response { + id: Uuid::new_v4().to_string(), + object: "response".into(), + created: 0, + model: "model".into(), + choices: vec![open_ai::Choice { + index: 0, + message: open_ai::RequestMessage::Assistant { + content: Some(open_ai::MessageContent::Plain(text.to_string())), + tool_calls: vec![], + }, + finish_reason: None, + }], + usage: Usage { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0, + }, + } + } - let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap(); - let diagnostic = lsp::Diagnostic { - range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)), - severity: Some(lsp::DiagnosticSeverity::ERROR), - message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(), - ..Default::default() + fn prompt_from_request(request: &open_ai::Request) -> &str { + assert_eq!(request.messages.len(), 1); + let open_ai::RequestMessage::User { + content: open_ai::MessageContent::Plain(content), + .. + } = &request.messages[0] + else { + panic!( + "Request does not have single user message of type Plain. {:#?}", + request + ); }; - - project.update(cx, |project, cx| { - project.lsp_store().update(cx, |lsp_store, cx| { - // Create some diagnostics - lsp_store - .update_diagnostics( - LanguageServerId(0), - lsp::PublishDiagnosticsParams { - uri: path_to_buffer_uri.clone(), - diagnostics: vec![diagnostic], - version: None, - }, - None, - language::DiagnosticSourceKind::Pushed, - &[], - cx, - ) - .unwrap(); - }); - }); - - let buffer = project - .update(cx, |project, cx| { - let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); - project.open_buffer(path, cx) - }) - .await - .unwrap(); - - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); - let position = snapshot.anchor_before(language::Point::new(0, 0)); - - let _prediction_task = zeta.update(cx, |zeta, cx| { - zeta.request_prediction(&project, &buffer, position, cx) - }); - - let (request, _respond_tx) = req_rx.next().await.unwrap(); - - assert_eq!(request.diagnostic_groups.len(), 1); - let value = serde_json::from_str::(request.diagnostic_groups[0].0.get()) - .unwrap(); - // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3 - assert_eq!( - value, - json!({ - "entries": [{ - "range": { - "start": 8, - "end": 10 - }, - "diagnostic": { - "source": null, - "code": null, - "code_description": null, - "severity": 1, - "message": "\"Hello\" deprecated. Use \"Hi\" instead", - "markdown": null, - "group_id": 0, - "is_primary": true, - "is_disk_based": false, - "is_unnecessary": false, - "source_kind": "Pushed", - "data": null, - "underline": true - } - }], - "primary_ix": 0 - }) - ); + content } fn init_test( cx: &mut TestAppContext, ) -> ( Entity, - mpsc::UnboundedReceiver<( - predict_edits_v3::PredictEditsRequest, - oneshot::Sender, - )>, + mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender)>, ) { cx.update(move |cx| { let settings_store = SettingsStore::test(cx); cx.set_global(settings_store); language::init(cx); Project::init_settings(cx); + zlog::init_test(); let (req_tx, req_rx) = mpsc::unbounded(); @@ -1811,7 +2099,7 @@ mod tests { "token": "test" })) .unwrap(), - "/predict_edits/v3" => { + "/predict_edits/raw" => { let mut buf = Vec::new(); body.read_to_end(&mut buf).await.ok(); let req = serde_json::from_slice(&buf).unwrap(); diff --git a/crates/zeta2_tools/Cargo.toml b/crates/zeta2_tools/Cargo.toml index 703dbd08b20184c6cd09f68e41cabbc296309483..89d0ce8338624906d2262a7d8314700f6720cff1 100644 --- a/crates/zeta2_tools/Cargo.toml +++ b/crates/zeta2_tools/Cargo.toml @@ -27,10 +27,11 @@ log.workspace = true multi_buffer.workspace = true ordered-float.workspace = true project.workspace = true +regex-syntax = "0.8.8" serde.workspace = true +serde_json.workspace = true telemetry.workspace = true text.workspace = true -regex-syntax = "0.8.8" ui.workspace = true ui_input.workspace = true util.workspace = true diff --git a/crates/zeta2_tools/src/zeta2_context_view.rs b/crates/zeta2_tools/src/zeta2_context_view.rs index 9532d77622645f80696d69ed92b0190e48f838c7..685029cc4a2993227725c17e283c660da5c1d5e1 100644 --- a/crates/zeta2_tools/src/zeta2_context_view.rs +++ b/crates/zeta2_tools/src/zeta2_context_view.rs @@ -45,7 +45,6 @@ struct RetrievalRun { started_at: Instant, search_results_generated_at: Option, search_results_executed_at: Option, - search_results_filtered_at: Option, finished_at: Option, } @@ -117,17 +116,12 @@ impl Zeta2ContextView { self.handle_search_queries_executed(info, window, cx); } } - ZetaDebugInfo::SearchResultsFiltered(info) => { - if info.project == self.project { - self.handle_search_results_filtered(info, window, cx); - } - } ZetaDebugInfo::ContextRetrievalFinished(info) => { if info.project == self.project { self.handle_context_retrieval_finished(info, window, cx); } } - ZetaDebugInfo::EditPredicted(_) => {} + ZetaDebugInfo::EditPredictionRequested(_) => {} } } @@ -159,7 +153,6 @@ impl Zeta2ContextView { started_at: info.timestamp, search_results_generated_at: None, search_results_executed_at: None, - search_results_filtered_at: None, finished_at: None, }); @@ -218,18 +211,18 @@ impl Zeta2ContextView { run.search_results_generated_at = Some(info.timestamp); run.search_queries = info - .queries + .regex_by_glob .into_iter() - .map(|query| { + .map(|(glob, regex)| { let mut regex_parser = regex_syntax::ast::parse::Parser::new(); GlobQueries { - glob: query.glob, - alternations: match regex_parser.parse(&query.regex) { + glob, + alternations: match regex_parser.parse(®ex) { Ok(regex_syntax::ast::Ast::Alternation(ref alt)) => { alt.asts.iter().map(|ast| ast.to_string()).collect() } - _ => vec![query.regex], + _ => vec![regex], }, } }) @@ -256,20 +249,6 @@ impl Zeta2ContextView { cx.notify(); } - fn handle_search_results_filtered( - &mut self, - info: ZetaContextRetrievalDebugInfo, - _window: &mut Window, - cx: &mut Context, - ) { - let Some(run) = self.runs.back_mut() else { - return; - }; - - run.search_results_filtered_at = Some(info.timestamp); - cx.notify(); - } - fn handle_go_back( &mut self, _: &Zeta2ContextGoBack, @@ -398,19 +377,10 @@ impl Zeta2ContextView { }; div = div.child(format!("Ran search: {:>5} ms", (t2 - t1).as_millis())); - let Some(t3) = run.search_results_filtered_at else { - return pending_message(div, "Filtering results..."); - }; - div = - div.child(format!("Filtered results: {:>5} ms", (t3 - t2).as_millis())); - - let Some(t4) = run.finished_at else { - return pending_message(div, "Building excerpts"); - }; - div = div - .child(format!("Build excerpts: {:>5} µs", (t4 - t3).as_micros())) - .child(format!("Total: {:>5} ms", (t4 - t0).as_millis())); - div + div.child(format!( + "Total: {:>5} ms", + (run.finished_at.unwrap_or(t0) - t0).as_millis() + )) }), ) } diff --git a/crates/zeta2_tools/src/zeta2_tools.rs b/crates/zeta2_tools/src/zeta2_tools.rs index 89f9dcd5e318c5c21d0121a52b1f39a4f1bd8848..756fff5d621a85f7936a980d71f68c87098c4539 100644 --- a/crates/zeta2_tools/src/zeta2_tools.rs +++ b/crates/zeta2_tools/src/zeta2_tools.rs @@ -5,7 +5,7 @@ use std::{cmp::Reverse, path::PathBuf, str::FromStr, sync::Arc, time::Duration}; use chrono::TimeDelta; use client::{Client, UserStore}; use cloud_llm_client::predict_edits_v3::{ - self, DeclarationScoreComponents, PredictEditsRequest, PredictEditsResponse, PromptFormat, + DeclarationScoreComponents, PredictEditsRequest, PromptFormat, }; use collections::HashMap; use editor::{Editor, EditorEvent, EditorMode, ExcerptRange, MultiBuffer}; @@ -23,7 +23,7 @@ use ui_input::InputField; use util::{ResultExt, paths::PathStyle, rel_path::RelPath}; use workspace::{Item, SplitDirection, Workspace}; use zeta2::{ - ContextMode, DEFAULT_SYNTAX_CONTEXT_OPTIONS, LlmContextOptions, Zeta, Zeta2FeatureFlag, + AgenticContextOptions, ContextMode, DEFAULT_SYNTAX_CONTEXT_OPTIONS, Zeta, Zeta2FeatureFlag, ZetaDebugInfo, ZetaEditPredictionDebugInfo, ZetaOptions, }; @@ -123,6 +123,7 @@ struct LastPrediction { context_editor: Entity, prompt_editor: Entity, retrieval_time: TimeDelta, + request_time: Option, buffer: WeakEntity, position: language::Anchor, state: LastPredictionState, @@ -143,7 +144,7 @@ enum LastPredictionState { model_response_editor: Entity, feedback_editor: Entity, feedback: Option, - response: predict_edits_v3::PredictEditsResponse, + request_id: String, }, Failed { message: String, @@ -217,7 +218,7 @@ impl Zeta2Inspector { }); match &options.context { - ContextMode::Llm(_) => { + ContextMode::Agentic(_) => { self.context_mode = ContextModeState::Llm; } ContextMode::Syntax(_) => { @@ -307,9 +308,11 @@ impl Zeta2Inspector { }; let context = match zeta_options.context { - ContextMode::Llm(_context_options) => ContextMode::Llm(LlmContextOptions { - excerpt: excerpt_options, - }), + ContextMode::Agentic(_context_options) => { + ContextMode::Agentic(AgenticContextOptions { + excerpt: excerpt_options, + }) + } ContextMode::Syntax(context_options) => { let max_retrieved_declarations = match &this.context_mode { ContextModeState::Llm => { @@ -368,7 +371,7 @@ impl Zeta2Inspector { let language_registry = self.project.read(cx).languages().clone(); async move |this, cx| { let mut languages = HashMap::default(); - let ZetaDebugInfo::EditPredicted(prediction) = prediction else { + let ZetaDebugInfo::EditPredictionRequested(prediction) = prediction else { return; }; for ext in prediction @@ -396,6 +399,8 @@ impl Zeta2Inspector { .await .log_err(); + let json_language = language_registry.language_for_name("Json").await.log_err(); + this.update_in(cx, |this, window, cx| { let context_editor = cx.new(|cx| { let mut excerpt_score_components = HashMap::default(); @@ -492,25 +497,15 @@ impl Zeta2Inspector { let task = cx.spawn_in(window, { let markdown_language = markdown_language.clone(); + let json_language = json_language.clone(); async move |this, cx| { let response = response_rx.await; this.update_in(cx, |this, window, cx| { if let Some(prediction) = this.last_prediction.as_mut() { prediction.state = match response { - Ok(Ok(response)) => { - if let Some(debug_info) = &response.debug_info { - prediction.prompt_editor.update( - cx, - |prompt_editor, cx| { - prompt_editor.set_text( - debug_info.prompt.as_str(), - window, - cx, - ); - }, - ); - } + Ok((Ok(response), request_time)) => { + prediction.request_time = Some(request_time); let feedback_editor = cx.new(|cx| { let buffer = cx.new(|cx| { @@ -577,16 +572,11 @@ impl Zeta2Inspector { model_response_editor: cx.new(|cx| { let buffer = cx.new(|cx| { let mut buffer = Buffer::local( - response - .debug_info - .as_ref() - .map(|p| p.model_response.as_str()) - .unwrap_or( - "(Debug info not available)", - ), + serde_json::to_string_pretty(&response) + .unwrap_or_default(), cx, ); - buffer.set_language(markdown_language, cx); + buffer.set_language(json_language, cx); buffer }); let buffer = cx.new(|cx| { @@ -607,10 +597,11 @@ impl Zeta2Inspector { }), feedback_editor, feedback: None, - response, + request_id: response.id.clone(), } } - Ok(Err(err)) => { + Ok((Err(err), request_time)) => { + prediction.request_time = Some(request_time); LastPredictionState::Failed { message: err } } Err(oneshot::Canceled) => LastPredictionState::Failed { @@ -644,6 +635,7 @@ impl Zeta2Inspector { editor }), retrieval_time, + request_time: None, buffer, position, state: LastPredictionState::Requested, @@ -700,7 +692,7 @@ impl Zeta2Inspector { feedback: feedback_state, feedback_editor, model_response_editor, - response, + request_id, .. } = &mut last_prediction.state else { @@ -734,11 +726,10 @@ impl Zeta2Inspector { telemetry::event!( "Zeta2 Prediction Rated", - id = response.request_id, + id = request_id, kind = kind, text = text, request = last_prediction.request, - response = response, project_snapshot = project_snapshot, ); }) @@ -834,11 +825,11 @@ impl Zeta2Inspector { let current_options = this.zeta.read(cx).options().clone(); match current_options.context.clone() { - ContextMode::Llm(_) => {} + ContextMode::Agentic(_) => {} ContextMode::Syntax(context_options) => { let options = ZetaOptions { - context: ContextMode::Llm( - LlmContextOptions { + context: ContextMode::Agentic( + AgenticContextOptions { excerpt: context_options.excerpt, }, ), @@ -865,7 +856,7 @@ impl Zeta2Inspector { let current_options = this.zeta.read(cx).options().clone(); match current_options.context.clone() { - ContextMode::Llm(context_options) => { + ContextMode::Agentic(context_options) => { let options = ZetaOptions { context: ContextMode::Syntax( EditPredictionContextOptions { @@ -976,25 +967,6 @@ impl Zeta2Inspector { return None; }; - let (prompt_planning_time, inference_time, parsing_time) = - if let LastPredictionState::Success { - response: - PredictEditsResponse { - debug_info: Some(debug_info), - .. - }, - .. - } = &prediction.state - { - ( - Some(debug_info.prompt_planning_time), - Some(debug_info.inference_time), - Some(debug_info.parsing_time), - ) - } else { - (None, None, None) - }; - Some( v_flex() .p_4() @@ -1005,12 +977,7 @@ impl Zeta2Inspector { "Context retrieval", Some(prediction.retrieval_time), )) - .child(Self::render_duration( - "Prompt planning", - prompt_planning_time, - )) - .child(Self::render_duration("Inference", inference_time)) - .child(Self::render_duration("Parsing", parsing_time)), + .child(Self::render_duration("Request", prediction.request_time)), ) } diff --git a/crates/zeta_cli/src/evaluate.rs b/crates/zeta_cli/src/evaluate.rs index 6b5e2d0eecb8ca5e38ca233254aa1b0271448a11..5ffdd8ccff6601cf99b2bb3237f46cab224b0daf 100644 --- a/crates/zeta_cli/src/evaluate.rs +++ b/crates/zeta_cli/src/evaluate.rs @@ -7,13 +7,14 @@ use std::{ use anyhow::Result; use clap::Args; -use cloud_llm_client::udiff::DiffLine; use collections::HashSet; use gpui::AsyncApp; +use zeta2::udiff::DiffLine; use crate::{ example::{Example, NamedExample}, headless::ZetaCliAppState, + paths::CACHE_DIR, predict::{PredictionDetails, zeta2_predict}, }; @@ -54,10 +55,8 @@ pub async fn run_evaluate_one( app_state: Arc, cx: &mut AsyncApp, ) -> Result { - let cache_dir = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap_or_default()) - .join("../../target/zeta-prediction-cache"); let example = NamedExample::load(&example_path).unwrap(); - let example_cache_path = cache_dir.join(&example_path.file_name().unwrap()); + let example_cache_path = CACHE_DIR.join(&example_path.file_name().unwrap()); let predictions = if !re_run && example_cache_path.exists() { let file_contents = fs::read_to_string(&example_cache_path)?; @@ -74,7 +73,7 @@ pub async fn run_evaluate_one( }; if !example_cache_path.exists() { - fs::create_dir_all(&cache_dir).unwrap(); + fs::create_dir_all(&*CACHE_DIR).unwrap(); fs::write( example_cache_path, serde_json::to_string(&predictions).unwrap(), diff --git a/crates/zeta_cli/src/example.rs b/crates/zeta_cli/src/example.rs index 6537068e84cc46f5ab72a0e1bd9e19445a3ec37e..ab62d690887aa42b2fb3de0c7f05cfc0975de177 100644 --- a/crates/zeta_cli/src/example.rs +++ b/crates/zeta_cli/src/example.rs @@ -1,28 +1,31 @@ use std::{ borrow::Cow, cell::RefCell, - env, fmt::{self, Display}, fs, io::Write, mem, - ops::Range, path::{Path, PathBuf}, sync::Arc, }; -use anyhow::{Context as _, Result}; +use anyhow::{Context as _, Result, anyhow}; use clap::ValueEnum; -use collections::{HashMap, HashSet}; +use cloud_zeta2_prompt::CURSOR_MARKER; +use collections::HashMap; use futures::{ AsyncWriteExt as _, lock::{Mutex, OwnedMutexGuard}, }; use gpui::{AsyncApp, Entity, http_client::Url}; -use language::Buffer; +use language::{Anchor, Buffer}; use project::{Project, ProjectPath}; use pulldown_cmark::CowStr; use serde::{Deserialize, Serialize}; +use util::{paths::PathStyle, rel_path::RelPath}; +use zeta2::udiff::OpenedBuffers; + +use crate::paths::{REPOS_DIR, WORKTREES_DIR}; const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff"; const EDIT_HISTORY_HEADING: &str = "Edit History"; @@ -215,12 +218,10 @@ impl NamedExample { let (repo_owner, repo_name) = self.repo_name()?; let file_name = self.file_name(); - let worktrees_dir = env::current_dir()?.join("target").join("zeta-worktrees"); - let repos_dir = env::current_dir()?.join("target").join("zeta-repos"); - fs::create_dir_all(&repos_dir)?; - fs::create_dir_all(&worktrees_dir)?; + fs::create_dir_all(&*REPOS_DIR)?; + fs::create_dir_all(&*WORKTREES_DIR)?; - let repo_dir = repos_dir.join(repo_owner.as_ref()).join(repo_name.as_ref()); + let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref()); let repo_lock = lock_repo(&repo_dir).await; if !repo_dir.is_dir() { @@ -251,7 +252,7 @@ impl NamedExample { }; // Create the worktree for this example if needed. - let worktree_path = worktrees_dir.join(&file_name); + let worktree_path = WORKTREES_DIR.join(&file_name); if worktree_path.is_dir() { run_git(&worktree_path, &["clean", "--force", "-d"]).await?; run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?; @@ -309,7 +310,6 @@ impl NamedExample { .collect() } - #[allow(unused)] fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> { // git@github.com:owner/repo.git if self.example.repository_url.contains('@') { @@ -344,13 +344,63 @@ impl NamedExample { } } + pub async fn cursor_position( + &self, + project: &Entity, + cx: &mut AsyncApp, + ) -> Result<(Entity, Anchor)> { + let worktree = project.read_with(cx, |project, cx| { + project.visible_worktrees(cx).next().unwrap() + })?; + let cursor_path = RelPath::new(&self.example.cursor_path, PathStyle::Posix)?.into_arc(); + let cursor_buffer = project + .update(cx, |project, cx| { + project.open_buffer( + ProjectPath { + worktree_id: worktree.read(cx).id(), + path: cursor_path, + }, + cx, + ) + })? + .await?; + let cursor_offset_within_excerpt = self + .example + .cursor_position + .find(CURSOR_MARKER) + .ok_or_else(|| anyhow!("missing cursor marker"))?; + let mut cursor_excerpt = self.example.cursor_position.clone(); + cursor_excerpt.replace_range( + cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()), + "", + ); + let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| { + let text = buffer.text(); + + let mut matches = text.match_indices(&cursor_excerpt); + let Some((excerpt_offset, _)) = matches.next() else { + anyhow::bail!( + "Cursor excerpt did not exist in buffer.\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n" + ); + }; + assert!(matches.next().is_none()); + + Ok(excerpt_offset) + })??; + + let cursor_offset = excerpt_offset + cursor_offset_within_excerpt; + let cursor_anchor = + cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?; + Ok((cursor_buffer, cursor_anchor)) + } + #[must_use] pub async fn apply_edit_history( &self, project: &Entity, cx: &mut AsyncApp, - ) -> Result>> { - apply_diff(&self.example.edit_history, project, cx).await + ) -> Result> { + zeta2::udiff::apply_diff(&self.example.edit_history, project, cx).await } } @@ -446,404 +496,3 @@ pub async fn lock_repo(path: impl AsRef) -> OwnedMutexGuard<()> { .lock_owned() .await } - -#[must_use] -pub async fn apply_diff( - diff: &str, - project: &Entity, - cx: &mut AsyncApp, -) -> Result>> { - use cloud_llm_client::udiff::DiffLine; - use std::fmt::Write; - - #[derive(Debug, Default)] - struct HunkState { - context: String, - edits: Vec, - } - - #[derive(Debug)] - struct Edit { - range: Range, - text: String, - } - - let mut old_path = None; - let mut new_path = None; - let mut hunk = HunkState::default(); - let mut diff_lines = diff.lines().map(DiffLine::parse).peekable(); - let mut open_buffers = HashSet::default(); - - while let Some(diff_line) = diff_lines.next() { - match diff_line { - DiffLine::OldPath { path } => old_path = Some(path), - DiffLine::NewPath { path } => { - if old_path.is_none() { - anyhow::bail!( - "Found a new path header (`+++`) before an (`---`) old path header" - ); - } - new_path = Some(path) - } - DiffLine::Context(ctx) => { - writeln!(&mut hunk.context, "{ctx}")?; - } - DiffLine::Deletion(del) => { - let range = hunk.context.len()..hunk.context.len() + del.len() + '\n'.len_utf8(); - if let Some(last_edit) = hunk.edits.last_mut() - && last_edit.range.end == range.start - { - last_edit.range.end = range.end; - } else { - hunk.edits.push(Edit { - range, - text: String::new(), - }); - } - writeln!(&mut hunk.context, "{del}")?; - } - DiffLine::Addition(add) => { - let range = hunk.context.len()..hunk.context.len(); - if let Some(last_edit) = hunk.edits.last_mut() - && last_edit.range.end == range.start - { - writeln!(&mut last_edit.text, "{add}").unwrap(); - } else { - hunk.edits.push(Edit { - range, - text: format!("{add}\n"), - }); - } - } - DiffLine::HunkHeader(_) | DiffLine::Garbage(_) => {} - } - - let at_hunk_end = match diff_lines.peek() { - Some(DiffLine::OldPath { .. }) | Some(DiffLine::HunkHeader(_)) | None => true, - _ => false, - }; - - if at_hunk_end { - let hunk = mem::take(&mut hunk); - - let Some(old_path) = old_path.as_deref() else { - anyhow::bail!("Missing old path (`---`) header") - }; - - let Some(new_path) = new_path.as_deref() else { - anyhow::bail!("Missing new path (`+++`) header") - }; - - let buffer = project - .update(cx, |project, cx| { - let project_path = project - .find_project_path(old_path, cx) - .context("Failed to find old_path in project")?; - - anyhow::Ok(project.open_buffer(project_path, cx)) - })?? - .await?; - open_buffers.insert(buffer.clone()); - - if old_path != new_path { - project - .update(cx, |project, cx| { - let project_file = project::File::from_dyn(buffer.read(cx).file()).unwrap(); - let new_path = ProjectPath { - worktree_id: project_file.worktree_id(cx), - path: project_file.path.clone(), - }; - project.rename_entry(project_file.entry_id.unwrap(), new_path, cx) - })? - .await?; - } - - // TODO is it worth using project search? - buffer.update(cx, |buffer, cx| { - let context_offset = if hunk.context.is_empty() { - 0 - } else { - let text = buffer.text(); - if let Some(offset) = text.find(&hunk.context) { - if text[offset + 1..].contains(&hunk.context) { - anyhow::bail!("Context is not unique enough:\n{}", hunk.context); - } - offset - } else { - anyhow::bail!( - "Failed to match context:\n{}\n\nBuffer:\n{}", - hunk.context, - text - ); - } - }; - - buffer.edit( - hunk.edits.into_iter().map(|edit| { - ( - context_offset + edit.range.start..context_offset + edit.range.end, - edit.text, - ) - }), - None, - cx, - ); - - anyhow::Ok(()) - })??; - } - } - - anyhow::Ok(open_buffers) -} - -#[cfg(test)] -mod tests { - use super::*; - use ::fs::FakeFs; - use gpui::TestAppContext; - use indoc::indoc; - use pretty_assertions::assert_eq; - use project::Project; - use serde_json::json; - use settings::SettingsStore; - use util::path; - - #[gpui::test] - async fn test_apply_diff_successful(cx: &mut TestAppContext) { - let buffer_1_text = indoc! {r#" - one - two - three - four - five - "# }; - - let buffer_1_text_final = indoc! {r#" - 3 - 4 - 5 - "# }; - - let buffer_2_text = indoc! {r#" - six - seven - eight - nine - ten - "# }; - - let buffer_2_text_final = indoc! {r#" - 5 - six - seven - 7.5 - eight - nine - ten - 11 - "# }; - - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - Project::init_settings(cx); - language::init(cx); - }); - - let fs = FakeFs::new(cx.background_executor.clone()); - fs.insert_tree( - path!("/root"), - json!({ - "file1": buffer_1_text, - "file2": buffer_2_text, - }), - ) - .await; - - let project = Project::test(fs, [path!("/root").as_ref()], cx).await; - - let diff = indoc! {r#" - --- a/root/file1 - +++ b/root/file1 - one - two - -three - +3 - four - five - --- a/root/file1 - +++ b/root/file1 - 3 - -four - -five - +4 - +5 - --- a/root/file1 - +++ b/root/file1 - -one - -two - 3 - 4 - --- a/root/file2 - +++ b/root/file2 - +5 - six - --- a/root/file2 - +++ b/root/file2 - seven - +7.5 - eight - --- a/root/file2 - +++ b/root/file2 - ten - +11 - "#}; - - let _buffers = apply_diff(diff, &project, &mut cx.to_async()) - .await - .unwrap(); - let buffer_1 = project - .update(cx, |project, cx| { - let project_path = project.find_project_path(path!("/root/file1"), cx).unwrap(); - project.open_buffer(project_path, cx) - }) - .await - .unwrap(); - - buffer_1.read_with(cx, |buffer, _cx| { - assert_eq!(buffer.text(), buffer_1_text_final); - }); - let buffer_2 = project - .update(cx, |project, cx| { - let project_path = project.find_project_path(path!("/root/file2"), cx).unwrap(); - project.open_buffer(project_path, cx) - }) - .await - .unwrap(); - - buffer_2.read_with(cx, |buffer, _cx| { - assert_eq!(buffer.text(), buffer_2_text_final); - }); - } - - #[gpui::test] - async fn test_apply_diff_non_unique(cx: &mut TestAppContext) { - let buffer_1_text = indoc! {r#" - one - two - three - four - five - one - two - three - four - five - "# }; - - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - Project::init_settings(cx); - language::init(cx); - }); - - let fs = FakeFs::new(cx.background_executor.clone()); - fs.insert_tree( - path!("/root"), - json!({ - "file1": buffer_1_text, - }), - ) - .await; - - let project = Project::test(fs, [path!("/root").as_ref()], cx).await; - - let diff = indoc! {r#" - --- a/root/file1 - +++ b/root/file1 - one - two - -three - +3 - four - five - "#}; - - apply_diff(diff, &project, &mut cx.to_async()) - .await - .expect_err("Non-unique edits should fail"); - } - - #[gpui::test] - async fn test_apply_diff_unique_via_previous_context(cx: &mut TestAppContext) { - let start = indoc! {r#" - one - two - three - four - five - - four - five - "# }; - - let end = indoc! {r#" - one - two - 3 - four - 5 - - four - five - "# }; - - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - Project::init_settings(cx); - language::init(cx); - }); - - let fs = FakeFs::new(cx.background_executor.clone()); - fs.insert_tree( - path!("/root"), - json!({ - "file1": start, - }), - ) - .await; - - let project = Project::test(fs, [path!("/root").as_ref()], cx).await; - - let diff = indoc! {r#" - --- a/root/file1 - +++ b/root/file1 - one - two - -three - +3 - four - -five - +5 - "#}; - - let _buffers = apply_diff(diff, &project, &mut cx.to_async()) - .await - .unwrap(); - - let buffer_1 = project - .update(cx, |project, cx| { - let project_path = project.find_project_path(path!("/root/file1"), cx).unwrap(); - project.open_buffer(project_path, cx) - }) - .await - .unwrap(); - - buffer_1.read_with(cx, |buffer, _cx| { - assert_eq!(buffer.text(), end); - }); - } -} diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index 43d5b899c8e1c3f3656d5752ffd226dc4b73656d..66b4a6c8bd71ce046b6336ecb671d491128af945 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -1,6 +1,7 @@ mod evaluate; mod example; mod headless; +mod paths; mod predict; mod source_location; mod syntax_retrieval_stats; @@ -10,28 +11,22 @@ use crate::evaluate::{EvaluateArguments, run_evaluate}; use crate::example::{ExampleFormat, NamedExample}; use crate::predict::{PredictArguments, run_zeta2_predict}; use crate::syntax_retrieval_stats::retrieval_stats; -use ::serde::Serialize; use ::util::paths::PathStyle; -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Result, anyhow}; use clap::{Args, Parser, Subcommand}; -use cloud_llm_client::predict_edits_v3::{self, Excerpt}; -use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock}; +use cloud_llm_client::predict_edits_v3; use edit_prediction_context::{ - EditPredictionContextOptions, EditPredictionExcerpt, EditPredictionExcerptOptions, - EditPredictionScoreOptions, Line, + EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions, }; -use futures::StreamExt as _; -use futures::channel::mpsc; use gpui::{Application, AsyncApp, Entity, prelude::*}; -use language::{Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point}; -use language_model::LanguageModelRegistry; +use language::{Bias, Buffer, BufferSnapshot, Point}; use project::{Project, Worktree}; use reqwest_client::ReqwestClient; use serde_json::json; use std::io::{self}; use std::time::Duration; use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc}; -use zeta2::{ContextMode, LlmContextOptions, SearchToolQuery}; +use zeta2::ContextMode; use crate::headless::ZetaCliAppState; use crate::source_location::SourceLocation; @@ -79,12 +74,6 @@ enum Zeta2Command { #[command(subcommand)] command: Zeta2SyntaxCommand, }, - Llm { - #[clap(flatten)] - args: Zeta2Args, - #[command(subcommand)] - command: Zeta2LlmCommand, - }, Predict(PredictArguments), Eval(EvaluateArguments), } @@ -107,14 +96,6 @@ enum Zeta2SyntaxCommand { }, } -#[derive(Subcommand, Debug)] -enum Zeta2LlmCommand { - Context { - #[clap(flatten)] - context_args: ContextArgs, - }, -} - #[derive(Debug, Args)] #[group(requires = "worktree")] struct ContextArgs { @@ -388,197 +369,6 @@ async fn zeta2_syntax_context( Ok(output) } -async fn zeta2_llm_context( - zeta2_args: Zeta2Args, - context_args: ContextArgs, - app_state: &Arc, - cx: &mut AsyncApp, -) -> Result { - let LoadedContext { - buffer, - clipped_cursor, - snapshot: cursor_snapshot, - project, - .. - } = load_context(&context_args, app_state, cx).await?; - - let cursor_position = cursor_snapshot.anchor_after(clipped_cursor); - - cx.update(|cx| { - LanguageModelRegistry::global(cx).update(cx, |registry, cx| { - registry - .provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID) - .unwrap() - .authenticate(cx) - }) - })? - .await?; - - let edit_history_unified_diff = match context_args.edit_history { - Some(events) => events.read_to_string().await?, - None => String::new(), - }; - - let (debug_tx, mut debug_rx) = mpsc::unbounded(); - - let excerpt_options = EditPredictionExcerptOptions { - max_bytes: zeta2_args.max_excerpt_bytes, - min_bytes: zeta2_args.min_excerpt_bytes, - target_before_cursor_over_total_bytes: zeta2_args.target_before_cursor_over_total_bytes, - }; - - let related_excerpts = cx - .update(|cx| { - zeta2::related_excerpts::find_related_excerpts( - buffer, - cursor_position, - &project, - edit_history_unified_diff, - &LlmContextOptions { - excerpt: excerpt_options.clone(), - }, - Some(debug_tx), - cx, - ) - })? - .await?; - - let cursor_excerpt = EditPredictionExcerpt::select_from_buffer( - clipped_cursor, - &cursor_snapshot, - &excerpt_options, - None, - ) - .context("line didn't fit")?; - - #[derive(Serialize)] - struct Output { - excerpts: Vec, - formatted_excerpts: String, - meta: OutputMeta, - } - - #[derive(Default, Serialize)] - struct OutputMeta { - search_prompt: String, - search_queries: Vec, - } - - #[derive(Serialize)] - struct OutputExcerpt { - path: PathBuf, - #[serde(flatten)] - excerpt: Excerpt, - } - - let mut meta = OutputMeta::default(); - - while let Some(debug_info) = debug_rx.next().await { - match debug_info { - zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => { - meta.search_prompt = info.search_prompt; - } - zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => { - meta.search_queries = info.queries - } - _ => {} - } - } - - cx.update(|cx| { - let mut excerpts = Vec::new(); - let mut formatted_excerpts = String::new(); - - let cursor_insertions = [( - predict_edits_v3::Point { - line: Line(clipped_cursor.row), - column: clipped_cursor.column, - }, - CURSOR_MARKER, - )]; - - let mut cursor_excerpt_added = false; - - for (buffer, ranges) in related_excerpts { - let excerpt_snapshot = buffer.read(cx).snapshot(); - - let mut line_ranges = ranges - .into_iter() - .map(|range| { - let point_range = range.to_point(&excerpt_snapshot); - Line(point_range.start.row)..Line(point_range.end.row) - }) - .collect::>(); - - let Some(file) = excerpt_snapshot.file() else { - continue; - }; - let path = file.full_path(cx); - - let is_cursor_file = path == cursor_snapshot.file().unwrap().full_path(cx); - if is_cursor_file { - let insertion_ix = line_ranges - .binary_search_by(|probe| { - probe - .start - .cmp(&cursor_excerpt.line_range.start) - .then(cursor_excerpt.line_range.end.cmp(&probe.end)) - }) - .unwrap_or_else(|ix| ix); - line_ranges.insert(insertion_ix, cursor_excerpt.line_range.clone()); - cursor_excerpt_added = true; - } - - let merged_excerpts = - zeta2::merge_excerpts::merge_excerpts(&excerpt_snapshot, line_ranges) - .into_iter() - .map(|excerpt| OutputExcerpt { - path: path.clone(), - excerpt, - }); - - let excerpt_start_ix = excerpts.len(); - excerpts.extend(merged_excerpts); - - write_codeblock( - &path, - excerpts[excerpt_start_ix..].iter().map(|e| &e.excerpt), - if is_cursor_file { - &cursor_insertions - } else { - &[] - }, - Line(excerpt_snapshot.max_point().row), - true, - &mut formatted_excerpts, - ); - } - - if !cursor_excerpt_added { - write_codeblock( - &cursor_snapshot.file().unwrap().full_path(cx), - &[Excerpt { - start_line: cursor_excerpt.line_range.start, - text: cursor_excerpt.text(&cursor_snapshot).body.into(), - }], - &cursor_insertions, - Line(cursor_snapshot.max_point().row), - true, - &mut formatted_excerpts, - ); - } - - let output = Output { - excerpts, - formatted_excerpts, - meta, - }; - - Ok(serde_json::to_string_pretty(&output)?) - }) - .unwrap() -} - async fn zeta1_context( args: ContextArgs, app_state: &Arc, @@ -670,13 +460,6 @@ fn main() { }; println!("{}", result.unwrap()); } - Zeta2Command::Llm { args, command } => match command { - Zeta2LlmCommand::Context { context_args } => { - let result = - zeta2_llm_context(args, context_args, &app_state, cx).await; - println!("{}", result.unwrap()); - } - }, }, Command::ConvertExample { path, diff --git a/crates/zeta_cli/src/paths.rs b/crates/zeta_cli/src/paths.rs new file mode 100644 index 0000000000000000000000000000000000000000..61987607bf2a5bb99eae68db4863f97bb282b29c --- /dev/null +++ b/crates/zeta_cli/src/paths.rs @@ -0,0 +1,8 @@ +use std::{env, path::PathBuf, sync::LazyLock}; + +static TARGET_DIR: LazyLock = LazyLock::new(|| env::current_dir().unwrap().join("target")); +pub static CACHE_DIR: LazyLock = + LazyLock::new(|| TARGET_DIR.join("zeta-prediction-cache")); +pub static REPOS_DIR: LazyLock = LazyLock::new(|| TARGET_DIR.join("zeta-repos")); +pub static WORKTREES_DIR: LazyLock = LazyLock::new(|| TARGET_DIR.join("zeta-worktrees")); +pub static LOGS_DIR: LazyLock = LazyLock::new(|| TARGET_DIR.join("zeta-logs")); diff --git a/crates/zeta_cli/src/predict.rs b/crates/zeta_cli/src/predict.rs index cdf385ae6db0556180ae9a223ff32efc53ad9a02..c94353d5b1dd02648265b70b4b584705f3769f96 100644 --- a/crates/zeta_cli/src/predict.rs +++ b/crates/zeta_cli/src/predict.rs @@ -1,22 +1,20 @@ use crate::example::{ActualExcerpt, NamedExample}; - use crate::headless::ZetaCliAppState; +use crate::paths::LOGS_DIR; use ::serde::Serialize; -use ::util::paths::PathStyle; use anyhow::{Context as _, Result, anyhow}; use clap::Args; use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock}; use futures::StreamExt as _; use gpui::AsyncApp; -use language_model::LanguageModelRegistry; -use project::{Project, ProjectPath}; +use project::Project; use serde::Deserialize; use std::cell::Cell; +use std::fs; use std::io::Write; use std::path::PathBuf; use std::sync::Arc; use std::time::{Duration, Instant}; -use util::rel_path::RelPath; #[derive(Debug, Args)] pub struct PredictArguments { @@ -50,21 +48,12 @@ pub async fn zeta2_predict( app_state: &Arc, cx: &mut AsyncApp, ) -> Result { + fs::create_dir_all(&*LOGS_DIR)?; let worktree_path = example.setup_worktree().await?; if !AUTHENTICATED.get() { AUTHENTICATED.set(true); - cx.update(|cx| { - LanguageModelRegistry::global(cx).update(cx, |registry, cx| { - registry - .provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID) - .unwrap() - .authenticate(cx) - }) - })? - .await?; - app_state .client .sign_in_with_optional_connect(true, cx) @@ -83,6 +72,8 @@ pub async fn zeta2_predict( ) })?; + let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?; + let worktree = project .update(cx, |project, cx| { project.create_worktree(&worktree_path, true, cx) @@ -94,58 +85,30 @@ pub async fn zeta2_predict( })? .await; - let _edited_buffers = example.apply_edit_history(&project, cx).await?; - - let cursor_path = RelPath::new(&example.example.cursor_path, PathStyle::Posix)?.into_arc(); - - let cursor_buffer = project - .update(cx, |project, cx| { - project.open_buffer( - ProjectPath { - worktree_id: worktree.read(cx).id(), - path: cursor_path, - }, - cx, - ) - })? - .await?; - - let cursor_offset_within_excerpt = example - .example - .cursor_position - .find(CURSOR_MARKER) - .ok_or_else(|| anyhow!("missing cursor marker"))?; - let mut cursor_excerpt = example.example.cursor_position.clone(); - cursor_excerpt.replace_range( - cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()), - "", - ); - let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| { - let text = buffer.text(); - - let mut matches = text.match_indices(&cursor_excerpt); - let Some((excerpt_offset, _)) = matches.next() else { - anyhow::bail!( - "Cursor excerpt did not exist in buffer.\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n" - ); - }; - assert!(matches.next().is_none()); + let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?; - Ok(excerpt_offset) - })??; + cx.subscribe(&buffer_store, { + let project = project.clone(); + move |_, event, cx| match event { + project::buffer_store::BufferStoreEvent::BufferAdded(buffer) => { + zeta2::Zeta::try_global(cx) + .unwrap() + .update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx)); + } + _ => {} + } + })? + .detach(); - let cursor_offset = excerpt_offset + cursor_offset_within_excerpt; - let cursor_anchor = - cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?; + let _edited_buffers = example.apply_edit_history(&project, cx).await?; + let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?; - let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?; + let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?; let refresh_task = zeta.update(cx, |zeta, cx| { - zeta.register_buffer(&cursor_buffer, &project, cx); zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx) })?; - let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?; let mut context_retrieval_started_at = None; let mut context_retrieval_finished_at = None; let mut search_queries_generated_at = None; @@ -159,9 +122,14 @@ pub async fn zeta2_predict( match event { zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => { context_retrieval_started_at = Some(info.timestamp); + fs::write(LOGS_DIR.join("search_prompt.md"), &info.search_prompt)?; } zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => { search_queries_generated_at = Some(info.timestamp); + fs::write( + LOGS_DIR.join("search_queries.json"), + serde_json::to_string_pretty(&info.regex_by_glob).unwrap(), + )?; } zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => { search_queries_executed_at = Some(info.timestamp); @@ -173,11 +141,21 @@ pub async fn zeta2_predict( zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx) })?); } - zeta2::ZetaDebugInfo::EditPredicted(request) => { + zeta2::ZetaDebugInfo::EditPredictionRequested(request) => { prediction_started_at = Some(Instant::now()); - request.response_rx.await?.map_err(|err| anyhow!(err))?; + fs::write( + LOGS_DIR.join("prediction_prompt.md"), + &request.local_prompt.unwrap_or_default(), + )?; + + let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?; prediction_finished_at = Some(Instant::now()); + fs::write( + LOGS_DIR.join("prediction_response.json"), + &serde_json::to_string_pretty(&response).unwrap(), + )?; + for included_file in request.request.included_files { let insertions = vec![(request.request.cursor_point, CURSOR_MARKER)]; result @@ -201,7 +179,6 @@ pub async fn zeta2_predict( } break; } - _ => {} } }