From ee80ba6693b4541b8d9939c34c16cc0f94547d73 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Mon, 27 Oct 2025 19:54:42 -0300 Subject: [PATCH] zeta2: LLM-based context gathering (#41326) Release Notes: - N/A --------- Co-authored-by: Max Brunsfeld Co-authored-by: Max Brunsfeld --- Cargo.lock | 2 + .../cloud_llm_client/src/predict_edits_v3.rs | 17 + .../src/cloud_zeta2_prompt.rs | 278 ++++++--- crates/language/Cargo.toml | 1 - crates/language/src/buffer.rs | 49 +- crates/language/src/outline.rs | 7 + crates/languages/src/rust/outline.scm | 5 +- crates/multi_buffer/src/multi_buffer.rs | 7 + crates/outline_panel/src/outline_panel.rs | 1 + crates/zeta2/Cargo.toml | 3 + crates/zeta2/src/merge_excerpts.rs | 192 ++++++ crates/zeta2/src/provider.rs | 4 + crates/zeta2/src/related_excerpts.rs | 586 ++++++++++++++++++ crates/zeta2/src/zeta2.rs | 436 ++++++++++--- crates/zeta2_tools/src/zeta2_tools.rs | 203 ++++-- crates/zeta_cli/src/main.rs | 9 +- crates/zeta_cli/src/retrieval_stats.rs | 19 +- 17 files changed, 1575 insertions(+), 244 deletions(-) create mode 100644 crates/zeta2/src/merge_excerpts.rs create mode 100644 crates/zeta2/src/related_excerpts.rs diff --git a/Cargo.lock b/Cargo.lock index 8b96a89070a6f3d1a5cd179a41a84e3e913ece7a..9dce268507edf8a0554d9b113de044c564d1827e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -21640,6 +21640,7 @@ dependencies = [ "clock", "cloud_llm_client", "cloud_zeta2_prompt", + "collections", "edit_prediction", "edit_prediction_context", "feature_flags", @@ -21653,6 +21654,7 @@ dependencies = [ "pretty_assertions", "project", "release_channel", + "schemars 1.0.4", "serde", "serde_json", "settings", diff --git a/crates/cloud_llm_client/src/predict_edits_v3.rs b/crates/cloud_llm_client/src/predict_edits_v3.rs index e03541e0f7d66bd54d6fbd918debbdc3d6c8d9e7..7166139d9077394e684a8b53ce3d8300cb5fa2db 100644 --- a/crates/cloud_llm_client/src/predict_edits_v3.rs +++ b/crates/cloud_llm_client/src/predict_edits_v3.rs @@ -23,7 +23,11 @@ pub struct PredictEditsRequest { pub cursor_point: Point, /// Within `signatures` pub excerpt_parent: Option, + #[serde(skip_serializing_if = "Vec::is_empty", default)] + pub included_files: Vec, + #[serde(skip_serializing_if = "Vec::is_empty", default)] pub signatures: Vec, + #[serde(skip_serializing_if = "Vec::is_empty", default)] pub referenced_declarations: Vec, pub events: Vec, #[serde(default)] @@ -44,6 +48,19 @@ pub struct PredictEditsRequest { pub prompt_format: PromptFormat, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct IncludedFile { + pub path: Arc, + pub max_row: Line, + pub excerpts: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Excerpt { + pub start_line: Line, + pub text: Arc, +} + #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)] pub enum PromptFormat { MarkedExcerpt, diff --git a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs index 284b245acf2305350e6a6a5e7c38dfaa9b16c5d4..1c8b1caf80db28ef936aa9a747b4a163e183134f 100644 --- a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs +++ b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs @@ -1,11 +1,14 @@ //! Zeta2 prompt planning and generation code shared with cloud. use anyhow::{Context as _, Result, anyhow}; -use cloud_llm_client::predict_edits_v3::{self, Line, Point, PromptFormat, ReferencedDeclaration}; +use cloud_llm_client::predict_edits_v3::{ + self, Excerpt, Line, Point, PromptFormat, ReferencedDeclaration, +}; use indoc::indoc; use ordered_float::OrderedFloat; use rustc_hash::{FxHashMap, FxHashSet}; use serde::Serialize; +use std::cmp; use std::fmt::Write; use std::sync::Arc; use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path}; @@ -96,7 +99,177 @@ const UNIFIED_DIFF_REMINDER: &str = indoc! {" If you're editing multiple files, be sure to reflect filename in the hunk's header. "}; -pub struct PlannedPrompt<'a> { +pub fn build_prompt( + request: &predict_edits_v3::PredictEditsRequest, +) -> Result<(String, SectionLabels)> { + let mut insertions = match request.prompt_format { + PromptFormat::MarkedExcerpt => vec![ + ( + Point { + line: request.excerpt_line_range.start, + column: 0, + }, + EDITABLE_REGION_START_MARKER_WITH_NEWLINE, + ), + (request.cursor_point, CURSOR_MARKER), + ( + Point { + line: request.excerpt_line_range.end, + column: 0, + }, + EDITABLE_REGION_END_MARKER_WITH_NEWLINE, + ), + ], + PromptFormat::LabeledSections => vec![(request.cursor_point, CURSOR_MARKER)], + PromptFormat::NumLinesUniDiff => { + vec![(request.cursor_point, CURSOR_MARKER)] + } + PromptFormat::OnlySnippets => vec![], + }; + + let mut prompt = match request.prompt_format { + PromptFormat::MarkedExcerpt => MARKED_EXCERPT_INSTRUCTIONS.to_string(), + PromptFormat::LabeledSections => LABELED_SECTIONS_INSTRUCTIONS.to_string(), + PromptFormat::NumLinesUniDiff => NUMBERED_LINES_INSTRUCTIONS.to_string(), + // only intended for use via zeta_cli + PromptFormat::OnlySnippets => String::new(), + }; + + if request.events.is_empty() { + prompt.push_str("(No edit history)\n\n"); + } else { + prompt.push_str( + "The following are the latest edits made by the user, from earlier to later.\n\n", + ); + push_events(&mut prompt, &request.events); + } + + if request.prompt_format == PromptFormat::NumLinesUniDiff { + if request.referenced_declarations.is_empty() { + prompt.push_str(indoc! {" + # File under the cursor: + + The cursor marker <|user_cursor|> indicates the current user cursor position. + The file is in current state, edits from edit history have been applied. + We prepend line numbers (e.g., `123|`); they are not part of the file. + + "}); + } else { + // Note: This hasn't been trained on yet + prompt.push_str(indoc! {" + # Code Excerpts: + + The cursor marker <|user_cursor|> indicates the current user cursor position. + Other excerpts of code from the project have been included as context based on their similarity to the code under the cursor. + Context excerpts are not guaranteed to be relevant, so use your own judgement. + Files are in their current state, edits from edit history have been applied. + We prepend line numbers (e.g., `123|`); they are not part of the file. + + "}); + } + } else { + prompt.push_str("\n## Code\n\n"); + } + + let mut section_labels = Default::default(); + + if !request.referenced_declarations.is_empty() || !request.signatures.is_empty() { + let syntax_based_prompt = SyntaxBasedPrompt::populate(request)?; + section_labels = syntax_based_prompt.write(&mut insertions, &mut prompt)?; + } else { + if request.prompt_format == PromptFormat::LabeledSections { + anyhow::bail!("PromptFormat::LabeledSections cannot be used with ContextMode::Llm"); + } + + for related_file in &request.included_files { + writeln!(&mut prompt, "`````filename={}", related_file.path.display()).unwrap(); + write_excerpts( + &related_file.excerpts, + if related_file.path == request.excerpt_path { + &insertions + } else { + &[] + }, + related_file.max_row, + request.prompt_format == PromptFormat::NumLinesUniDiff, + &mut prompt, + ); + write!(&mut prompt, "`````\n\n").unwrap(); + } + } + + if request.prompt_format == PromptFormat::NumLinesUniDiff { + prompt.push_str(UNIFIED_DIFF_REMINDER); + } + + Ok((prompt, section_labels)) +} + +pub fn write_excerpts<'a>( + excerpts: impl IntoIterator, + sorted_insertions: &[(Point, &str)], + file_line_count: Line, + include_line_numbers: bool, + output: &mut String, +) { + let mut current_row = Line(0); + let mut sorted_insertions = sorted_insertions.iter().peekable(); + + for excerpt in excerpts { + if excerpt.start_line > current_row { + writeln!(output, "…").unwrap(); + } + if excerpt.text.is_empty() { + return; + } + + current_row = excerpt.start_line; + + for mut line in excerpt.text.lines() { + if include_line_numbers { + write!(output, "{}|", current_row.0 + 1).unwrap(); + } + + while let Some((insertion_location, insertion_marker)) = sorted_insertions.peek() { + match current_row.cmp(&insertion_location.line) { + cmp::Ordering::Equal => { + let (prefix, suffix) = line.split_at(insertion_location.column as usize); + output.push_str(prefix); + output.push_str(insertion_marker); + line = suffix; + sorted_insertions.next(); + } + cmp::Ordering::Less => break, + cmp::Ordering::Greater => { + sorted_insertions.next(); + break; + } + } + } + output.push_str(line); + output.push('\n'); + current_row.0 += 1; + } + } + + if current_row < file_line_count { + writeln!(output, "…").unwrap(); + } +} + +fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) { + if events.is_empty() { + return; + }; + + writeln!(output, "`````diff").unwrap(); + for event in events { + writeln!(output, "{}", event).unwrap(); + } + writeln!(output, "`````\n").unwrap(); +} + +pub struct SyntaxBasedPrompt<'a> { request: &'a predict_edits_v3::PredictEditsRequest, /// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in /// `to_prompt_string`. @@ -120,13 +293,13 @@ pub enum DeclarationStyle { Declaration, } -#[derive(Clone, Debug, Serialize)] +#[derive(Default, Clone, Debug, Serialize)] pub struct SectionLabels { pub excerpt_index: usize, pub section_ranges: Vec<(Arc, Range)>, } -impl<'a> PlannedPrompt<'a> { +impl<'a> SyntaxBasedPrompt<'a> { /// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following: /// /// Initializes a priority queue by populating it with each snippet, finding the @@ -149,7 +322,7 @@ impl<'a> PlannedPrompt<'a> { /// /// * Does not include file paths / other text when considering max_bytes. pub fn populate(request: &'a predict_edits_v3::PredictEditsRequest) -> Result { - let mut this = PlannedPrompt { + let mut this = Self { request, snippets: Vec::new(), budget_used: request.excerpt.len(), @@ -354,7 +527,11 @@ impl<'a> PlannedPrompt<'a> { /// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple /// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive /// chunks. - pub fn to_prompt_string(&'a self) -> Result<(String, SectionLabels)> { + pub fn write( + &'a self, + excerpt_file_insertions: &mut Vec<(Point, &'static str)>, + prompt: &mut String, + ) -> Result { let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> = FxHashMap::default(); for snippet in &self.snippets { @@ -383,95 +560,10 @@ impl<'a> PlannedPrompt<'a> { excerpt_file_snippets.push(&excerpt_snippet); file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true)); - let mut excerpt_file_insertions = match self.request.prompt_format { - PromptFormat::MarkedExcerpt => vec![ - ( - Point { - line: self.request.excerpt_line_range.start, - column: 0, - }, - EDITABLE_REGION_START_MARKER_WITH_NEWLINE, - ), - (self.request.cursor_point, CURSOR_MARKER), - ( - Point { - line: self.request.excerpt_line_range.end, - column: 0, - }, - EDITABLE_REGION_END_MARKER_WITH_NEWLINE, - ), - ], - PromptFormat::LabeledSections => vec![(self.request.cursor_point, CURSOR_MARKER)], - PromptFormat::NumLinesUniDiff => { - vec![(self.request.cursor_point, CURSOR_MARKER)] - } - PromptFormat::OnlySnippets => vec![], - }; - - let mut prompt = match self.request.prompt_format { - PromptFormat::MarkedExcerpt => MARKED_EXCERPT_INSTRUCTIONS.to_string(), - PromptFormat::LabeledSections => LABELED_SECTIONS_INSTRUCTIONS.to_string(), - PromptFormat::NumLinesUniDiff => NUMBERED_LINES_INSTRUCTIONS.to_string(), - // only intended for use via zeta_cli - PromptFormat::OnlySnippets => String::new(), - }; - - if self.request.events.is_empty() { - prompt.push_str("(No edit history)\n\n"); - } else { - prompt.push_str( - "The following are the latest edits made by the user, from earlier to later.\n\n", - ); - Self::push_events(&mut prompt, &self.request.events); - } - - if self.request.prompt_format == PromptFormat::NumLinesUniDiff { - if self.request.referenced_declarations.is_empty() { - prompt.push_str(indoc! {" - # File under the cursor: - - The cursor marker <|user_cursor|> indicates the current user cursor position. - The file is in current state, edits from edit history have been applied. - We prepend line numbers (e.g., `123|`); they are not part of the file. - - "}); - } else { - // Note: This hasn't been trained on yet - prompt.push_str(indoc! {" - # Code Excerpts: - - The cursor marker <|user_cursor|> indicates the current user cursor position. - Other excerpts of code from the project have been included as context based on their similarity to the code under the cursor. - Context excerpts are not guaranteed to be relevant, so use your own judgement. - Files are in their current state, edits from edit history have been applied. - We prepend line numbers (e.g., `123|`); they are not part of the file. - - "}); - } - } else { - prompt.push_str("\n## Code\n\n"); - } - let section_labels = - self.push_file_snippets(&mut prompt, &mut excerpt_file_insertions, file_snippets)?; - - if self.request.prompt_format == PromptFormat::NumLinesUniDiff { - prompt.push_str(UNIFIED_DIFF_REMINDER); - } - - Ok((prompt, section_labels)) - } + self.push_file_snippets(prompt, excerpt_file_insertions, file_snippets)?; - fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) { - if events.is_empty() { - return; - }; - - writeln!(output, "`````diff").unwrap(); - for event in events { - writeln!(output, "{}", event).unwrap(); - } - writeln!(output, "`````\n").unwrap(); + Ok(section_labels) } fn push_file_snippets( diff --git a/crates/language/Cargo.toml b/crates/language/Cargo.toml index bbbf9e31a5b39069e93a5f52f18df16bbc9f9671..ffc5ad85d14c293eeeaff9172b21ef58cf9a1cf0 100644 --- a/crates/language/Cargo.toml +++ b/crates/language/Cargo.toml @@ -20,7 +20,6 @@ test-support = [ "text/test-support", "tree-sitter-rust", "tree-sitter-python", - "tree-sitter-rust", "tree-sitter-typescript", "settings/test-support", "util/test-support", diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index 41c0e3eec8e8f4daaf5dff706dceea4159fedae1..c2da93aa7399267f6300625da58aba9bf6dccc4f 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -3833,6 +3833,32 @@ impl BufferSnapshot { include_extra_context: bool, theme: Option<&SyntaxTheme>, ) -> Vec> { + self.outline_items_containing_internal( + range, + include_extra_context, + theme, + |this, range| this.anchor_after(range.start)..this.anchor_before(range.end), + ) + } + + pub fn outline_items_as_points_containing( + &self, + range: Range, + include_extra_context: bool, + theme: Option<&SyntaxTheme>, + ) -> Vec> { + self.outline_items_containing_internal(range, include_extra_context, theme, |_, range| { + range + }) + } + + fn outline_items_containing_internal( + &self, + range: Range, + include_extra_context: bool, + theme: Option<&SyntaxTheme>, + range_callback: fn(&Self, Range) -> Range, + ) -> Vec> { let range = range.to_offset(self); let mut matches = self.syntax.matches(range.clone(), &self.text, |grammar| { grammar.outline_config.as_ref().map(|c| &c.query) @@ -3905,19 +3931,16 @@ impl BufferSnapshot { anchor_items.push(OutlineItem { depth: item_ends_stack.len(), - range: self.anchor_after(item.range.start)..self.anchor_before(item.range.end), + range: range_callback(self, item.range.clone()), + source_range_for_text: range_callback(self, item.source_range_for_text.clone()), text: item.text, highlight_ranges: item.highlight_ranges, name_ranges: item.name_ranges, - body_range: item - .body_range - .map(|r| self.anchor_after(r.start)..self.anchor_before(r.end)), + body_range: item.body_range.map(|r| range_callback(self, r)), annotation_range: annotation_row_range.map(|annotation_range| { - self.anchor_after(Point::new(annotation_range.start, 0)) - ..self.anchor_before(Point::new( - annotation_range.end, - self.line_len(annotation_range.end), - )) + let point_range = Point::new(annotation_range.start, 0) + ..Point::new(annotation_range.end, self.line_len(annotation_range.end)); + range_callback(self, point_range) }), }); item_ends_stack.push(item.range.end); @@ -3984,14 +4007,13 @@ impl BufferSnapshot { if buffer_ranges.is_empty() { return None; } + let source_range_for_text = + buffer_ranges.first().unwrap().0.start..buffer_ranges.last().unwrap().0.end; let mut text = String::new(); let mut highlight_ranges = Vec::new(); let mut name_ranges = Vec::new(); - let mut chunks = self.chunks( - buffer_ranges.first().unwrap().0.start..buffer_ranges.last().unwrap().0.end, - true, - ); + let mut chunks = self.chunks(source_range_for_text.clone(), true); let mut last_buffer_range_end = 0; for (buffer_range, is_name) in buffer_ranges { let space_added = !text.is_empty() && buffer_range.start > last_buffer_range_end; @@ -4037,6 +4059,7 @@ impl BufferSnapshot { Some(OutlineItem { depth: 0, // We'll calculate the depth later range: item_point_range, + source_range_for_text: source_range_for_text.to_point(self), text, highlight_ranges, name_ranges, diff --git a/crates/language/src/outline.rs b/crates/language/src/outline.rs index d96cd90e03142c6498ae17bc63e1787d99e8557a..2ce2b42734465a4710a7439f5e2225debc96b04a 100644 --- a/crates/language/src/outline.rs +++ b/crates/language/src/outline.rs @@ -16,6 +16,7 @@ pub struct Outline { pub struct OutlineItem { pub depth: usize, pub range: Range, + pub source_range_for_text: Range, pub text: String, pub highlight_ranges: Vec<(Range, HighlightStyle)>, pub name_ranges: Vec>, @@ -32,6 +33,8 @@ impl OutlineItem { OutlineItem { depth: self.depth, range: self.range.start.to_point(buffer)..self.range.end.to_point(buffer), + source_range_for_text: self.source_range_for_text.start.to_point(buffer) + ..self.source_range_for_text.end.to_point(buffer), text: self.text.clone(), highlight_ranges: self.highlight_ranges.clone(), name_ranges: self.name_ranges.clone(), @@ -205,6 +208,7 @@ mod tests { OutlineItem { depth: 0, range: Point::new(0, 0)..Point::new(5, 0), + source_range_for_text: Point::new(0, 0)..Point::new(0, 9), text: "class Foo".to_string(), highlight_ranges: vec![], name_ranges: vec![6..9], @@ -214,6 +218,7 @@ mod tests { OutlineItem { depth: 0, range: Point::new(2, 0)..Point::new(2, 7), + source_range_for_text: Point::new(0, 0)..Point::new(0, 7), text: "private".to_string(), highlight_ranges: vec![], name_ranges: vec![], @@ -238,6 +243,7 @@ mod tests { OutlineItem { depth: 0, range: Point::new(0, 0)..Point::new(5, 0), + source_range_for_text: Point::new(0, 0)..Point::new(0, 10), text: "fn process".to_string(), highlight_ranges: vec![], name_ranges: vec![3..10], @@ -247,6 +253,7 @@ mod tests { OutlineItem { depth: 0, range: Point::new(7, 0)..Point::new(12, 0), + source_range_for_text: Point::new(0, 0)..Point::new(0, 20), text: "struct DataProcessor".to_string(), highlight_ranges: vec![], name_ranges: vec![7..20], diff --git a/crates/languages/src/rust/outline.scm b/crates/languages/src/rust/outline.scm index 3012995e2a7f23f66b0c1a891789f8fbc3524e6c..a99f53dd2b3154aa3717f67fd683da4a8b57d31b 100644 --- a/crates/languages/src/rust/outline.scm +++ b/crates/languages/src/rust/outline.scm @@ -20,7 +20,7 @@ trait: (_)? @name "for"? @context type: (_) @name - body: (_ "{" @open (_)* "}" @close)) @item + body: (_ . "{" @open "}" @close .)) @item (trait_item (visibility_modifier)? @context @@ -31,7 +31,8 @@ (visibility_modifier)? @context (function_modifiers)? @context "fn" @context - name: (_) @name) @item + name: (_) @name + body: (_ . "{" @open "}" @close .)) @item (function_signature_item (visibility_modifier)? @context diff --git a/crates/multi_buffer/src/multi_buffer.rs b/crates/multi_buffer/src/multi_buffer.rs index 4cd112d231fb340b67a712f235cccddd067234b3..e3ea3b9c92014acff7dab6931b1f756224cee288 100644 --- a/crates/multi_buffer/src/multi_buffer.rs +++ b/crates/multi_buffer/src/multi_buffer.rs @@ -5451,6 +5451,8 @@ impl MultiBufferSnapshot { Some(OutlineItem { depth: item.depth, range: self.anchor_range_in_excerpt(*excerpt_id, item.range)?, + source_range_for_text: self + .anchor_range_in_excerpt(*excerpt_id, item.source_range_for_text)?, text: item.text, highlight_ranges: item.highlight_ranges, name_ranges: item.name_ranges, @@ -5484,6 +5486,11 @@ impl MultiBufferSnapshot { .flat_map(|item| { Some(OutlineItem { depth: item.depth, + source_range_for_text: Anchor::range_in_buffer( + excerpt_id, + buffer_id, + item.source_range_for_text, + ), range: Anchor::range_in_buffer(excerpt_id, buffer_id, item.range), text: item.text, highlight_ranges: item.highlight_ranges, diff --git a/crates/outline_panel/src/outline_panel.rs b/crates/outline_panel/src/outline_panel.rs index ebc5946acf97b763d7ec06d264aeaa7169d7c68b..112aa3d21ebda9ef57d3bedda20e3f90735a0173 100644 --- a/crates/outline_panel/src/outline_panel.rs +++ b/crates/outline_panel/src/outline_panel.rs @@ -2484,6 +2484,7 @@ impl OutlinePanel { annotation_range: None, range: search_data.context_range.clone(), text: search_data.context_text.clone(), + source_range_for_text: search_data.context_range.clone(), highlight_ranges: search_data .highlights_data .get() diff --git a/crates/zeta2/Cargo.toml b/crates/zeta2/Cargo.toml index 7ca140fa353b6404e451fdb79cccfed982b64e27..13bb4e9106de9f5f201ba59106304a6aab4208d1 100644 --- a/crates/zeta2/Cargo.toml +++ b/crates/zeta2/Cargo.toml @@ -18,6 +18,7 @@ chrono.workspace = true client.workspace = true cloud_llm_client.workspace = true cloud_zeta2_prompt.workspace = true +collections.workspace = true edit_prediction.workspace = true edit_prediction_context.workspace = true feature_flags.workspace = true @@ -29,6 +30,7 @@ language_model.workspace = true log.workspace = true project.workspace = true release_channel.workspace = true +schemars.workspace = true serde.workspace = true serde_json.workspace = true thiserror.workspace = true @@ -43,6 +45,7 @@ cloud_llm_client = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } lsp.workspace = true indoc.workspace = true +language = { workspace = true, features = ["test-support"] } language_model = { workspace = true, features = ["test-support"] } pretty_assertions.workspace = true project = { workspace = true, features = ["test-support"] } diff --git a/crates/zeta2/src/merge_excerpts.rs b/crates/zeta2/src/merge_excerpts.rs new file mode 100644 index 0000000000000000000000000000000000000000..4cb7ab6cf4d3b63e641087f0c22cf0f900f56adc --- /dev/null +++ b/crates/zeta2/src/merge_excerpts.rs @@ -0,0 +1,192 @@ +use cloud_llm_client::predict_edits_v3::{self, Excerpt}; +use edit_prediction_context::Line; +use language::{BufferSnapshot, Point}; +use std::ops::Range; + +pub fn merge_excerpts( + buffer: &BufferSnapshot, + sorted_line_ranges: impl IntoIterator>, +) -> Vec { + let mut output = Vec::new(); + let mut merged_ranges = Vec::>::new(); + + for line_range in sorted_line_ranges { + if let Some(last_line_range) = merged_ranges.last_mut() + && line_range.start <= last_line_range.end + { + last_line_range.end = last_line_range.end.max(line_range.end); + continue; + } + merged_ranges.push(line_range); + } + + let outline_items = buffer.outline_items_as_points_containing(0..buffer.len(), false, None); + let mut outline_items = outline_items.into_iter().peekable(); + + for range in merged_ranges { + let point_range = Point::new(range.start.0, 0)..Point::new(range.end.0, 0); + + while let Some(outline_item) = outline_items.peek() { + if outline_item.range.start >= point_range.start { + break; + } + if outline_item.range.end > point_range.start { + let mut point_range = outline_item.source_range_for_text.clone(); + point_range.start.column = 0; + point_range.end.column = buffer.line_len(point_range.end.row); + + output.push(Excerpt { + start_line: Line(point_range.start.row), + text: buffer + .text_for_range(point_range.clone()) + .collect::() + .into(), + }) + } + outline_items.next(); + } + + output.push(Excerpt { + start_line: Line(point_range.start.row), + text: buffer + .text_for_range(point_range.clone()) + .collect::() + .into(), + }) + } + + output +} + +pub fn write_merged_excerpts( + buffer: &BufferSnapshot, + sorted_line_ranges: impl IntoIterator>, + sorted_insertions: &[(predict_edits_v3::Point, &str)], + output: &mut String, +) { + cloud_zeta2_prompt::write_excerpts( + merge_excerpts(buffer, sorted_line_ranges).iter(), + sorted_insertions, + Line(buffer.max_point().row), + true, + output, + ); +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; + use gpui::{TestAppContext, prelude::*}; + use indoc::indoc; + use language::{Buffer, Language, LanguageConfig, LanguageMatcher, OffsetRangeExt}; + use pretty_assertions::assert_eq; + use util::test::marked_text_ranges; + + #[gpui::test] + fn test_rust(cx: &mut TestAppContext) { + let table = [ + ( + indoc! {r#" + struct User { + first_name: String, + « last_name: String, + ageˇ: u32, + » email: String, + create_at: Instant, + } + + impl User { + pub fn first_name(&self) -> String { + self.first_name.clone() + } + + pub fn full_name(&self) -> String { + « format!("{} {}", self.first_name, self.last_name) + » } + } + "#}, + indoc! {r#" + 1|struct User { + … + 3| last_name: String, + 4| age<|cursor|>: u32, + … + 9|impl User { + … + 14| pub fn full_name(&self) -> String { + 15| format!("{} {}", self.first_name, self.last_name) + … + "#}, + ), + ( + indoc! {r#" + struct User { + first_name: String, + « last_name: String, + age: u32, + } + »"# + }, + indoc! {r#" + 1|struct User { + … + 3| last_name: String, + 4| age: u32, + 5|} + "#}, + ), + ]; + + for (input, expected_output) in table { + let input_without_ranges = input.replace(['«', '»'], ""); + let input_without_caret = input.replace('ˇ', ""); + let cursor_offset = input_without_ranges.find('ˇ'); + let (input, ranges) = marked_text_ranges(&input_without_caret, false); + let buffer = + cx.new(|cx| Buffer::local(input, cx).with_language(Arc::new(rust_lang()), cx)); + buffer.read_with(cx, |buffer, _cx| { + let insertions = cursor_offset + .map(|offset| { + let point = buffer.offset_to_point(offset); + vec![( + predict_edits_v3::Point { + line: Line(point.row), + column: point.column, + }, + "<|cursor|>", + )] + }) + .unwrap_or_default(); + let ranges: Vec> = ranges + .into_iter() + .map(|range| { + let point_range = range.to_point(&buffer); + Line(point_range.start.row)..Line(point_range.end.row) + }) + .collect(); + + let mut output = String::new(); + write_merged_excerpts(&buffer.snapshot(), ranges, &insertions, &mut output); + assert_eq!(output, expected_output); + }); + } + } + + fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + Some(language::tree_sitter_rust::LANGUAGE.into()), + ) + .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) + .unwrap() + } +} diff --git a/crates/zeta2/src/provider.rs b/crates/zeta2/src/provider.rs index 3c0dd75cc23a6a7b18a0fba19d0eab0a4833ba9c..a19e7f9a1da5e1808c48e3ce0469d8b390698760 100644 --- a/crates/zeta2/src/provider.rs +++ b/crates/zeta2/src/provider.rs @@ -116,6 +116,10 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { return; } + self.zeta.update(cx, |zeta, cx| { + zeta.refresh_context_if_needed(&self.project, &buffer, cursor_position, cx); + }); + let pending_prediction_id = self.next_pending_prediction_id; self.next_pending_prediction_id += 1; let last_request_timestamp = self.last_request_timestamp; diff --git a/crates/zeta2/src/related_excerpts.rs b/crates/zeta2/src/related_excerpts.rs new file mode 100644 index 0000000000000000000000000000000000000000..2f30ee15dc72720fca896580febc9fa75b1bc346 --- /dev/null +++ b/crates/zeta2/src/related_excerpts.rs @@ -0,0 +1,586 @@ +use std::{cmp::Reverse, fmt::Write, ops::Range, path::PathBuf, sync::Arc}; + +use crate::merge_excerpts::write_merged_excerpts; +use anyhow::{Result, anyhow}; +use collections::HashMap; +use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions, Line}; +use futures::{StreamExt, stream::BoxStream}; +use gpui::{App, AsyncApp, Entity, Task}; +use indoc::indoc; +use language::{Anchor, Bias, Buffer, OffsetRangeExt, Point, TextBufferSnapshot, ToPoint as _}; +use language_model::{ + LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, + LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, + LanguageModelRequestTool, LanguageModelToolResult, MessageContent, Role, +}; +use project::{ + Project, WorktreeSettings, + search::{SearchQuery, SearchResult}, +}; +use schemars::JsonSchema; +use serde::Deserialize; +use util::paths::{PathMatcher, PathStyle}; +use workspace::item::Settings as _; + +const SEARCH_PROMPT: &str = indoc! {r#" + ## Task + + You are part of an edit prediction system in a code editor. Your role is to identify relevant code locations + that will serve as context for predicting the next required edit. + + **Your task:** + - Analyze the user's recent edits and current cursor context + - Use the `search` tool to find code that may be relevant for predicting the next edit + - Focus on finding: + - Code patterns that might need similar changes based on the recent edits + - Functions, variables, types, and constants referenced in the current cursor context + - Related implementations, usages, or dependencies that may require consistent updates + + **Important constraints:** + - This conversation has exactly 2 turns + - You must make ALL search queries in your first response via the `search` tool + - All queries will be executed in parallel and results returned together + - In the second turn, you will select the most relevant results via the `select` tool. + + ## User Edits + + {edits} + + ## Current cursor context + + `````filename={current_file_path} + {cursor_excerpt} + ````` + + -- + Use the `search` tool now +"#}; + +const SEARCH_TOOL_NAME: &str = "search"; + +/// Search for relevant code +/// +/// For the best results, run multiple queries at once with a single invocation of this tool. +#[derive(Deserialize, JsonSchema)] +struct SearchToolInput { + /// An array of queries to run for gathering context relevant to the next prediction + #[schemars(length(max = 5))] + queries: Box<[SearchToolQuery]>, +} + +#[derive(Deserialize, JsonSchema)] +struct SearchToolQuery { + /// A glob pattern to match file paths in the codebase + glob: String, + /// A regular expression to match content within the files matched by the glob pattern + regex: String, + /// Whether the regex is case-sensitive. Defaults to false (case-insensitive). + #[serde(default)] + case_sensitive: bool, +} + +const RESULTS_MESSAGE: &str = indoc! {" + Here are the results of your queries combined and grouped by file: + +"}; + +const SELECT_TOOL_NAME: &str = "select"; + +const SELECT_PROMPT: &str = indoc! {" + Use the `select` tool now to pick the most relevant line ranges according to the user state provided in the first message. + Make sure to include enough lines of context so that the edit prediction model can suggest accurate edits. + Include up to 200 lines in total. +"}; + +/// Select line ranges from search results +#[derive(Deserialize, JsonSchema)] +struct SelectToolInput { + /// The line ranges to select from search results. + ranges: Vec, +} + +/// A specific line range to select from a file +#[derive(Debug, Deserialize, JsonSchema)] +struct SelectLineRange { + /// The file path containing the lines to select + /// Exactly as it appears in the search result codeblocks. + path: PathBuf, + /// The starting line number (1-based) + #[schemars(range(min = 1))] + start_line: u32, + /// The ending line number (1-based, inclusive) + #[schemars(range(min = 1))] + end_line: u32, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct LlmContextOptions { + pub excerpt: EditPredictionExcerptOptions, +} + +pub fn find_related_excerpts<'a>( + buffer: Entity, + cursor_position: Anchor, + project: &Entity, + events: impl Iterator, + options: &LlmContextOptions, + cx: &App, +) -> Task, Vec>>>> { + let language_model_registry = LanguageModelRegistry::global(cx); + let Some(model) = language_model_registry + .read(cx) + .available_models(cx) + .find(|model| { + model.provider_id() == language_model::ANTHROPIC_PROVIDER_ID + && model.id() == LanguageModelId("claude-haiku-4-5-latest".into()) + }) + else { + return Task::ready(Err(anyhow!("could not find claude model"))); + }; + + let mut edits_string = String::new(); + + for event in events { + if let Some(event) = event.to_request_event(cx) { + writeln!(&mut edits_string, "{event}").ok(); + } + } + + if edits_string.is_empty() { + edits_string.push_str("(No user edits yet)"); + } + + // TODO [zeta2] include breadcrumbs? + let snapshot = buffer.read(cx).snapshot(); + let cursor_point = cursor_position.to_point(&snapshot); + let Some(cursor_excerpt) = + EditPredictionExcerpt::select_from_buffer(cursor_point, &snapshot, &options.excerpt, None) + else { + return Task::ready(Ok(HashMap::default())); + }; + + let current_file_path = snapshot + .file() + .map(|f| f.full_path(cx).display().to_string()) + .unwrap_or_else(|| "untitled".to_string()); + + let prompt = SEARCH_PROMPT + .replace("{edits}", &edits_string) + .replace("{current_file_path}", ¤t_file_path) + .replace("{cursor_excerpt}", &cursor_excerpt.text(&snapshot).body); + + let path_style = project.read(cx).path_style(cx); + + let exclude_matcher = { + let global_settings = WorktreeSettings::get_global(cx); + let exclude_patterns = global_settings + .file_scan_exclusions + .sources() + .iter() + .chain(global_settings.private_files.sources().iter()); + + match PathMatcher::new(exclude_patterns, path_style) { + Ok(matcher) => matcher, + Err(err) => { + return Task::ready(Err(anyhow!(err))); + } + } + }; + + let project = project.clone(); + cx.spawn(async move |cx| { + let initial_prompt_message = LanguageModelRequestMessage { + role: Role::User, + content: vec![prompt.into()], + cache: false, + }; + + let mut search_stream = request_tool_call::( + vec![initial_prompt_message.clone()], + SEARCH_TOOL_NAME, + &model, + cx, + ) + .await?; + + let mut select_request_messages = Vec::with_capacity(5); // initial prompt, LLM response/thinking, tool use, tool result, select prompt + select_request_messages.push(initial_prompt_message); + let mut search_calls = Vec::new(); + + while let Some(event) = search_stream.next().await { + match event? { + LanguageModelCompletionEvent::ToolUse(tool_use) => { + if !tool_use.is_input_complete { + continue; + } + + if tool_use.name.as_ref() == SEARCH_TOOL_NAME { + search_calls.push((select_request_messages.len(), tool_use)); + } else { + log::warn!( + "context gathering model tried to use unknown tool: {}", + tool_use.name + ); + } + } + LanguageModelCompletionEvent::Text(txt) => { + if let Some(LanguageModelRequestMessage { + role: Role::Assistant, + content, + .. + }) = select_request_messages.last_mut() + { + if let Some(MessageContent::Text(existing_text)) = content.last_mut() { + existing_text.push_str(&txt); + } else { + content.push(MessageContent::Text(txt)); + } + } else { + select_request_messages.push(LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![MessageContent::Text(txt)], + cache: false, + }); + } + } + LanguageModelCompletionEvent::Thinking { text, signature } => { + if let Some(LanguageModelRequestMessage { + role: Role::Assistant, + content, + .. + }) = select_request_messages.last_mut() + { + if let Some(MessageContent::Thinking { + text: existing_text, + signature: existing_signature, + }) = content.last_mut() + { + existing_text.push_str(&text); + *existing_signature = signature; + } else { + content.push(MessageContent::Thinking { text, signature }); + } + } else { + select_request_messages.push(LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![MessageContent::Thinking { text, signature }], + cache: false, + }); + } + } + LanguageModelCompletionEvent::RedactedThinking { data } => { + if let Some(LanguageModelRequestMessage { + role: Role::Assistant, + content, + .. + }) = select_request_messages.last_mut() + { + if let Some(MessageContent::RedactedThinking(existing_data)) = + content.last_mut() + { + existing_data.push_str(&data); + } else { + content.push(MessageContent::RedactedThinking(data)); + } + } else { + select_request_messages.push(LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![MessageContent::RedactedThinking(data)], + cache: false, + }); + } + } + ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => { + log::error!("{ev:?}"); + } + ev => { + log::trace!("context search event: {ev:?}") + } + } + } + + struct ResultBuffer { + buffer: Entity, + snapshot: TextBufferSnapshot, + } + + let mut result_buffers_by_path = HashMap::default(); + + for (index, tool_use) in search_calls.into_iter().rev() { + let call = serde_json::from_value::(tool_use.input.clone())?; + + let mut excerpts_by_buffer = HashMap::default(); + + for query in call.queries { + // TODO [zeta2] parallelize? + + run_query( + query, + &mut excerpts_by_buffer, + path_style, + exclude_matcher.clone(), + &project, + cx, + ) + .await?; + } + + if excerpts_by_buffer.is_empty() { + continue; + } + + let mut merged_result = RESULTS_MESSAGE.to_string(); + + for (buffer_entity, mut excerpts_for_buffer) in excerpts_by_buffer { + excerpts_for_buffer.sort_unstable_by_key(|range| (range.start, Reverse(range.end))); + + buffer_entity + .clone() + .read_with(cx, |buffer, cx| { + let Some(file) = buffer.file() else { + return; + }; + + let path = file.full_path(cx); + + writeln!(&mut merged_result, "`````filename={}", path.display()).unwrap(); + + let snapshot = buffer.snapshot(); + + write_merged_excerpts( + &snapshot, + excerpts_for_buffer, + &[], + &mut merged_result, + ); + + merged_result.push_str("`````\n\n"); + + result_buffers_by_path.insert( + path, + ResultBuffer { + buffer: buffer_entity, + snapshot: snapshot.text, + }, + ); + }) + .ok(); + } + + let tool_result = LanguageModelToolResult { + tool_use_id: tool_use.id.clone(), + tool_name: SEARCH_TOOL_NAME.into(), + is_error: false, + content: merged_result.into(), + output: None, + }; + + // Almost always appends at the end, but in theory, the model could return some text after the tool call + // or perform parallel tool calls, so we splice at the message index for correctness. + select_request_messages.splice( + index..index, + [ + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![MessageContent::ToolUse(tool_use)], + cache: false, + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::ToolResult(tool_result)], + cache: false, + }, + ], + ); + } + + if result_buffers_by_path.is_empty() { + log::trace!("context gathering queries produced no results"); + return anyhow::Ok(HashMap::default()); + } + + select_request_messages.push(LanguageModelRequestMessage { + role: Role::User, + content: vec![SELECT_PROMPT.into()], + cache: false, + }); + + let mut select_stream = request_tool_call::( + select_request_messages, + SELECT_TOOL_NAME, + &model, + cx, + ) + .await?; + let mut selected_ranges = Vec::new(); + + while let Some(event) = select_stream.next().await { + match event? { + LanguageModelCompletionEvent::ToolUse(tool_use) => { + if !tool_use.is_input_complete { + continue; + } + + if tool_use.name.as_ref() == SELECT_TOOL_NAME { + let call = + serde_json::from_value::(tool_use.input.clone())?; + selected_ranges.extend(call.ranges); + } else { + log::warn!( + "context gathering model tried to use unknown tool: {}", + tool_use.name + ); + } + } + ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => { + log::error!("{ev:?}"); + } + ev => { + log::trace!("context select event: {ev:?}") + } + } + } + + if selected_ranges.is_empty() { + log::trace!("context gathering selected no ranges") + } + + let mut related_excerpts_by_buffer: HashMap<_, Vec<_>> = HashMap::default(); + + for selected_range in selected_ranges { + if let Some(ResultBuffer { buffer, snapshot }) = + result_buffers_by_path.get(&selected_range.path) + { + let start_point = Point::new(selected_range.start_line.saturating_sub(1), 0); + let end_point = + snapshot.clip_point(Point::new(selected_range.end_line, 0), Bias::Left); + let range = snapshot.anchor_after(start_point)..snapshot.anchor_before(end_point); + + related_excerpts_by_buffer + .entry(buffer.clone()) + .or_default() + .push(range); + } else { + log::warn!( + "selected path that wasn't included in search results: {}", + selected_range.path.display() + ); + } + } + + for (buffer, ranges) in &mut related_excerpts_by_buffer { + buffer.read_with(cx, |buffer, _cx| { + ranges.sort_unstable_by(|a, b| { + a.start + .cmp(&b.start, buffer) + .then(b.end.cmp(&a.end, buffer)) + }); + })?; + } + + anyhow::Ok(related_excerpts_by_buffer) + }) +} + +async fn request_tool_call( + messages: Vec, + tool_name: &'static str, + model: &Arc, + cx: &mut AsyncApp, +) -> Result>> +{ + let schema = schemars::schema_for!(T); + + let request = LanguageModelRequest { + messages, + tools: vec![LanguageModelRequestTool { + name: tool_name.into(), + description: schema + .get("description") + .and_then(|description| description.as_str()) + .unwrap() + .to_string(), + input_schema: serde_json::to_value(schema).unwrap(), + }], + ..Default::default() + }; + + Ok(model.stream_completion(request, cx).await?) +} + +const MIN_EXCERPT_LEN: usize = 16; +const MAX_EXCERPT_LEN: usize = 768; +const MAX_RESULT_BYTES_PER_QUERY: usize = MAX_EXCERPT_LEN * 5; + +async fn run_query( + args: SearchToolQuery, + excerpts_by_buffer: &mut HashMap, Vec>>, + path_style: PathStyle, + exclude_matcher: PathMatcher, + project: &Entity, + cx: &mut AsyncApp, +) -> Result<()> { + let include_matcher = PathMatcher::new(vec![args.glob], path_style)?; + + let query = SearchQuery::regex( + &args.regex, + false, + args.case_sensitive, + false, + true, + include_matcher, + exclude_matcher, + true, + None, + )?; + + let results = project.update(cx, |project, cx| project.search(query, cx))?; + futures::pin_mut!(results); + + let mut total_bytes = 0; + + while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await { + if ranges.is_empty() { + continue; + } + + let excerpts_for_buffer = excerpts_by_buffer + .entry(buffer.clone()) + .or_insert_with(|| Vec::with_capacity(ranges.len())); + + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + + for range in ranges { + let offset_range = range.to_offset(&snapshot); + let query_point = (offset_range.start + offset_range.len() / 2).to_point(&snapshot); + + if total_bytes + MIN_EXCERPT_LEN >= MAX_RESULT_BYTES_PER_QUERY { + break; + } + + let excerpt = EditPredictionExcerpt::select_from_buffer( + query_point, + &snapshot, + &EditPredictionExcerptOptions { + max_bytes: MAX_EXCERPT_LEN.min(MAX_RESULT_BYTES_PER_QUERY - total_bytes), + min_bytes: MIN_EXCERPT_LEN, + target_before_cursor_over_total_bytes: 0.5, + }, + None, + ); + + if let Some(excerpt) = excerpt { + total_bytes += excerpt.range.len(); + if !excerpt.line_range.is_empty() { + excerpts_for_buffer.push(excerpt.line_range); + } + } + } + + if excerpts_for_buffer.is_empty() { + excerpts_by_buffer.remove(&buffer); + } + } + + anyhow::Ok(()) +} diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index 42eb565502e6568491e820dfb5c0921e4d56039b..48eda0f79aec57c6061c2287a80a8075e5badc74 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -6,10 +6,12 @@ use cloud_llm_client::{ AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME, }; -use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, PlannedPrompt}; +use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, build_prompt}; +use collections::HashMap; use edit_prediction_context::{ DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions, - EditPredictionExcerptOptions, EditPredictionScoreOptions, SyntaxIndex, SyntaxIndexState, + EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line, + SyntaxIndex, SyntaxIndexState, }; use feature_flags::{FeatureFlag, FeatureFlagAppExt as _}; use futures::AsyncReadExt as _; @@ -19,25 +21,32 @@ use gpui::{ App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity, http_client, prelude::*, }; -use language::BufferSnapshot; -use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint}; +use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint}; +use language::{BufferSnapshot, OffsetRangeExt}; use language_model::{LlmApiToken, RefreshLlmTokenListener}; use project::Project; use release_channel::AppVersion; use serde::de::DeserializeOwned; -use std::collections::{HashMap, VecDeque, hash_map}; +use std::collections::{VecDeque, hash_map}; +use std::ops::Range; use std::path::Path; use std::str::FromStr as _; use std::sync::Arc; use std::time::{Duration, Instant}; use thiserror::Error; +use util::ResultExt as _; use util::rel_path::RelPathBuf; use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; +mod merge_excerpts; mod prediction; mod provider; +mod related_excerpts; +use crate::merge_excerpts::merge_excerpts; use crate::prediction::EditPrediction; +pub use crate::related_excerpts::LlmContextOptions; +use crate::related_excerpts::find_related_excerpts; pub use provider::ZetaEditPredictionProvider; const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1); @@ -45,19 +54,28 @@ const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1); /// Maximum number of events to track. const MAX_EVENT_COUNT: usize = 16; -pub const DEFAULT_CONTEXT_OPTIONS: EditPredictionContextOptions = EditPredictionContextOptions { - use_imports: true, - max_retrieved_declarations: 0, - excerpt: EditPredictionExcerptOptions { - max_bytes: 512, - min_bytes: 128, - target_before_cursor_over_total_bytes: 0.5, - }, - score: EditPredictionScoreOptions { - omit_excerpt_overlaps: true, - }, +pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions { + max_bytes: 512, + min_bytes: 128, + target_before_cursor_over_total_bytes: 0.5, +}; + +pub const DEFAULT_CONTEXT_OPTIONS: ContextMode = ContextMode::Llm(DEFAULT_LLM_CONTEXT_OPTIONS); + +pub const DEFAULT_LLM_CONTEXT_OPTIONS: LlmContextOptions = LlmContextOptions { + excerpt: DEFAULT_EXCERPT_OPTIONS, }; +pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions = + EditPredictionContextOptions { + use_imports: true, + max_retrieved_declarations: 0, + excerpt: DEFAULT_EXCERPT_OPTIONS, + score: EditPredictionScoreOptions { + omit_excerpt_overlaps: true, + }, + }; + pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions { context: DEFAULT_CONTEXT_OPTIONS, max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES, @@ -94,13 +112,28 @@ pub struct Zeta { #[derive(Debug, Clone, PartialEq)] pub struct ZetaOptions { - pub context: EditPredictionContextOptions, + pub context: ContextMode, pub max_prompt_bytes: usize, pub max_diagnostic_bytes: usize, pub prompt_format: predict_edits_v3::PromptFormat, pub file_indexing_parallelism: usize, } +#[derive(Debug, Clone, PartialEq)] +pub enum ContextMode { + Llm(LlmContextOptions), + Syntax(EditPredictionContextOptions), +} + +impl ContextMode { + pub fn excerpt(&self) -> &EditPredictionExcerptOptions { + match self { + ContextMode::Llm(options) => &options.excerpt, + ContextMode::Syntax(options) => &options.excerpt, + } + } +} + pub struct PredictionDebugInfo { pub request: predict_edits_v3::PredictEditsRequest, pub retrieval_time: TimeDelta, @@ -117,6 +150,10 @@ struct ZetaProject { events: VecDeque, registered_buffers: HashMap, current_prediction: Option, + context: Option, Vec>>>, + refresh_context_task: Option>>, + refresh_context_debounce_task: Option>>, + refresh_context_timestamp: Option, } #[derive(Debug, Clone)] @@ -183,6 +220,44 @@ pub enum Event { }, } +impl Event { + pub fn to_request_event(&self, cx: &App) -> Option { + match self { + Event::BufferChange { + old_snapshot, + new_snapshot, + .. + } => { + let path = new_snapshot.file().map(|f| f.full_path(cx)); + + let old_path = old_snapshot.file().and_then(|f| { + let old_path = f.full_path(cx); + if Some(&old_path) != path.as_ref() { + Some(old_path) + } else { + None + } + }); + + // TODO [zeta2] move to bg? + let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text()); + + if path == old_path && diff.is_empty() { + None + } else { + Some(predict_edits_v3::Event::BufferChange { + old_path, + path, + diff, + //todo: Actually detect if this edit was predicted or not + predicted: false, + }) + } + } + } + } +} + impl Zeta { pub fn try_global(cx: &App) -> Option> { cx.try_global::().map(|global| global.0.clone()) @@ -206,7 +281,7 @@ impl Zeta { let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); Self { - projects: HashMap::new(), + projects: HashMap::default(), client, user_store, options: DEFAULT_OPTIONS, @@ -248,6 +323,14 @@ impl Zeta { } } + pub fn history_for_project(&self, project: &Entity) -> impl Iterator { + static EMPTY_EVENTS: VecDeque = VecDeque::new(); + self.projects + .get(&project.entity_id()) + .map_or(&EMPTY_EVENTS, |project| &project.events) + .iter() + } + pub fn usage(&self, cx: &App) -> Option { self.user_store.read(cx).edit_prediction_usage() } @@ -278,8 +361,12 @@ impl Zeta { SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx) }), events: VecDeque::new(), - registered_buffers: HashMap::new(), + registered_buffers: HashMap::default(), current_prediction: None, + context: None, + refresh_context_task: None, + refresh_context_debounce_task: None, + refresh_context_timestamp: None, }) } @@ -507,7 +594,10 @@ impl Zeta { }); let options = self.options.clone(); let snapshot = buffer.read(cx).snapshot(); - let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else { + let Some(excerpt_path) = snapshot + .file() + .map(|path| -> Arc { path.full_path(cx).into() }) + else { return Task::ready(Err(anyhow!("No file path for excerpt"))); }; let client = self.client.clone(); @@ -525,40 +615,7 @@ impl Zeta { state .events .iter() - .filter_map(|event| match event { - Event::BufferChange { - old_snapshot, - new_snapshot, - .. - } => { - let path = new_snapshot.file().map(|f| f.full_path(cx)); - - let old_path = old_snapshot.file().and_then(|f| { - let old_path = f.full_path(cx); - if Some(&old_path) != path.as_ref() { - Some(old_path) - } else { - None - } - }); - - // TODO [zeta2] move to bg? - let diff = - language::unified_diff(&old_snapshot.text(), &new_snapshot.text()); - - if path == old_path && diff.is_empty() { - None - } else { - Some(predict_edits_v3::Event::BufferChange { - old_path, - path, - diff, - //todo: Actually detect if this edit was predicted or not - predicted: false, - }) - } - } - }) + .filter_map(|event| event.to_request_event(cx)) .collect::>() }) .unwrap_or_default(); @@ -573,6 +630,20 @@ impl Zeta { // TODO data collection let can_collect_data = cx.is_staff(); + let mut included_files = project_state + .and_then(|project_state| project_state.context.as_ref()) + .unwrap_or(&HashMap::default()) + .iter() + .filter_map(|(buffer, ranges)| { + let buffer = buffer.read(cx); + Some(( + buffer.snapshot(), + buffer.file()?.full_path(cx).into(), + ranges.clone(), + )) + }) + .collect::>(); + let request_task = cx.background_spawn({ let snapshot = snapshot.clone(); let buffer = buffer.clone(); @@ -588,18 +659,6 @@ impl Zeta { let before_retrieval = chrono::Utc::now(); - let Some(context) = EditPredictionContext::gather_context( - cursor_point, - &snapshot, - parent_abs_path.as_deref(), - &options.context, - index_state.as_deref(), - ) else { - return Ok((None, None)); - }; - - let retrieval_time = chrono::Utc::now() - before_retrieval; - let (diagnostic_groups, diagnostic_groups_truncated) = Self::gather_nearby_diagnostics( cursor_offset, @@ -608,26 +667,127 @@ impl Zeta { options.max_diagnostic_bytes, ); - let request = make_cloud_request( - excerpt_path, - context, - events, - can_collect_data, - diagnostic_groups, - diagnostic_groups_truncated, - None, - debug_tx.is_some(), - &worktree_snapshots, - index_state.as_deref(), - Some(options.max_prompt_bytes), - options.prompt_format, - ); + let request = match options.context { + ContextMode::Llm(context_options) => { + let Some(excerpt) = EditPredictionExcerpt::select_from_buffer( + cursor_point, + &snapshot, + &context_options.excerpt, + index_state.as_deref(), + ) else { + return Ok((None, None)); + }; + + let excerpt_anchor_range = snapshot.anchor_after(excerpt.range.start) + ..snapshot.anchor_before(excerpt.range.end); + + if let Some(buffer_ix) = included_files + .iter() + .position(|(buffer, _, _)| buffer.remote_id() == snapshot.remote_id()) + { + let (buffer, _, ranges) = &mut included_files[buffer_ix]; + let range_ix = ranges + .binary_search_by(|probe| { + probe + .start + .cmp(&excerpt_anchor_range.start, buffer) + .then(excerpt_anchor_range.end.cmp(&probe.end, buffer)) + }) + .unwrap_or_else(|ix| ix); + + ranges.insert(range_ix, excerpt_anchor_range); + let last_ix = included_files.len() - 1; + included_files.swap(buffer_ix, last_ix); + } else { + included_files.push(( + snapshot, + excerpt_path.clone(), + vec![excerpt_anchor_range], + )); + } + + let included_files = included_files + .into_iter() + .map(|(buffer, path, ranges)| { + let excerpts = merge_excerpts( + &buffer, + ranges.iter().map(|range| { + let point_range = range.to_point(&buffer); + Line(point_range.start.row)..Line(point_range.end.row) + }), + ); + predict_edits_v3::IncludedFile { + path, + max_row: Line(buffer.max_point().row), + excerpts, + } + }) + .collect::>(); + + predict_edits_v3::PredictEditsRequest { + excerpt_path, + excerpt: String::new(), + excerpt_line_range: Line(0)..Line(0), + excerpt_range: 0..0, + cursor_point: predict_edits_v3::Point { + line: predict_edits_v3::Line(cursor_point.row), + column: cursor_point.column, + }, + included_files, + referenced_declarations: vec![], + events, + can_collect_data, + diagnostic_groups, + diagnostic_groups_truncated, + debug_info: debug_tx.is_some(), + prompt_max_bytes: Some(options.max_prompt_bytes), + prompt_format: options.prompt_format, + // TODO [zeta2] + signatures: vec![], + excerpt_parent: None, + git_info: None, + } + } + ContextMode::Syntax(context_options) => { + let Some(context) = EditPredictionContext::gather_context( + cursor_point, + &snapshot, + parent_abs_path.as_deref(), + &context_options, + index_state.as_deref(), + ) else { + return Ok((None, None)); + }; + + make_syntax_context_cloud_request( + excerpt_path, + context, + events, + can_collect_data, + diagnostic_groups, + diagnostic_groups_truncated, + None, + debug_tx.is_some(), + &worktree_snapshots, + index_state.as_deref(), + Some(options.max_prompt_bytes), + options.prompt_format, + ) + } + }; + + let retrieval_time = chrono::Utc::now() - before_retrieval; let debug_response_tx = if let Some(debug_tx) = &debug_tx { let (response_tx, response_rx) = oneshot::channel(); - let local_prompt = PlannedPrompt::populate(&request) - .and_then(|p| p.to_prompt_string().map(|p| p.0)) + if !request.referenced_declarations.is_empty() || !request.signatures.is_empty() + { + } else { + }; + + let local_prompt = build_prompt(&request) + .map(|(prompt, _)| prompt) .map_err(|err| err.to_string()); debug_tx @@ -827,6 +987,103 @@ impl Zeta { } } + pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10); + pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3); + + // Refresh the related excerpts when the user just beguns editing after + // an idle period, and after they pause editing. + fn refresh_context_if_needed( + &mut self, + project: &Entity, + buffer: &Entity, + cursor_position: language::Anchor, + cx: &mut Context, + ) { + if !matches!(&self.options().context, ContextMode::Llm { .. }) { + return; + } + + let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else { + return; + }; + + let now = Instant::now(); + let was_idle = zeta_project + .refresh_context_timestamp + .map_or(true, |timestamp| { + now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION + }); + zeta_project.refresh_context_timestamp = Some(now); + zeta_project.refresh_context_debounce_task = Some(cx.spawn({ + let buffer = buffer.clone(); + let project = project.clone(); + async move |this, cx| { + if was_idle { + log::debug!("refetching edit prediction context after idle"); + } else { + cx.background_executor() + .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION) + .await; + log::debug!("refetching edit prediction context after pause"); + } + this.update(cx, |this, cx| { + this.refresh_context(project, buffer, cursor_position, cx); + }) + .ok() + } + })); + } + + // Refresh the related excerpts asynchronously. Ensure the task runs to completion, + // and avoid spawning more than one concurrent task. + fn refresh_context( + &mut self, + project: Entity, + buffer: Entity, + cursor_position: language::Anchor, + cx: &mut Context, + ) { + let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else { + return; + }; + + zeta_project + .refresh_context_task + .get_or_insert(cx.spawn(async move |this, cx| { + let related_excerpts = this + .update(cx, |this, cx| { + let Some(zeta_project) = this.projects.get(&project.entity_id()) else { + return Task::ready(anyhow::Ok(HashMap::default())); + }; + + let ContextMode::Llm(options) = &this.options().context else { + return Task::ready(anyhow::Ok(HashMap::default())); + }; + + find_related_excerpts( + buffer.clone(), + cursor_position, + &project, + zeta_project.events.iter(), + options, + cx, + ) + }) + .ok()? + .await + .log_err() + .unwrap_or_default(); + this.update(cx, |this, _cx| { + let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else { + return; + }; + zeta_project.context = Some(related_excerpts); + zeta_project.refresh_context_task.take(); + }) + .ok() + })); + } + fn gather_nearby_diagnostics( cursor_offset: usize, diagnostic_sets: &[(LanguageServerId, DiagnosticSet)], @@ -918,12 +1175,20 @@ impl Zeta { cursor_point, &snapshot, parent_abs_path.as_deref(), - &options.context, + match &options.context { + ContextMode::Llm(_) => { + // TODO + panic!("Llm mode not supported in zeta cli yet"); + } + ContextMode::Syntax(edit_prediction_context_options) => { + edit_prediction_context_options + } + }, index_state.as_deref(), ) .context("Failed to select excerpt") .map(|context| { - make_cloud_request( + make_syntax_context_cloud_request( excerpt_path.into(), context, // TODO pass everything @@ -963,7 +1228,7 @@ pub struct ZedUpdateRequiredError { minimum_version: SemanticVersion, } -fn make_cloud_request( +fn make_syntax_context_cloud_request( excerpt_path: Arc, context: EditPredictionContext, events: Vec, @@ -1044,6 +1309,7 @@ fn make_cloud_request( column: context.cursor_point.column, }, referenced_declarations, + included_files: vec![], signatures, excerpt_parent, events, diff --git a/crates/zeta2_tools/src/zeta2_tools.rs b/crates/zeta2_tools/src/zeta2_tools.rs index 2319df2a49d04c7e73180830ecf9778380bbf025..d44852971b3a06b240ab1a827989cf81c0be58de 100644 --- a/crates/zeta2_tools/src/zeta2_tools.rs +++ b/crates/zeta2_tools/src/zeta2_tools.rs @@ -20,7 +20,10 @@ use ui::{ButtonLike, ContextMenu, ContextMenuEntry, DropdownMenu, KeyBinding, pr use ui_input::InputField; use util::{ResultExt, paths::PathStyle, rel_path::RelPath}; use workspace::{Item, SplitDirection, Workspace}; -use zeta2::{PredictionDebugInfo, Zeta, Zeta2FeatureFlag, ZetaOptions}; +use zeta2::{ + ContextMode, DEFAULT_SYNTAX_CONTEXT_OPTIONS, LlmContextOptions, PredictionDebugInfo, Zeta, + Zeta2FeatureFlag, ZetaOptions, +}; use edit_prediction_context::{EditPredictionContextOptions, EditPredictionExcerptOptions}; @@ -69,7 +72,7 @@ pub struct Zeta2Inspector { min_excerpt_bytes_input: Entity, cursor_context_ratio_input: Entity, max_prompt_bytes_input: Entity, - max_retrieved_declarations: Entity, + context_mode: ContextModeState, active_view: ActiveView, zeta: Entity, _active_editor_subscription: Option, @@ -77,6 +80,13 @@ pub struct Zeta2Inspector { _receive_task: Task<()>, } +pub enum ContextModeState { + Llm, + Syntax { + max_retrieved_declarations: Entity, + }, +} + #[derive(PartialEq)] enum ActiveView { Context, @@ -143,36 +153,34 @@ impl Zeta2Inspector { min_excerpt_bytes_input: Self::number_input("Min Excerpt Bytes", window, cx), cursor_context_ratio_input: Self::number_input("Cursor Context Ratio", window, cx), max_prompt_bytes_input: Self::number_input("Max Prompt Bytes", window, cx), - max_retrieved_declarations: Self::number_input("Max Retrieved Definitions", window, cx), + context_mode: ContextModeState::Llm, zeta: zeta.clone(), _active_editor_subscription: None, _update_state_task: Task::ready(()), _receive_task: receive_task, }; - this.set_input_options(&zeta.read(cx).options().clone(), window, cx); + this.set_options_state(&zeta.read(cx).options().clone(), window, cx); this } - fn set_input_options( + fn set_options_state( &mut self, options: &ZetaOptions, window: &mut Window, cx: &mut Context, ) { + let excerpt_options = options.context.excerpt(); self.max_excerpt_bytes_input.update(cx, |input, cx| { - input.set_text(options.context.excerpt.max_bytes.to_string(), window, cx); + input.set_text(excerpt_options.max_bytes.to_string(), window, cx); }); self.min_excerpt_bytes_input.update(cx, |input, cx| { - input.set_text(options.context.excerpt.min_bytes.to_string(), window, cx); + input.set_text(excerpt_options.min_bytes.to_string(), window, cx); }); self.cursor_context_ratio_input.update(cx, |input, cx| { input.set_text( format!( "{:.2}", - options - .context - .excerpt - .target_before_cursor_over_total_bytes + excerpt_options.target_before_cursor_over_total_bytes ), window, cx, @@ -181,20 +189,28 @@ impl Zeta2Inspector { self.max_prompt_bytes_input.update(cx, |input, cx| { input.set_text(options.max_prompt_bytes.to_string(), window, cx); }); - self.max_retrieved_declarations.update(cx, |input, cx| { - input.set_text( - options.context.max_retrieved_declarations.to_string(), - window, - cx, - ); - }); + + match &options.context { + ContextMode::Llm(_) => { + self.context_mode = ContextModeState::Llm; + } + ContextMode::Syntax(_) => { + self.context_mode = ContextModeState::Syntax { + max_retrieved_declarations: Self::number_input( + "Max Retrieved Definitions", + window, + cx, + ), + }; + } + } cx.notify(); } - fn set_options(&mut self, options: ZetaOptions, cx: &mut Context) { + fn set_zeta_options(&mut self, options: ZetaOptions, cx: &mut Context) { self.zeta.update(cx, |this, _cx| this.set_options(options)); - const THROTTLE_TIME: Duration = Duration::from_millis(100); + const DEBOUNCE_TIME: Duration = Duration::from_millis(100); if let Some(prediction) = self.last_prediction.as_mut() { if let Some(buffer) = prediction.buffer.upgrade() { @@ -202,7 +218,7 @@ impl Zeta2Inspector { let zeta = self.zeta.clone(); let project = self.project.clone(); prediction._task = Some(cx.spawn(async move |_this, cx| { - cx.background_executor().timer(THROTTLE_TIME).await; + cx.background_executor().timer(DEBOUNCE_TIME).await; if let Some(task) = zeta .update(cx, |zeta, cx| { zeta.refresh_prediction(&project, &buffer, position, cx) @@ -255,25 +271,40 @@ impl Zeta2Inspector { let zeta_options = this.zeta.read(cx).options().clone(); - let context_options = EditPredictionContextOptions { - excerpt: EditPredictionExcerptOptions { - max_bytes: number_input_value(&this.max_excerpt_bytes_input, cx), - min_bytes: number_input_value(&this.min_excerpt_bytes_input, cx), - target_before_cursor_over_total_bytes: number_input_value( - &this.cursor_context_ratio_input, - cx, - ), - }, - max_retrieved_declarations: number_input_value( - &this.max_retrieved_declarations, + let excerpt_options = EditPredictionExcerptOptions { + max_bytes: number_input_value(&this.max_excerpt_bytes_input, cx), + min_bytes: number_input_value(&this.min_excerpt_bytes_input, cx), + target_before_cursor_over_total_bytes: number_input_value( + &this.cursor_context_ratio_input, cx, ), - ..zeta_options.context }; - this.set_options( + let context = match zeta_options.context { + ContextMode::Llm(_context_options) => ContextMode::Llm(LlmContextOptions { + excerpt: excerpt_options, + }), + ContextMode::Syntax(context_options) => { + let max_retrieved_declarations = match &this.context_mode { + ContextModeState::Llm => { + zeta2::DEFAULT_SYNTAX_CONTEXT_OPTIONS.max_retrieved_declarations + } + ContextModeState::Syntax { + max_retrieved_declarations, + } => number_input_value(max_retrieved_declarations, cx), + }; + + ContextMode::Syntax(EditPredictionContextOptions { + excerpt: excerpt_options, + max_retrieved_declarations, + ..context_options + }) + } + }; + + this.set_zeta_options( ZetaOptions { - context: context_options, + context, max_prompt_bytes: number_input_value(&this.max_prompt_bytes_input, cx), max_diagnostic_bytes: zeta_options.max_diagnostic_bytes, prompt_format: zeta_options.prompt_format, @@ -709,7 +740,7 @@ impl Zeta2Inspector { .style(ButtonStyle::Outlined) .size(ButtonSize::Large) .on_click(cx.listener(|this, _, window, cx| { - this.set_input_options(&zeta2::DEFAULT_OPTIONS, window, cx); + this.set_options_state(&zeta2::DEFAULT_OPTIONS, window, cx); })), ), ) @@ -722,19 +753,113 @@ impl Zeta2Inspector { .items_end() .child(self.max_excerpt_bytes_input.clone()) .child(self.min_excerpt_bytes_input.clone()) - .child(self.cursor_context_ratio_input.clone()), + .child(self.cursor_context_ratio_input.clone()) + .child(self.render_context_mode_dropdown(window, cx)), ) .child( h_flex() .gap_2() .items_end() - .child(self.max_retrieved_declarations.clone()) + .children(match &self.context_mode { + ContextModeState::Llm => None, + ContextModeState::Syntax { + max_retrieved_declarations, + } => Some(max_retrieved_declarations.clone()), + }) .child(self.max_prompt_bytes_input.clone()) .child(self.render_prompt_format_dropdown(window, cx)), ), ) } + fn render_context_mode_dropdown(&self, window: &mut Window, cx: &mut Context) -> Div { + let this = cx.weak_entity(); + + v_flex() + .gap_1p5() + .child( + Label::new("Context Mode") + .size(LabelSize::Small) + .color(Color::Muted), + ) + .child( + DropdownMenu::new( + "ep-ctx-mode", + match &self.context_mode { + ContextModeState::Llm => "LLM-based", + ContextModeState::Syntax { .. } => "Syntax", + }, + ContextMenu::build(window, cx, move |menu, _window, _cx| { + menu.item( + ContextMenuEntry::new("LLM-based") + .toggleable( + IconPosition::End, + matches!(self.context_mode, ContextModeState::Llm), + ) + .handler({ + let this = this.clone(); + move |window, cx| { + this.update(cx, |this, cx| { + let current_options = + this.zeta.read(cx).options().clone(); + match current_options.context.clone() { + ContextMode::Llm(_) => {} + ContextMode::Syntax(context_options) => { + let options = ZetaOptions { + context: ContextMode::Llm( + LlmContextOptions { + excerpt: context_options.excerpt, + }, + ), + ..current_options + }; + this.set_options_state(&options, window, cx); + this.set_zeta_options(options, cx); + } + } + }) + .ok(); + } + }), + ) + .item( + ContextMenuEntry::new("Syntax") + .toggleable( + IconPosition::End, + matches!(self.context_mode, ContextModeState::Syntax { .. }), + ) + .handler({ + move |window, cx| { + this.update(cx, |this, cx| { + let current_options = + this.zeta.read(cx).options().clone(); + match current_options.context.clone() { + ContextMode::Llm(context_options) => { + let options = ZetaOptions { + context: ContextMode::Syntax( + EditPredictionContextOptions { + excerpt: context_options.excerpt, + ..DEFAULT_SYNTAX_CONTEXT_OPTIONS + }, + ), + ..current_options + }; + this.set_options_state(&options, window, cx); + this.set_zeta_options(options, cx); + } + ContextMode::Syntax(_) => {} + } + }) + .ok(); + } + }), + ) + }), + ) + .style(ui::DropdownStyle::Outlined), + ) + } + fn render_prompt_format_dropdown(&self, window: &mut Window, cx: &mut Context) -> Div { let active_format = self.zeta.read(cx).options().prompt_format; let this = cx.weak_entity(); @@ -765,7 +890,7 @@ impl Zeta2Inspector { prompt_format, ..current_options }; - this.set_options(options, cx); + this.set_zeta_options(options, cx); }) .ok(); } diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index 149b13719f2075143d81c164e8d91bbdaca17384..eea80898870d68a8ad361de43d4556438ed25444 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -20,6 +20,7 @@ use reqwest_client::ReqwestClient; use serde_json::json; use std::{collections::HashSet, path::PathBuf, process::exit, str::FromStr, sync::Arc}; use zeta::{PerformPredictEditsParams, Zeta}; +use zeta2::ContextMode; use crate::headless::ZetaCliAppState; use crate::source_location::SourceLocation; @@ -263,8 +264,8 @@ async fn get_context( })? .await?; - let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(&request)?; - let (prompt_string, section_labels) = planned_prompt.to_prompt_string()?; + let (prompt_string, section_labels) = + cloud_zeta2_prompt::build_prompt(&request)?; match zeta2_args.output_format { OutputFormat::Prompt => anyhow::Ok(prompt_string), @@ -301,7 +302,7 @@ async fn get_context( impl Zeta2Args { fn to_options(&self, omit_excerpt_overlaps: bool) -> zeta2::ZetaOptions { zeta2::ZetaOptions { - context: EditPredictionContextOptions { + context: ContextMode::Syntax(EditPredictionContextOptions { max_retrieved_declarations: self.max_retrieved_definitions, use_imports: !self.disable_imports_gathering, excerpt: EditPredictionExcerptOptions { @@ -313,7 +314,7 @@ impl Zeta2Args { score: EditPredictionScoreOptions { omit_excerpt_overlaps, }, - }, + }), max_diagnostic_bytes: self.max_diagnostic_bytes, max_prompt_bytes: self.max_prompt_bytes, prompt_format: self.prompt_format.clone().into(), diff --git a/crates/zeta_cli/src/retrieval_stats.rs b/crates/zeta_cli/src/retrieval_stats.rs index bf1f78200ec5dd9262b6ae8937695b690155e8e2..f2634b1323d92b7136c591627226161b2905a955 100644 --- a/crates/zeta_cli/src/retrieval_stats.rs +++ b/crates/zeta_cli/src/retrieval_stats.rs @@ -3,8 +3,8 @@ use ::util::{RangeExt, ResultExt as _}; use anyhow::{Context as _, Result}; use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents; use edit_prediction_context::{ - Declaration, DeclarationStyle, EditPredictionContext, Identifier, Imports, Reference, - ReferenceRegion, SyntaxIndex, SyntaxIndexState, references_in_range, + Declaration, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions, Identifier, + Imports, Reference, ReferenceRegion, SyntaxIndex, SyntaxIndexState, references_in_range, }; use futures::StreamExt as _; use futures::channel::mpsc; @@ -32,6 +32,7 @@ use std::{ time::Duration, }; use util::paths::PathStyle; +use zeta2::ContextMode; use crate::headless::ZetaCliAppState; use crate::source_location::SourceLocation; @@ -46,6 +47,10 @@ pub async fn retrieval_stats( options: zeta2::ZetaOptions, cx: &mut AsyncApp, ) -> Result { + let ContextMode::Syntax(context_options) = options.context.clone() else { + anyhow::bail!("retrieval stats only works in ContextMode::Syntax"); + }; + let options = Arc::new(options); let worktree_path = worktree.canonicalize()?; @@ -264,10 +269,10 @@ pub async fn retrieval_stats( .map(|project_file| { let index_state = index_state.clone(); let lsp_definitions = lsp_definitions.clone(); - let options = options.clone(); let output_tx = output_tx.clone(); let done_count = done_count.clone(); let file_snapshots = file_snapshots.clone(); + let context_options = context_options.clone(); cx.background_spawn(async move { let snapshot = project_file.snapshot; @@ -279,7 +284,7 @@ pub async fn retrieval_stats( &snapshot, ); - let imports = if options.context.use_imports { + let imports = if context_options.use_imports { Imports::gather(&snapshot, Some(&project_file.parent_abs_path)) } else { Imports::default() @@ -311,7 +316,7 @@ pub async fn retrieval_stats( &snapshot, &index_state, &file_snapshots, - &options, + &context_options, ) .await?; @@ -958,7 +963,7 @@ async fn retrieve_definitions( snapshot: &BufferSnapshot, index: &Arc, file_snapshots: &Arc>, - options: &Arc, + context_options: &EditPredictionContextOptions, ) -> Result { let mut single_reference_map = HashMap::default(); single_reference_map.insert(reference.identifier.clone(), vec![reference.clone()]); @@ -966,7 +971,7 @@ async fn retrieve_definitions( query_point, snapshot, imports, - &options.context, + &context_options, Some(&index), |_, _, _| single_reference_map, );