From 4532765ae845b8c98c73c88cc916f7d771b429d5 Mon Sep 17 00:00:00 2001 From: Michael Sloan Date: Tue, 23 Sep 2025 00:20:26 -0600 Subject: [PATCH] zeta2: Add prompt planner and provide access via zeta_cli (#38691) Release Notes: - N/A --- Cargo.lock | 15 + Cargo.toml | 2 + .../cloud_llm_client/src/predict_edits_v3.rs | 5 +- crates/cloud_zeta2_prompt/Cargo.toml | 20 + crates/cloud_zeta2_prompt/LICENSE-GPL | 1 + .../src/cloud_zeta2_prompt.rs | 396 ++++++++++++++++++ .../src/declaration.rs | 58 +-- .../src/declaration_scoring.rs | 5 +- .../src/syntax_index.rs | 4 +- .../src/wip_requests.rs | 35 -- crates/zeta2/src/zeta2.rs | 77 +++- crates/zeta_cli/Cargo.toml | 5 +- crates/zeta_cli/src/main.rs | 143 +++++-- 13 files changed, 667 insertions(+), 99 deletions(-) create mode 100644 crates/cloud_zeta2_prompt/Cargo.toml create mode 120000 crates/cloud_zeta2_prompt/LICENSE-GPL create mode 100644 crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs delete mode 100644 crates/edit_prediction_context/src/wip_requests.rs diff --git a/Cargo.lock b/Cargo.lock index 5e704d56697b1460c9a3a705f852e3287185bdf2..1819b62c3434ad8bdea6dd526b7f68122378e290 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3225,6 +3225,18 @@ dependencies = [ "workspace-hack", ] +[[package]] +name = "cloud_zeta2_prompt" +version = "0.1.0" +dependencies = [ + "anyhow", + "cloud_llm_client", + "ordered-float 2.10.1", + "rustc-hash 2.1.1", + "strum 0.27.1", + "workspace-hack", +] + [[package]] name = "clru" version = "0.6.2" @@ -21683,7 +21695,9 @@ dependencies = [ "anyhow", "clap", "client", + "cloud_zeta2_prompt", "debug_adapter_extension", + "edit_prediction_context", "extension", "fs", "futures 0.3.31", @@ -21710,6 +21724,7 @@ dependencies = [ "watch", "workspace-hack", "zeta", + "zeta2", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 6e1950aaeea715dd85c98a443b6116a619b0e3f7..3c431a29eb53420dadde38a1c1ad30a1f61d44c1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,7 @@ members = [ "crates/cloud_api_client", "crates/cloud_api_types", "crates/cloud_llm_client", + "crates/cloud_zeta2_prompt", "crates/collab", "crates/collab_ui", "crates/collections", @@ -271,6 +272,7 @@ clock = { path = "crates/clock" } cloud_api_client = { path = "crates/cloud_api_client" } cloud_api_types = { path = "crates/cloud_api_types" } cloud_llm_client = { path = "crates/cloud_llm_client" } +cloud_zeta2_prompt = { path = "crates/cloud_zeta2_prompt" } collab = { path = "crates/collab" } collab_ui = { path = "crates/collab_ui" } collections = { path = "crates/collections" } diff --git a/crates/cloud_llm_client/src/predict_edits_v3.rs b/crates/cloud_llm_client/src/predict_edits_v3.rs index 60621b1f14714439b0527078c07e2865799172f3..35d05ff81f9eb72c4b8261dc3e3340b18c79ebfa 100644 --- a/crates/cloud_llm_client/src/predict_edits_v3.rs +++ b/crates/cloud_llm_client/src/predict_edits_v3.rs @@ -48,6 +48,9 @@ pub struct Signature { pub text_is_truncated: bool, #[serde(skip_serializing_if = "Option::is_none", default)] pub parent_index: Option, + /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The + /// file is implicitly the file that contains the descendant declaration or excerpt. + pub range: Range, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -55,7 +58,7 @@ pub struct ReferencedDeclaration { pub path: PathBuf, pub text: String, pub text_is_truncated: bool, - /// Range of `text` within file, potentially truncated according to `text_is_truncated` + /// Range of `text` within file, possibly truncated according to `text_is_truncated` pub range: Range, /// Range within `text` pub signature_range: Range, diff --git a/crates/cloud_zeta2_prompt/Cargo.toml b/crates/cloud_zeta2_prompt/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..a1194f13615964fd3013eb8dbdf3057984946e32 --- /dev/null +++ b/crates/cloud_zeta2_prompt/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "cloud_zeta2_prompt" +version = "0.1.0" +publish.workspace = true +edition.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/cloud_zeta2_prompt.rs" + +[dependencies] +anyhow.workspace = true +cloud_llm_client.workspace = true +ordered-float.workspace = true +rustc-hash.workspace = true +strum.workspace = true +workspace-hack.workspace = true diff --git a/crates/cloud_zeta2_prompt/LICENSE-GPL b/crates/cloud_zeta2_prompt/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/cloud_zeta2_prompt/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs new file mode 100644 index 0000000000000000000000000000000000000000..6690380c74b0d4880210b683f34eea1d98a7946b --- /dev/null +++ b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs @@ -0,0 +1,396 @@ +//! Zeta2 prompt planning and generation code shared with cloud. + +use anyhow::{Result, anyhow}; +use cloud_llm_client::predict_edits_v3::{self, ReferencedDeclaration}; +use ordered_float::OrderedFloat; +use rustc_hash::{FxHashMap, FxHashSet}; +use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path}; +use strum::{EnumIter, IntoEnumIterator}; + +pub const CURSOR_MARKER: &str = "<|user_cursor_is_here|>"; +/// NOTE: Differs from zed version of constant - includes a newline +pub const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>\n"; +/// NOTE: Differs from zed version of constant - includes a newline +pub const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>\n"; + +pub struct PlannedPrompt<'a> { + request: &'a predict_edits_v3::PredictEditsRequest, + /// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in + /// `to_prompt_string`. + snippets: Vec>, + budget_used: usize, +} + +pub struct PlanOptions { + pub max_bytes: usize, +} + +#[derive(Clone, Debug)] +pub struct PlannedSnippet<'a> { + path: &'a Path, + range: Range, + text: &'a str, + // TODO: Indicate this in the output + #[allow(dead_code)] + text_is_truncated: bool, +} + +#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)] +pub enum SnippetStyle { + Signature, + Declaration, +} + +impl<'a> PlannedPrompt<'a> { + /// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following: + /// + /// Initializes a priority queue by populating it with each snippet, finding the SnippetStyle + /// that minimizes `score_density = score / snippet.range(style).len()`. When a "signature" + /// snippet is popped, insert an entry for the "declaration" variant that reflects the cost of + /// upgrade. + /// + /// TODO: Implement an early halting condition. One option might be to have another priority + /// queue where the score is the size, and update it accordingly. Another option might be to + /// have some simpler heuristic like bailing after N failed insertions, or based on how much + /// budget is left. + /// + /// TODO: Has the current known sources of imprecision: + /// + /// * Does not consider snippet overlap when ranking. For example, it might add a field to the + /// plan even though the containing struct is already included. + /// + /// * Does not consider cost of signatures when ranking snippets - this is tricky since + /// signatures may be shared by multiple snippets. + /// + /// * Does not include file paths / other text when considering max_bytes. + pub fn populate( + request: &'a predict_edits_v3::PredictEditsRequest, + options: &PlanOptions, + ) -> Result { + let mut this = PlannedPrompt { + request, + snippets: Vec::new(), + budget_used: request.excerpt.len(), + }; + let mut included_parents = FxHashSet::default(); + let additional_parents = this.additional_parent_signatures( + &request.excerpt_path, + request.excerpt_parent, + &included_parents, + )?; + this.add_parents(&mut included_parents, additional_parents); + + if this.budget_used > options.max_bytes { + return Err(anyhow!( + "Excerpt + signatures size of {} already exceeds budget of {}", + this.budget_used, + options.max_bytes + )); + } + + #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] + struct QueueEntry { + score_density: OrderedFloat, + declaration_index: usize, + style: SnippetStyle, + } + + // Initialize priority queue with the best score for each snippet. + let mut queue: BinaryHeap = BinaryHeap::new(); + for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() { + let (style, score_density) = SnippetStyle::iter() + .map(|style| { + ( + style, + OrderedFloat(declaration_score_density(&declaration, style)), + ) + }) + .max_by_key(|(_, score_density)| *score_density) + .unwrap(); + queue.push(QueueEntry { + score_density, + declaration_index, + style, + }); + } + + // Knapsack selection loop + while let Some(queue_entry) = queue.pop() { + let Some(declaration) = request + .referenced_declarations + .get(queue_entry.declaration_index) + else { + return Err(anyhow!( + "Invalid declaration index {}", + queue_entry.declaration_index + )); + }; + + let mut additional_bytes = declaration_size(declaration, queue_entry.style); + if this.budget_used + additional_bytes > options.max_bytes { + continue; + } + + let additional_parents = this.additional_parent_signatures( + &declaration.path, + declaration.parent_index, + &mut included_parents, + )?; + additional_bytes += additional_parents + .iter() + .map(|(_, snippet)| snippet.text.len()) + .sum::(); + if this.budget_used + additional_bytes > options.max_bytes { + continue; + } + + this.budget_used += additional_bytes; + this.add_parents(&mut included_parents, additional_parents); + let planned_snippet = match queue_entry.style { + SnippetStyle::Signature => { + let Some(text) = declaration.text.get(declaration.signature_range.clone()) + else { + return Err(anyhow!( + "Invalid declaration signature_range {:?} with text.len() = {}", + declaration.signature_range, + declaration.text.len() + )); + }; + PlannedSnippet { + path: &declaration.path, + range: (declaration.signature_range.start + declaration.range.start) + ..(declaration.signature_range.end + declaration.range.start), + text, + text_is_truncated: declaration.text_is_truncated, + } + } + SnippetStyle::Declaration => PlannedSnippet { + path: &declaration.path, + range: declaration.range.clone(), + text: &declaration.text, + text_is_truncated: declaration.text_is_truncated, + }, + }; + this.snippets.push(planned_snippet); + + // When a Signature is consumed, insert an entry for Definition style. + if queue_entry.style == SnippetStyle::Signature { + let signature_size = declaration_size(&declaration, SnippetStyle::Signature); + let declaration_size = declaration_size(&declaration, SnippetStyle::Declaration); + let signature_score = declaration_score(&declaration, SnippetStyle::Signature); + let declaration_score = declaration_score(&declaration, SnippetStyle::Declaration); + + let score_diff = declaration_score - signature_score; + let size_diff = declaration_size.saturating_sub(signature_size); + if score_diff > 0.0001 && size_diff > 0 { + queue.push(QueueEntry { + declaration_index: queue_entry.declaration_index, + score_density: OrderedFloat(score_diff / (size_diff as f32)), + style: SnippetStyle::Declaration, + }); + } + } + } + + anyhow::Ok(this) + } + + fn add_parents( + &mut self, + included_parents: &mut FxHashSet, + snippets: Vec<(usize, PlannedSnippet<'a>)>, + ) { + for (parent_index, snippet) in snippets { + included_parents.insert(parent_index); + self.budget_used += snippet.text.len(); + self.snippets.push(snippet); + } + } + + fn additional_parent_signatures( + &self, + path: &'a Path, + parent_index: Option, + included_parents: &FxHashSet, + ) -> Result)>> { + let mut results = Vec::new(); + self.additional_parent_signatures_impl(path, parent_index, included_parents, &mut results)?; + Ok(results) + } + + fn additional_parent_signatures_impl( + &self, + path: &'a Path, + parent_index: Option, + included_parents: &FxHashSet, + results: &mut Vec<(usize, PlannedSnippet<'a>)>, + ) -> Result<()> { + let Some(parent_index) = parent_index else { + return Ok(()); + }; + if included_parents.contains(&parent_index) { + return Ok(()); + } + let Some(parent_signature) = self.request.signatures.get(parent_index) else { + return Err(anyhow!("Invalid parent index {}", parent_index)); + }; + results.push(( + parent_index, + PlannedSnippet { + path, + range: parent_signature.range.clone(), + text: &parent_signature.text, + text_is_truncated: parent_signature.text_is_truncated, + }, + )); + self.additional_parent_signatures_impl( + path, + parent_signature.parent_index, + included_parents, + results, + ) + } + + /// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple + /// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive + /// chunks. + pub fn to_prompt_string(&self) -> String { + let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> = + FxHashMap::default(); + for snippet in &self.snippets { + file_to_snippets + .entry(&snippet.path) + .or_default() + .push(snippet); + } + + // Reorder so that file with cursor comes last + let mut file_snippets = Vec::new(); + let mut excerpt_file_snippets = Vec::new(); + for (file_path, snippets) in file_to_snippets { + if file_path == &self.request.excerpt_path { + excerpt_file_snippets = snippets; + } else { + file_snippets.push((file_path, snippets, false)); + } + } + let excerpt_snippet = PlannedSnippet { + path: &self.request.excerpt_path, + range: self.request.excerpt_range.clone(), + text: &self.request.excerpt, + text_is_truncated: false, + }; + excerpt_file_snippets.push(&excerpt_snippet); + file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true)); + + let mut excerpt_file_insertions = vec![ + ( + self.request.excerpt_range.start, + EDITABLE_REGION_START_MARKER, + ), + ( + self.request.excerpt_range.start + self.request.cursor_offset, + CURSOR_MARKER, + ), + ( + self.request + .excerpt_range + .end + .saturating_sub(0) + .max(self.request.excerpt_range.start), + EDITABLE_REGION_END_MARKER, + ), + ]; + + fn push_excerpt_file_range( + range: Range, + text: &str, + excerpt_file_insertions: &mut Vec<(usize, &'static str)>, + output: &mut String, + ) { + let mut last_offset = range.start; + let mut i = 0; + while i < excerpt_file_insertions.len() { + let (offset, insertion) = &excerpt_file_insertions[i]; + let found = *offset >= range.start && *offset <= range.end; + if found { + output.push_str(&text[last_offset - range.start..offset - range.start]); + output.push_str(insertion); + last_offset = *offset; + excerpt_file_insertions.remove(i); + continue; + } + i += 1; + } + output.push_str(&text[last_offset - range.start..]); + } + + let mut output = String::new(); + for (file_path, mut snippets, is_excerpt_file) in file_snippets { + output.push_str(&format!("```{}\n", file_path.display())); + + let mut last_included_range: Option> = None; + snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end))); + for snippet in snippets { + if let Some(last_range) = &last_included_range + && snippet.range.start < last_range.end + { + if snippet.range.end <= last_range.end { + continue; + } + // TODO: Should probably also handle case where there is just one char (newline) + // between snippets - assume it's a newline. + let text = &snippet.text[last_range.end - snippet.range.start..]; + if is_excerpt_file { + push_excerpt_file_range( + last_range.end..snippet.range.end, + text, + &mut excerpt_file_insertions, + &mut output, + ); + } else { + output.push_str(text); + } + last_included_range = Some(last_range.start..snippet.range.end); + continue; + } + if last_included_range.is_some() { + output.push_str("…\n"); + } + if is_excerpt_file { + push_excerpt_file_range( + snippet.range.clone(), + snippet.text, + &mut excerpt_file_insertions, + &mut output, + ); + } else { + output.push_str(snippet.text); + } + last_included_range = Some(snippet.range.clone()); + } + + output.push_str("```\n\n"); + } + + output + } +} + +fn declaration_score_density(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 { + declaration_score(declaration, style) / declaration_size(declaration, style) as f32 +} + +fn declaration_score(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 { + match style { + SnippetStyle::Signature => declaration.signature_score, + SnippetStyle::Declaration => declaration.declaration_score, + } +} + +fn declaration_size(declaration: &ReferencedDeclaration, style: SnippetStyle) -> usize { + match style { + SnippetStyle::Signature => declaration.signature_range.len(), + SnippetStyle::Declaration => declaration.text.len(), + } +} diff --git a/crates/edit_prediction_context/src/declaration.rs b/crates/edit_prediction_context/src/declaration.rs index 653f810d439395a8825c99f4b007e05d881540ab..910835534af80ba97b99b8fc560c27bf13c4acda 100644 --- a/crates/edit_prediction_context/src/declaration.rs +++ b/crates/edit_prediction_context/src/declaration.rs @@ -68,7 +68,7 @@ impl Declaration { pub fn item_range(&self) -> Range { match self { - Declaration::File { declaration, .. } => declaration.item_range_in_file.clone(), + Declaration::File { declaration, .. } => declaration.item_range.clone(), Declaration::Buffer { declaration, .. } => declaration.item_range.clone(), } } @@ -92,7 +92,7 @@ impl Declaration { pub fn signature_text(&self) -> (Cow<'_, str>, bool) { match self { Declaration::File { declaration, .. } => ( - declaration.text[declaration.signature_range_in_text.clone()].into(), + declaration.text[self.signature_range_in_item_text()].into(), declaration.signature_is_truncated, ), Declaration::Buffer { @@ -105,15 +105,19 @@ impl Declaration { } } - pub fn signature_range_in_item_text(&self) -> Range { + pub fn signature_range(&self) -> Range { match self { - Declaration::File { declaration, .. } => declaration.signature_range_in_text.clone(), - Declaration::Buffer { declaration, .. } => { - declaration.signature_range.start - declaration.item_range.start - ..declaration.signature_range.end - declaration.item_range.start - } + Declaration::File { declaration, .. } => declaration.signature_range.clone(), + Declaration::Buffer { declaration, .. } => declaration.signature_range.clone(), } } + + pub fn signature_range_in_item_text(&self) -> Range { + let signature_range = self.signature_range(); + let item_range = self.item_range(); + signature_range.start.saturating_sub(item_range.start) + ..(signature_range.end.saturating_sub(item_range.start)).min(item_range.len()) + } } fn expand_range_to_line_boundaries_and_truncate( @@ -141,13 +145,13 @@ pub struct FileDeclaration { pub parent: Option, pub identifier: Identifier, /// offset range of the declaration in the file, expanded to line boundaries and truncated - pub item_range_in_file: Range, - /// text of `item_range_in_file` + pub item_range: Range, + /// text of `item_range` pub text: Arc, /// whether `text` was truncated pub text_is_truncated: bool, - /// offset range of the signature within `text` - pub signature_range_in_text: Range, + /// offset range of the signature in the file, expanded to line boundaries and truncated + pub signature_range: Range, /// whether `signature` was truncated pub signature_is_truncated: bool, } @@ -160,31 +164,33 @@ impl FileDeclaration { rope, ); - // TODO: consider logging if unexpected - let signature_start = declaration - .signature_range - .start - .saturating_sub(item_range_in_file.start); - let mut signature_end = declaration - .signature_range - .end - .saturating_sub(item_range_in_file.start); - let signature_is_truncated = signature_end > item_range_in_file.len(); - if signature_is_truncated { - signature_end = item_range_in_file.len(); + let (mut signature_range_in_file, mut signature_is_truncated) = + expand_range_to_line_boundaries_and_truncate( + &declaration.signature_range, + ITEM_TEXT_TRUNCATION_LENGTH, + rope, + ); + + if signature_range_in_file.start < item_range_in_file.start { + signature_range_in_file.start = item_range_in_file.start; + signature_is_truncated = true; + } + if signature_range_in_file.end > item_range_in_file.end { + signature_range_in_file.end = item_range_in_file.end; + signature_is_truncated = true; } FileDeclaration { parent: None, identifier: declaration.identifier, - signature_range_in_text: signature_start..signature_end, + signature_range: signature_range_in_file, signature_is_truncated, text: rope .chunks_in_range(item_range_in_file.clone()) .collect::() .into(), text_is_truncated, - item_range_in_file, + item_range: item_range_in_file, } } } diff --git a/crates/edit_prediction_context/src/declaration_scoring.rs b/crates/edit_prediction_context/src/declaration_scoring.rs index 1638723857d225d05efce512bb9025fa89fb38f3..fee7498a696c0608704dab6e8ab9f012c95660b5 100644 --- a/crates/edit_prediction_context/src/declaration_scoring.rs +++ b/crates/edit_prediction_context/src/declaration_scoring.rs @@ -40,10 +40,9 @@ impl ScoredSnippet { } pub fn size(&self, style: SnippetStyle) -> usize { - // TODO: how to handle truncation? match &self.declaration { Declaration::File { declaration, .. } => match style { - SnippetStyle::Signature => declaration.signature_range_in_text.len(), + SnippetStyle::Signature => declaration.signature_range.len(), SnippetStyle::Declaration => declaration.text.len(), }, Declaration::Buffer { declaration, .. } => match style { @@ -276,6 +275,8 @@ pub struct Scores { impl Scores { fn score(components: &ScoreComponents) -> Scores { + // TODO: handle truncation + // Score related to how likely this is the correct declaration, range 0 to 1 let accuracy_score = if components.is_same_file { // TODO: use declaration_line_distance_rank diff --git a/crates/edit_prediction_context/src/syntax_index.rs b/crates/edit_prediction_context/src/syntax_index.rs index d234e975d504c145d7bc2fc0680569c388ba0d1c..1b5e4268ccec74b9eea52c1001c7854dd746c5cf 100644 --- a/crates/edit_prediction_context/src/syntax_index.rs +++ b/crates/edit_prediction_context/src/syntax_index.rs @@ -578,11 +578,11 @@ mod tests { let decl = expect_file_decl("c.rs", &decls[0].1, &project, cx); assert_eq!(decl.identifier, main.clone()); - assert_eq!(decl.item_range_in_file, 32..280); + assert_eq!(decl.item_range, 32..280); let decl = expect_file_decl("a.rs", &decls[1].1, &project, cx); assert_eq!(decl.identifier, main); - assert_eq!(decl.item_range_in_file, 0..98); + assert_eq!(decl.item_range, 0..98); }); } diff --git a/crates/edit_prediction_context/src/wip_requests.rs b/crates/edit_prediction_context/src/wip_requests.rs deleted file mode 100644 index 9189587929725c8e1e4369fe5bd24cc641d6afab..0000000000000000000000000000000000000000 --- a/crates/edit_prediction_context/src/wip_requests.rs +++ /dev/null @@ -1,35 +0,0 @@ -// To discuss: What to send to the new endpoint? Thinking it'd make sense to put `prompt.rs` from -// `zeta_context.rs` in cloud. -// -// * Run excerpt selection at several different sizes, send the largest size with offsets within for -// the smaller sizes. -// -// * Longer event history. -// -// * Many more snippets than could fit in model context - allows ranking experimentation. - -pub struct Zeta2Request { - pub event_history: Vec, - pub excerpt: String, - pub excerpt_subsets: Vec, - /// Within `excerpt` - pub cursor_position: usize, - pub signatures: Vec, - pub retrieved_declarations: Vec, -} - -pub struct Zeta2ExcerptSubset { - /// Within `excerpt` text. - pub excerpt_range: Range, - /// Within `signatures`. - pub parent_signatures: Vec, -} - -pub struct ReferencedDeclaration { - pub text: Arc, - /// Range within `text` - pub signature_range: Range, - /// Indices within `signatures`. - pub parent_signatures: Vec, - // A bunch of score metrics -} diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index 791273a9242dd7aa50588fcfe90e9258ff3724ea..240d44fae44d9a430e1ed64816e11428a5bdb3d0 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -48,7 +48,7 @@ pub struct Zeta { llm_token: LlmApiToken, _llm_token_subscription: Subscription, projects: HashMap, - excerpt_options: EditPredictionExcerptOptions, + pub excerpt_options: EditPredictionExcerptOptions, update_required: bool, } @@ -87,7 +87,7 @@ impl Zeta { }) } - fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self { + pub fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self { let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); Self { @@ -478,6 +478,66 @@ impl Zeta { } } } + + // TODO: Dedupe with similar code in request_prediction? + pub fn cloud_request_for_zeta_cli( + &mut self, + project: &Entity, + buffer: &Entity, + position: language::Anchor, + cx: &mut Context, + ) -> Task> { + let project_state = self.projects.get(&project.entity_id()); + + let index_state = project_state.map(|state| { + state + .syntax_index + .read_with(cx, |index, _cx| index.state().clone()) + }); + let excerpt_options = self.excerpt_options.clone(); + let snapshot = buffer.read(cx).snapshot(); + let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else { + return Task::ready(Err(anyhow!("No file path for excerpt"))); + }; + let worktree_snapshots = project + .read(cx) + .worktrees(cx) + .map(|worktree| worktree.read(cx).snapshot()) + .collect::>(); + + cx.background_spawn(async move { + let index_state = if let Some(index_state) = index_state { + Some(index_state.lock_owned().await) + } else { + None + }; + + let cursor_point = position.to_point(&snapshot); + + let debug_info = true; + EditPredictionContext::gather_context( + cursor_point, + &snapshot, + &excerpt_options, + index_state.as_deref(), + ) + .context("Failed to select excerpt") + .map(|context| { + make_cloud_request( + excerpt_path.clone(), + context, + // TODO pass everything + Vec::new(), + false, + Vec::new(), + None, + debug_info, + &worktree_snapshots, + index_state.as_deref(), + ) + }) + }) + } } #[derive(Error, Debug)] @@ -840,13 +900,13 @@ fn make_cloud_request( for snippet in context.snippets { let project_entry_id = snippet.declaration.project_entry_id(); - // TODO: Use full paths (worktree rooted) - need to move full_path method to the snapshot. - // Note that currently full_path is currently being used for excerpt_path. let Some(path) = worktrees.iter().find_map(|worktree| { - let abs_path = worktree.abs_path(); - worktree - .entry_for_id(project_entry_id) - .map(|e| abs_path.join(&e.path)) + worktree.entry_for_id(project_entry_id).map(|entry| { + let mut full_path = PathBuf::new(); + full_path.push(worktree.root_name()); + full_path.push(&entry.path); + full_path + }) }) else { continue; }; @@ -929,6 +989,7 @@ fn add_signature( text: text.into(), text_is_truncated, parent_index, + range: parent_declaration.signature_range(), }); declaration_to_signature_index.insert(declaration_id, signature_index); Some(signature_index) diff --git a/crates/zeta_cli/Cargo.toml b/crates/zeta_cli/Cargo.toml index e77351c219bac4425136e2a3f1752d73e76adbbf..38b85d7c3ac583b25f72240bfde6109a04e30c10 100644 --- a/crates/zeta_cli/Cargo.toml +++ b/crates/zeta_cli/Cargo.toml @@ -16,7 +16,9 @@ path = "src/main.rs" anyhow.workspace = true clap.workspace = true client.workspace = true +cloud_zeta2_prompt.workspace= true debug_adapter_extension.workspace = true +edit_prediction_context.workspace = true extension.workspace = true fs.workspace = true futures.workspace = true @@ -37,9 +39,10 @@ serde.workspace = true serde_json.workspace = true settings.workspace = true shellexpand.workspace = true +smol.workspace = true terminal_view.workspace = true util.workspace = true watch.workspace = true workspace-hack.workspace = true zeta.workspace = true -smol.workspace = true +zeta2.workspace = true diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index e7cec26b19358056cee4c8e253c54c0b2c794b33..d2cde6be589a8138ef1e88e872a1e18294f4cb30 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -2,6 +2,7 @@ mod headless; use anyhow::{Result, anyhow}; use clap::{Args, Parser, Subcommand}; +use edit_prediction_context::EditPredictionExcerptOptions; use futures::channel::mpsc; use futures::{FutureExt as _, StreamExt as _}; use gpui::{AppContext, Application, AsyncApp}; @@ -18,7 +19,7 @@ use std::process::exit; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; -use zeta::{GatherContextOutput, PerformPredictEditsParams, Zeta, gather_context}; +use zeta::{PerformPredictEditsParams, Zeta}; use crate::headless::ZetaCliAppState; @@ -32,6 +33,12 @@ struct ZetaCliArgs { #[derive(Subcommand, Debug)] enum Commands { Context(ContextArgs), + Zeta2Context { + #[clap(flatten)] + zeta2_args: Zeta2Args, + #[clap(flatten)] + context_args: ContextArgs, + }, Predict { #[arg(long)] predict_edits_body: Option, @@ -53,6 +60,18 @@ struct ContextArgs { events: Option, } +#[derive(Debug, Args)] +struct Zeta2Args { + #[arg(long, default_value_t = 8192)] + prompt_max_bytes: usize, + #[arg(long, default_value_t = 2048)] + excerpt_max_bytes: usize, + #[arg(long, default_value_t = 1024)] + excerpt_min_bytes: usize, + #[arg(long, default_value_t = 0.66)] + target_before_cursor_over_total_bytes: f32, +} + #[derive(Debug, Clone)] enum FileOrStdin { File(PathBuf), @@ -112,11 +131,17 @@ impl FromStr for CursorPosition { } } +enum GetContextOutput { + Zeta1(zeta::GatherContextOutput), + Zeta2(String), +} + async fn get_context( + zeta2_args: Option, args: ContextArgs, app_state: &Arc, cx: &mut AsyncApp, -) -> Result { +) -> Result { let ContextArgs { worktree: worktree_path, cursor, @@ -152,9 +177,7 @@ async fn get_context( open_buffer_with_language_server(&project, &worktree, &cursor.path, cx).await?; (Some(lsp_open_handle), buffer) } else { - let abs_path = worktree_path.join(&cursor.path); - let content = smol::fs::read_to_string(&abs_path).await?; - let buffer = cx.new(|cx| Buffer::local(content, cx))?; + let buffer = open_buffer(&project, &worktree, &cursor.path, cx).await?; (None, buffer) }; @@ -189,33 +212,83 @@ async fn get_context( Some(events) => events.read_to_string().await?, None => String::new(), }; - let prompt_for_events = move || (events, 0); - cx.update(|cx| { - gather_context( - full_path_str, - &snapshot, - clipped_cursor, - prompt_for_events, - cx, - ) - })? - .await + + if let Some(zeta2_args) = zeta2_args { + Ok(GetContextOutput::Zeta2( + cx.update(|cx| { + let zeta = cx.new(|cx| { + zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx) + }); + zeta.update(cx, |zeta, cx| { + zeta.register_buffer(&buffer, &project, cx); + zeta.excerpt_options = EditPredictionExcerptOptions { + max_bytes: zeta2_args.excerpt_max_bytes, + min_bytes: zeta2_args.excerpt_min_bytes, + target_before_cursor_over_total_bytes: zeta2_args + .target_before_cursor_over_total_bytes, + } + }); + // TODO: Actually wait for indexing. + let timer = cx.background_executor().timer(Duration::from_secs(5)); + cx.spawn(async move |cx| { + timer.await; + let request = zeta + .update(cx, |zeta, cx| { + let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor); + zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx) + })? + .await?; + let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate( + &request, + &cloud_zeta2_prompt::PlanOptions { + max_bytes: zeta2_args.prompt_max_bytes, + }, + )?; + anyhow::Ok(planned_prompt.to_prompt_string()) + }) + })? + .await?, + )) + } else { + let prompt_for_events = move || (events, 0); + Ok(GetContextOutput::Zeta1( + cx.update(|cx| { + zeta::gather_context( + full_path_str, + &snapshot, + clipped_cursor, + prompt_for_events, + cx, + ) + })? + .await?, + )) + } } -pub async fn open_buffer_with_language_server( +pub async fn open_buffer( project: &Entity, worktree: &Entity, path: &Path, cx: &mut AsyncApp, -) -> Result<(Entity>, Entity)> { +) -> Result> { let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath { worktree_id: worktree.id(), path: path.to_path_buf().into(), })?; - let buffer = project + project .update(cx, |project, cx| project.open_buffer(project_path, cx))? - .await?; + .await +} + +pub async fn open_buffer_with_language_server( + project: &Entity, + worktree: &Entity, + path: &Path, + cx: &mut AsyncApp, +) -> Result<(Entity>, Entity)> { + let buffer = open_buffer(project, worktree, path, cx).await?; let lsp_open_handle = project.update(cx, |project, cx| { project.register_buffer_with_language_servers(&buffer, cx) @@ -319,11 +392,26 @@ fn main() { app.run(move |cx| { let app_state = Arc::new(headless::init(cx)); + let is_zeta2_context_command = matches!(args.command, Commands::Zeta2Context { .. }); cx.spawn(async move |cx| { let result = match args.command { - Commands::Context(context_args) => get_context(context_args, &app_state, cx) - .await - .map(|output| serde_json::to_string_pretty(&output.body).unwrap()), + Commands::Zeta2Context { + zeta2_args, + context_args, + } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await { + Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(), + Ok(GetContextOutput::Zeta2(output)) => Ok(output), + Err(err) => Err(err), + }, + Commands::Context(context_args) => { + match get_context(None, context_args, &app_state, cx).await { + Ok(GetContextOutput::Zeta1(output)) => { + Ok(serde_json::to_string_pretty(&output.body).unwrap()) + } + Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(), + Err(err) => Err(err), + } + } Commands::Predict { predict_edits_body, context_args, @@ -338,7 +426,10 @@ fn main() { if let Some(predict_edits_body) = predict_edits_body { serde_json::from_str(&predict_edits_body.read_to_string().await?)? } else if let Some(context_args) = context_args { - get_context(context_args, &app_state, cx).await?.body + match get_context(None, context_args, &app_state, cx).await? { + GetContextOutput::Zeta1(output) => output.body, + GetContextOutput::Zeta2 { .. } => unreachable!(), + } } else { return Err(anyhow!( "Expected either --predict-edits-body-file \ @@ -363,6 +454,10 @@ fn main() { match result { Ok(output) => { println!("{}", output); + // TODO: Remove this once the 5 second delay is properly replaced. + if is_zeta2_context_command { + eprintln!("Note that zeta2-context doesn't yet wait for indexing, instead waits 5 seconds."); + } let _ = cx.update(|cx| cx.quit()); } Err(e) => {