diff --git a/crates/cloud_llm_client/src/predict_edits_v3.rs b/crates/cloud_llm_client/src/predict_edits_v3.rs index 6a6090afab569d77ccfcbcb25d48bab0158ea335..60bbd8c8d6e55019f6b91df94a103eb83f3a100d 100644 --- a/crates/cloud_llm_client/src/predict_edits_v3.rs +++ b/crates/cloud_llm_client/src/predict_edits_v3.rs @@ -1,7 +1,7 @@ use chrono::Duration; use serde::{Deserialize, Serialize}; use std::{ - ops::Range, + ops::{Add, Range, Sub}, path::{Path, PathBuf}, sync::Arc, }; @@ -18,8 +18,8 @@ pub struct PredictEditsRequest { pub excerpt_path: Arc, /// Within file pub excerpt_range: Range, - /// Within `excerpt` - pub cursor_offset: usize, + pub excerpt_line_range: Range, + pub cursor_point: Point, /// Within `signatures` pub excerpt_parent: Option, pub signatures: Vec, @@ -47,12 +47,13 @@ pub struct PredictEditsRequest { pub enum PromptFormat { MarkedExcerpt, LabeledSections, + NumberedLines, /// Prompt format intended for use via zeta_cli OnlySnippets, } impl PromptFormat { - pub const DEFAULT: PromptFormat = PromptFormat::LabeledSections; + pub const DEFAULT: PromptFormat = PromptFormat::NumberedLines; } impl Default for PromptFormat { @@ -73,6 +74,7 @@ impl std::fmt::Display for PromptFormat { PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"), PromptFormat::LabeledSections => write!(f, "Labeled Sections"), PromptFormat::OnlySnippets => write!(f, "Only Snippets"), + PromptFormat::NumberedLines => write!(f, "Numbered Lines"), } } } @@ -97,7 +99,7 @@ pub struct Signature { pub parent_index: Option, /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The /// file is implicitly the file that contains the descendant declaration or excerpt. - pub range: Range, + pub range: Range, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -106,7 +108,7 @@ pub struct ReferencedDeclaration { pub text: String, pub text_is_truncated: bool, /// Range of `text` within file, possibly truncated according to `text_is_truncated` - pub range: Range, + pub range: Range, /// Range within `text` pub signature_range: Range, /// Index within `signatures`. @@ -169,10 +171,36 @@ pub struct DebugInfo { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Edit { pub path: Arc, - pub range: Range, + pub range: Range, pub content: String, } fn is_default(value: &T) -> bool { *value == T::default() } + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)] +pub struct Point { + pub line: Line, + pub column: u32, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)] +#[serde(transparent)] +pub struct Line(pub u32); + +impl Add for Line { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Self(self.0 + rhs.0) + } +} + +impl Sub for Line { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Self(self.0 - rhs.0) + } +} diff --git a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs index df70119b7fc91cc570e605fd5cebb9164d54f215..d68c0defef050985160688de0d541671866a91ac 100644 --- a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs +++ b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs @@ -1,7 +1,9 @@ //! Zeta2 prompt planning and generation code shared with cloud. use anyhow::{Context as _, Result, anyhow}; -use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, ReferencedDeclaration}; +use cloud_llm_client::predict_edits_v3::{ + self, Event, Line, Point, PromptFormat, ReferencedDeclaration, +}; use indoc::indoc; use ordered_float::OrderedFloat; use rustc_hash::{FxHashMap, FxHashSet}; @@ -43,6 +45,42 @@ const LABELED_SECTIONS_SYSTEM_PROMPT: &str = indoc! {r#" } "#}; +const NUMBERED_LINES_SYSTEM_PROMPT: &str = indoc! {r#" + # Instructions + + You are a code completion assistant helping a programmer finish their work. Your task is to: + + 1. Analyze the edit history to understand what the programmer is trying to achieve + 2. Identify any incomplete refactoring or changes that need to be finished + 3. Make the remaining edits that a human programmer would logically make next + 4. Apply systematic changes consistently across the entire codebase - if you see a pattern starting, complete it everywhere. + + Focus on: + - Understanding the intent behind the changes (e.g., improving error handling, refactoring APIs, fixing bugs) + - Completing any partially-applied changes across the codebase + - Ensuring consistency with the programming style and patterns already established + - Making edits that maintain or improve code quality + - If the programmer started refactoring one instance of a pattern, find and update ALL similar instances + - Don't write a lot of code if you're not sure what to do + + Rules: + - Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals. + - Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code. + - Write the edits in the unified diff format as shown in the example. + + # Example output: + + ``` + --- a/distill-claude/tmp-outs/edits_history.txt + +++ b/distill-claude/tmp-outs/edits_history.txt + @@ -1,3 +1,3 @@ + - + - + -import sys + +import json + ``` +"#}; + pub struct PlannedPrompt<'a> { request: &'a predict_edits_v3::PredictEditsRequest, /// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in @@ -55,6 +93,7 @@ pub fn system_prompt(format: PromptFormat) -> &'static str { match format { PromptFormat::MarkedExcerpt => MARKED_EXCERPT_SYSTEM_PROMPT, PromptFormat::LabeledSections => LABELED_SECTIONS_SYSTEM_PROMPT, + PromptFormat::NumberedLines => NUMBERED_LINES_SYSTEM_PROMPT, // only intended for use via zeta_cli PromptFormat::OnlySnippets => "", } @@ -63,7 +102,7 @@ pub fn system_prompt(format: PromptFormat) -> &'static str { #[derive(Clone, Debug)] pub struct PlannedSnippet<'a> { path: Arc, - range: Range, + range: Range, text: &'a str, // TODO: Indicate this in the output #[allow(dead_code)] @@ -79,7 +118,7 @@ pub enum DeclarationStyle { #[derive(Clone, Debug, Serialize)] pub struct SectionLabels { pub excerpt_index: usize, - pub section_ranges: Vec<(Arc, Range)>, + pub section_ranges: Vec<(Arc, Range)>, } impl<'a> PlannedPrompt<'a> { @@ -196,10 +235,24 @@ impl<'a> PlannedPrompt<'a> { declaration.text.len() )); }; + let signature_start_line = declaration.range.start + + Line( + declaration.text[..declaration.signature_range.start] + .lines() + .count() as u32, + ); + let signature_end_line = signature_start_line + + Line( + declaration.text + [declaration.signature_range.start..declaration.signature_range.end] + .lines() + .count() as u32, + ); + let range = signature_start_line..signature_end_line; + PlannedSnippet { path: declaration.path.clone(), - range: (declaration.signature_range.start + declaration.range.start) - ..(declaration.signature_range.end + declaration.range.start), + range, text, text_is_truncated: declaration.text_is_truncated, } @@ -318,7 +371,7 @@ impl<'a> PlannedPrompt<'a> { } let excerpt_snippet = PlannedSnippet { path: self.request.excerpt_path.clone(), - range: self.request.excerpt_range.clone(), + range: self.request.excerpt_line_range.clone(), text: &self.request.excerpt, text_is_truncated: false, }; @@ -328,32 +381,33 @@ impl<'a> PlannedPrompt<'a> { let mut excerpt_file_insertions = match self.request.prompt_format { PromptFormat::MarkedExcerpt => vec![ ( - self.request.excerpt_range.start, + Point { + line: self.request.excerpt_line_range.start, + column: 0, + }, EDITABLE_REGION_START_MARKER_WITH_NEWLINE, ), + (self.request.cursor_point, CURSOR_MARKER), ( - self.request.excerpt_range.start + self.request.cursor_offset, - CURSOR_MARKER, - ), - ( - self.request - .excerpt_range - .end - .saturating_sub(0) - .max(self.request.excerpt_range.start), + Point { + line: self.request.excerpt_line_range.end, + column: 0, + }, EDITABLE_REGION_END_MARKER_WITH_NEWLINE, ), ], - PromptFormat::LabeledSections => vec![( - self.request.excerpt_range.start + self.request.cursor_offset, - CURSOR_MARKER, - )], + PromptFormat::LabeledSections => vec![(self.request.cursor_point, CURSOR_MARKER)], + PromptFormat::NumberedLines => vec![(self.request.cursor_point, CURSOR_MARKER)], PromptFormat::OnlySnippets => vec![], }; let mut prompt = String::new(); prompt.push_str("## User Edits\n\n"); - Self::push_events(&mut prompt, &self.request.events); + if self.request.events.is_empty() { + prompt.push_str("No edits yet.\n"); + } else { + Self::push_events(&mut prompt, &self.request.events); + } prompt.push_str("\n## Code\n\n"); let section_labels = @@ -391,13 +445,17 @@ impl<'a> PlannedPrompt<'a> { if *predicted { writeln!( output, - "User accepted prediction {:?}:\n```diff\n{}\n```\n", + "User accepted prediction {:?}:\n`````diff\n{}\n`````\n", path, diff ) .unwrap(); } else { - writeln!(output, "User edited {:?}:\n```diff\n{}\n```\n", path, diff) - .unwrap(); + writeln!( + output, + "User edited {:?}:\n`````diff\n{}\n`````\n", + path, diff + ) + .unwrap(); } } } @@ -407,7 +465,7 @@ impl<'a> PlannedPrompt<'a> { fn push_file_snippets( &self, output: &mut String, - excerpt_file_insertions: &mut Vec<(usize, &'static str)>, + excerpt_file_insertions: &mut Vec<(Point, &'static str)>, file_snippets: Vec<(&'a Path, Vec<&'a PlannedSnippet>, bool)>, ) -> Result { let mut section_ranges = Vec::new(); @@ -417,15 +475,13 @@ impl<'a> PlannedPrompt<'a> { snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end))); // TODO: What if the snippets get expanded too large to be editable? - let mut current_snippet: Option<(&PlannedSnippet, Range)> = None; - let mut disjoint_snippets: Vec<(&PlannedSnippet, Range)> = Vec::new(); + let mut current_snippet: Option<(&PlannedSnippet, Range)> = None; + let mut disjoint_snippets: Vec<(&PlannedSnippet, Range)> = Vec::new(); for snippet in snippets { if let Some((_, current_snippet_range)) = current_snippet.as_mut() - && snippet.range.start < current_snippet_range.end + && snippet.range.start <= current_snippet_range.end { - if snippet.range.end > current_snippet_range.end { - current_snippet_range.end = snippet.range.end; - } + current_snippet_range.end = current_snippet_range.end.max(snippet.range.end); continue; } if let Some(current_snippet) = current_snippet.take() { @@ -437,21 +493,24 @@ impl<'a> PlannedPrompt<'a> { disjoint_snippets.push(current_snippet); } - writeln!(output, "```{}", file_path.display()).ok(); + // TODO: remove filename=? + writeln!(output, "`````filename={}", file_path.display()).ok(); let mut skipped_last_snippet = false; for (snippet, range) in disjoint_snippets { let section_index = section_ranges.len(); match self.request.prompt_format { - PromptFormat::MarkedExcerpt | PromptFormat::OnlySnippets => { - if range.start > 0 && !skipped_last_snippet { + PromptFormat::MarkedExcerpt + | PromptFormat::OnlySnippets + | PromptFormat::NumberedLines => { + if range.start.0 > 0 && !skipped_last_snippet { output.push_str("…\n"); } } PromptFormat::LabeledSections => { if is_excerpt_file - && range.start <= self.request.excerpt_range.start - && range.end >= self.request.excerpt_range.end + && range.start <= self.request.excerpt_line_range.start + && range.end >= self.request.excerpt_line_range.end { writeln!(output, "<|current_section|>").ok(); } else { @@ -460,46 +519,83 @@ impl<'a> PlannedPrompt<'a> { } } + let push_full_snippet = |output: &mut String| { + if self.request.prompt_format == PromptFormat::NumberedLines { + for (i, line) in snippet.text.lines().enumerate() { + writeln!(output, "{}|{}", i as u32 + range.start.0 + 1, line)?; + } + } else { + output.push_str(&snippet.text); + } + anyhow::Ok(()) + }; + if is_excerpt_file { if self.request.prompt_format == PromptFormat::OnlySnippets { - if range.start >= self.request.excerpt_range.start - && range.end <= self.request.excerpt_range.end + if range.start >= self.request.excerpt_line_range.start + && range.end <= self.request.excerpt_line_range.end { skipped_last_snippet = true; } else { skipped_last_snippet = false; output.push_str(snippet.text); } - } else { - let mut last_offset = range.start; - let mut i = 0; - while i < excerpt_file_insertions.len() { - let (offset, insertion) = &excerpt_file_insertions[i]; - let found = *offset >= range.start && *offset <= range.end; + } else if !excerpt_file_insertions.is_empty() { + let lines = snippet.text.lines().collect::>(); + let push_line = |output: &mut String, line_ix: usize| { + if self.request.prompt_format == PromptFormat::NumberedLines { + write!(output, "{}|", line_ix as u32 + range.start.0 + 1)?; + } + anyhow::Ok(writeln!(output, "{}", lines[line_ix])?) + }; + let mut last_line_ix = 0; + let mut insertion_ix = 0; + while insertion_ix < excerpt_file_insertions.len() { + let (point, insertion) = &excerpt_file_insertions[insertion_ix]; + let found = point.line >= range.start && point.line <= range.end; if found { excerpt_index = Some(section_index); - output.push_str( - &snippet.text[last_offset - range.start..offset - range.start], - ); - output.push_str(insertion); - last_offset = *offset; - excerpt_file_insertions.remove(i); + let insertion_line_ix = (point.line.0 - range.start.0) as usize; + for line_ix in last_line_ix..insertion_line_ix { + push_line(output, line_ix)?; + } + if let Some(next_line) = lines.get(insertion_line_ix) { + if self.request.prompt_format == PromptFormat::NumberedLines { + write!( + output, + "{}|", + insertion_line_ix as u32 + range.start.0 + 1 + )? + } + output.push_str(&next_line[..point.column as usize]); + output.push_str(insertion); + writeln!(output, "{}", &next_line[point.column as usize..])?; + } else { + writeln!(output, "{}", insertion)?; + } + last_line_ix = insertion_line_ix + 1; + excerpt_file_insertions.remove(insertion_ix); continue; } - i += 1; + insertion_ix += 1; } skipped_last_snippet = false; - output.push_str(&snippet.text[last_offset - range.start..]); + for line_ix in last_line_ix..lines.len() { + push_line(output, line_ix)?; + } + } else { + skipped_last_snippet = false; + push_full_snippet(output)?; } } else { skipped_last_snippet = false; - output.push_str(snippet.text); + push_full_snippet(output)?; } section_ranges.push((snippet.path.clone(), range)); } - output.push_str("```\n\n"); + output.push_str("`````\n\n"); } Ok(SectionLabels { diff --git a/crates/edit_prediction_context/src/declaration.rs b/crates/edit_prediction_context/src/declaration.rs index b57054cb537655184d4a52b511213dcfa570cd87..cc32640425ecc563b1f24a6c695be1c13199cd73 100644 --- a/crates/edit_prediction_context/src/declaration.rs +++ b/crates/edit_prediction_context/src/declaration.rs @@ -1,3 +1,4 @@ +use cloud_llm_client::predict_edits_v3::{self, Line}; use language::{Language, LanguageId}; use project::ProjectEntryId; use std::ops::Range; @@ -91,6 +92,18 @@ impl Declaration { } } + pub fn item_line_range(&self) -> Range { + match self { + Declaration::File { declaration, .. } => declaration.item_line_range.clone(), + Declaration::Buffer { + declaration, rope, .. + } => { + Line(rope.offset_to_point(declaration.item_range.start).row) + ..Line(rope.offset_to_point(declaration.item_range.end).row) + } + } + } + pub fn item_text(&self) -> (Cow<'_, str>, bool) { match self { Declaration::File { declaration, .. } => ( @@ -130,6 +143,18 @@ impl Declaration { } } + pub fn signature_line_range(&self) -> Range { + match self { + Declaration::File { declaration, .. } => declaration.signature_line_range.clone(), + Declaration::Buffer { + declaration, rope, .. + } => { + Line(rope.offset_to_point(declaration.signature_range.start).row) + ..Line(rope.offset_to_point(declaration.signature_range.end).row) + } + } + } + pub fn signature_range_in_item_text(&self) -> Range { let signature_range = self.signature_range(); let item_range = self.item_range(); @@ -142,7 +167,7 @@ fn expand_range_to_line_boundaries_and_truncate( range: &Range, limit: usize, rope: &Rope, -) -> (Range, bool) { +) -> (Range, Range, bool) { let mut point_range = rope.offset_to_point(range.start)..rope.offset_to_point(range.end); point_range.start.column = 0; point_range.end.row += 1; @@ -155,7 +180,10 @@ fn expand_range_to_line_boundaries_and_truncate( item_range.end = item_range.start + limit; } item_range.end = rope.clip_offset(item_range.end, Bias::Left); - (item_range, is_truncated) + + let line_range = + predict_edits_v3::Line(point_range.start.row)..predict_edits_v3::Line(point_range.end.row); + (item_range, line_range, is_truncated) } #[derive(Debug, Clone)] @@ -164,25 +192,30 @@ pub struct FileDeclaration { pub identifier: Identifier, /// offset range of the declaration in the file, expanded to line boundaries and truncated pub item_range: Range, + /// line range of the declaration in the file, potentially truncated + pub item_line_range: Range, /// text of `item_range` pub text: Arc, /// whether `text` was truncated pub text_is_truncated: bool, /// offset range of the signature in the file, expanded to line boundaries and truncated pub signature_range: Range, + /// line range of the signature in the file, truncated + pub signature_line_range: Range, /// whether `signature` was truncated pub signature_is_truncated: bool, } impl FileDeclaration { pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> FileDeclaration { - let (item_range_in_file, text_is_truncated) = expand_range_to_line_boundaries_and_truncate( - &declaration.item_range, - ITEM_TEXT_TRUNCATION_LENGTH, - rope, - ); + let (item_range_in_file, item_line_range_in_file, text_is_truncated) = + expand_range_to_line_boundaries_and_truncate( + &declaration.item_range, + ITEM_TEXT_TRUNCATION_LENGTH, + rope, + ); - let (mut signature_range_in_file, mut signature_is_truncated) = + let (mut signature_range_in_file, signature_line_range, mut signature_is_truncated) = expand_range_to_line_boundaries_and_truncate( &declaration.signature_range, ITEM_TEXT_TRUNCATION_LENGTH, @@ -202,6 +235,7 @@ impl FileDeclaration { parent: None, identifier: declaration.identifier, signature_range: signature_range_in_file, + signature_line_range, signature_is_truncated, text: rope .chunks_in_range(item_range_in_file.clone()) @@ -209,6 +243,7 @@ impl FileDeclaration { .into(), text_is_truncated, item_range: item_range_in_file, + item_line_range: item_line_range_in_file, } } } @@ -225,12 +260,13 @@ pub struct BufferDeclaration { impl BufferDeclaration { pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> Self { - let (item_range, item_range_is_truncated) = expand_range_to_line_boundaries_and_truncate( - &declaration.item_range, - ITEM_TEXT_TRUNCATION_LENGTH, - rope, - ); - let (signature_range, signature_range_is_truncated) = + let (item_range, _item_line_range, item_range_is_truncated) = + expand_range_to_line_boundaries_and_truncate( + &declaration.item_range, + ITEM_TEXT_TRUNCATION_LENGTH, + rope, + ); + let (signature_range, _signature_line_range, signature_range_is_truncated) = expand_range_to_line_boundaries_and_truncate( &declaration.signature_range, ITEM_TEXT_TRUNCATION_LENGTH, diff --git a/crates/edit_prediction_context/src/edit_prediction_context.rs b/crates/edit_prediction_context/src/edit_prediction_context.rs index 19cafe0412bb0db67ef906d1ff119d7c23234f78..85b0c36d7342b8c83a6a6befb38a3f0c9753b093 100644 --- a/crates/edit_prediction_context/src/edit_prediction_context.rs +++ b/crates/edit_prediction_context/src/edit_prediction_context.rs @@ -9,6 +9,7 @@ pub mod text_similarity; use std::{path::Path, sync::Arc}; +use cloud_llm_client::predict_edits_v3; use collections::HashMap; use gpui::{App, AppContext as _, Entity, Task}; use language::BufferSnapshot; @@ -21,6 +22,8 @@ pub use imports::*; pub use reference::*; pub use syntax_index::*; +pub use predict_edits_v3::Line; + #[derive(Clone, Debug, PartialEq)] pub struct EditPredictionContextOptions { pub use_imports: bool, @@ -32,7 +35,7 @@ pub struct EditPredictionContextOptions { pub struct EditPredictionContext { pub excerpt: EditPredictionExcerpt, pub excerpt_text: EditPredictionExcerptText, - pub cursor_offset_in_excerpt: usize, + pub cursor_point: Point, pub declarations: Vec, } @@ -124,8 +127,6 @@ impl EditPredictionContext { ); let cursor_offset_in_file = cursor_point.to_offset(buffer); - // TODO fix this to not need saturating_sub - let cursor_offset_in_excerpt = cursor_offset_in_file.saturating_sub(excerpt.range.start); let declarations = if let Some(index_state) = index_state { let references = get_references(&excerpt, &excerpt_text, buffer); @@ -148,7 +149,7 @@ impl EditPredictionContext { Some(Self { excerpt, excerpt_text, - cursor_offset_in_excerpt, + cursor_point, declarations, }) } diff --git a/crates/edit_prediction_context/src/excerpt.rs b/crates/edit_prediction_context/src/excerpt.rs index 58549d579dca2b589fb8da01e4963782845933e9..7a4bb73edfa131b620a930d7f0e1c0da77e0afe6 100644 --- a/crates/edit_prediction_context/src/excerpt.rs +++ b/crates/edit_prediction_context/src/excerpt.rs @@ -4,7 +4,7 @@ use text::{Point, ToOffset as _, ToPoint as _}; use tree_sitter::{Node, TreeCursor}; use util::RangeExt; -use crate::{BufferDeclaration, declaration::DeclarationId, syntax_index::SyntaxIndexState}; +use crate::{BufferDeclaration, Line, declaration::DeclarationId, syntax_index::SyntaxIndexState}; // TODO: // @@ -35,6 +35,7 @@ pub struct EditPredictionExcerptOptions { #[derive(Debug, Clone)] pub struct EditPredictionExcerpt { pub range: Range, + pub line_range: Range, pub parent_declarations: Vec<(DeclarationId, Range)>, pub size: usize, } @@ -86,12 +87,19 @@ impl EditPredictionExcerpt { buffer.len(), options.max_bytes ); - return Some(EditPredictionExcerpt::new(0..buffer.len(), Vec::new())); + let offset_range = 0..buffer.len(); + let line_range = Line(0)..Line(buffer.max_point().row); + return Some(EditPredictionExcerpt::new( + offset_range, + line_range, + Vec::new(), + )); } let query_offset = query_point.to_offset(buffer); - let query_range = Point::new(query_point.row, 0).to_offset(buffer) - ..Point::new(query_point.row + 1, 0).to_offset(buffer); + let query_line_range = query_point.row..query_point.row + 1; + let query_range = Point::new(query_line_range.start, 0).to_offset(buffer) + ..Point::new(query_line_range.end, 0).to_offset(buffer); if query_range.len() >= options.max_bytes { return None; } @@ -107,6 +115,7 @@ impl EditPredictionExcerpt { let excerpt_selector = ExcerptSelector { query_offset, query_range, + query_line_range: Line(query_line_range.start)..Line(query_line_range.end), parent_declarations: &parent_declarations, buffer, options, @@ -130,7 +139,11 @@ impl EditPredictionExcerpt { excerpt_selector.select_lines() } - fn new(range: Range, parent_declarations: Vec<(DeclarationId, Range)>) -> Self { + fn new( + range: Range, + line_range: Range, + parent_declarations: Vec<(DeclarationId, Range)>, + ) -> Self { let size = range.len() + parent_declarations .iter() @@ -140,10 +153,11 @@ impl EditPredictionExcerpt { range, parent_declarations, size, + line_range, } } - fn with_expanded_range(&self, new_range: Range) -> Self { + fn with_expanded_range(&self, new_range: Range, new_line_range: Range) -> Self { if !new_range.contains_inclusive(&self.range) { // this is an issue because parent_signature_ranges may be incorrect log::error!("bug: with_expanded_range called with disjoint range"); @@ -155,7 +169,7 @@ impl EditPredictionExcerpt { } parent_declarations.push((*declaration_id, range.clone())); } - Self::new(new_range, parent_declarations) + Self::new(new_range, new_line_range, parent_declarations) } fn parent_signatures_size(&self) -> usize { @@ -166,6 +180,7 @@ impl EditPredictionExcerpt { struct ExcerptSelector<'a> { query_offset: usize, query_range: Range, + query_line_range: Range, parent_declarations: &'a [(DeclarationId, &'a BufferDeclaration)], buffer: &'a BufferSnapshot, options: &'a EditPredictionExcerptOptions, @@ -178,10 +193,13 @@ impl<'a> ExcerptSelector<'a> { let mut cursor = selected_layer_root.walk(); loop { - let excerpt_range = node_line_start(cursor.node()).to_offset(&self.buffer) - ..node_line_end(cursor.node()).to_offset(&self.buffer); + let line_start = node_line_start(cursor.node()); + let line_end = node_line_end(cursor.node()); + let line_range = Line(line_start.row)..Line(line_end.row); + let excerpt_range = + line_start.to_offset(&self.buffer)..line_end.to_offset(&self.buffer); if excerpt_range.contains_inclusive(&self.query_range) { - let excerpt = self.make_excerpt(excerpt_range); + let excerpt = self.make_excerpt(excerpt_range, line_range); if excerpt.size <= self.options.max_bytes { return Some(self.expand_to_siblings(&mut cursor, excerpt)); } @@ -272,9 +290,13 @@ impl<'a> ExcerptSelector<'a> { let mut forward = None; while !forward_done { - let new_end = node_line_end(forward_cursor.node()).to_offset(&self.buffer); + let new_end_point = node_line_end(forward_cursor.node()); + let new_end = new_end_point.to_offset(&self.buffer); if new_end > excerpt.range.end { - let new_excerpt = excerpt.with_expanded_range(excerpt.range.start..new_end); + let new_excerpt = excerpt.with_expanded_range( + excerpt.range.start..new_end, + excerpt.line_range.start..Line(new_end_point.row), + ); if new_excerpt.size <= self.options.max_bytes { forward = Some(new_excerpt); break; @@ -289,9 +311,13 @@ impl<'a> ExcerptSelector<'a> { let mut backward = None; while !backward_done { - let new_start = node_line_start(backward_cursor.node()).to_offset(&self.buffer); + let new_start_point = node_line_start(backward_cursor.node()); + let new_start = new_start_point.to_offset(&self.buffer); if new_start < excerpt.range.start { - let new_excerpt = excerpt.with_expanded_range(new_start..excerpt.range.end); + let new_excerpt = excerpt.with_expanded_range( + new_start..excerpt.range.end, + Line(new_start_point.row)..excerpt.line_range.end, + ); if new_excerpt.size <= self.options.max_bytes { backward = Some(new_excerpt); break; @@ -339,7 +365,7 @@ impl<'a> ExcerptSelector<'a> { fn select_lines(&self) -> Option { // early return if line containing query_offset is already too large - let excerpt = self.make_excerpt(self.query_range.clone()); + let excerpt = self.make_excerpt(self.query_range.clone(), self.query_line_range.clone()); if excerpt.size > self.options.max_bytes { log::debug!( "excerpt for cursor line is {} bytes, which exceeds the window", @@ -353,24 +379,24 @@ impl<'a> ExcerptSelector<'a> { let before_bytes = (self.options.target_before_cursor_over_total_bytes * bytes_remaining as f32) as usize; - let start_point = { + let start_line = { let offset = self.query_offset.saturating_sub(before_bytes); let point = offset.to_point(self.buffer); - Point::new(point.row + 1, 0) + Line(point.row + 1) }; - let start_offset = start_point.to_offset(&self.buffer); - let end_point = { + let start_offset = Point::new(start_line.0, 0).to_offset(&self.buffer); + let end_line = { let offset = start_offset + bytes_remaining; let point = offset.to_point(self.buffer); - Point::new(point.row, 0) + Line(point.row) }; - let end_offset = end_point.to_offset(&self.buffer); + let end_offset = Point::new(end_line.0, 0).to_offset(&self.buffer); // this could be expanded further since recalculated `signature_size` may be smaller, but // skipping that for now for simplicity // // TODO: could also consider checking if lines immediately before / after fit. - let excerpt = self.make_excerpt(start_offset..end_offset); + let excerpt = self.make_excerpt(start_offset..end_offset, start_line..end_line); if excerpt.size > self.options.max_bytes { log::error!( "bug: line-based excerpt selection has size {}, \ @@ -382,14 +408,14 @@ impl<'a> ExcerptSelector<'a> { return Some(excerpt); } - fn make_excerpt(&self, range: Range) -> EditPredictionExcerpt { + fn make_excerpt(&self, range: Range, line_range: Range) -> EditPredictionExcerpt { let parent_declarations = self .parent_declarations .iter() .filter(|(_, declaration)| declaration.item_range.contains_inclusive(&range)) .map(|(id, declaration)| (*id, declaration.signature_range.clone())) .collect(); - EditPredictionExcerpt::new(range, parent_declarations) + EditPredictionExcerpt::new(range, line_range, parent_declarations) } /// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt. diff --git a/crates/zeta2/src/prediction.rs b/crates/zeta2/src/prediction.rs index d4832993b9ecd7c40f154f2ab696c66872073d5e..9611d48023d84a91e477a51ff863b9ca6f0566a8 100644 --- a/crates/zeta2/src/prediction.rs +++ b/crates/zeta2/src/prediction.rs @@ -33,7 +33,7 @@ pub struct EditPrediction { pub snapshot: BufferSnapshot, pub edit_preview: EditPreview, // We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction. - _buffer: Entity, + pub buffer: Entity, } impl EditPrediction { @@ -108,7 +108,7 @@ impl EditPrediction { edits, snapshot, edit_preview, - _buffer: buffer, + buffer, }) } @@ -184,6 +184,10 @@ pub fn interpolate_edits( if edits.is_empty() { None } else { Some(edits) } } +pub fn line_range_to_point_range(range: Range) -> Range { + language::Point::new(range.start.0, 0)..language::Point::new(range.end.0, 0) +} + fn edits_from_response( edits: &[predict_edits_v3::Edit], snapshot: &TextBufferSnapshot, @@ -191,12 +195,14 @@ fn edits_from_response( edits .iter() .flat_map(|edit| { - let old_text = snapshot.text_for_range(edit.range.clone()); + let point_range = line_range_to_point_range(edit.range.clone()); + let offset = point_range.to_offset(snapshot).start; + let old_text = snapshot.text_for_range(point_range); excerpt_edits_from_response( old_text.collect::>(), &edit.content, - edit.range.start, + offset, &snapshot, ) }) @@ -252,6 +258,7 @@ mod tests { use super::*; use cloud_llm_client::predict_edits_v3; + use edit_prediction_context::Line; use gpui::{App, Entity, TestAppContext, prelude::*}; use indoc::indoc; use language::{Buffer, ToOffset as _}; @@ -278,7 +285,7 @@ mod tests { // TODO cover more cases when multi-file is supported let big_edits = vec![predict_edits_v3::Edit { path: PathBuf::from("test.txt").into(), - range: 0..old.len(), + range: Line(0)..Line(old.lines().count() as u32), content: new.into(), }]; @@ -317,7 +324,7 @@ mod tests { edits, snapshot: cx.read(|cx| buffer.read(cx).snapshot()), path: Path::new("test.txt").into(), - _buffer: buffer.clone(), + buffer: buffer.clone(), edit_preview, }; diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index ab786fa80a6520f2b2ceb3cda177dab4b7120bc2..16caee7aefde8c38d7a466fc9e1197c7ad21b94a 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -17,8 +17,8 @@ use gpui::{ App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity, http_client, prelude::*, }; +use language::BufferSnapshot; use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint}; -use language::{BufferSnapshot, TextBufferSnapshot}; use language_model::{LlmApiToken, RefreshLlmTokenListener}; use project::Project; use release_channel::AppVersion; @@ -106,30 +106,40 @@ struct ZetaProject { current_prediction: Option, } -#[derive(Clone)] +#[derive(Debug, Clone)] struct CurrentEditPrediction { pub requested_by_buffer_id: EntityId, pub prediction: EditPrediction, } impl CurrentEditPrediction { - fn should_replace_prediction( - &self, - old_prediction: &Self, - snapshot: &TextBufferSnapshot, - ) -> bool { - if self.requested_by_buffer_id != old_prediction.requested_by_buffer_id { + fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool { + let Some(new_edits) = self + .prediction + .interpolate(&self.prediction.buffer.read(cx)) + else { + return false; + }; + + if self.prediction.buffer != old_prediction.prediction.buffer { return true; } - let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else { + let Some(old_edits) = old_prediction + .prediction + .interpolate(&old_prediction.prediction.buffer.read(cx)) + else { return true; }; - let Some(new_edits) = self.prediction.interpolate(snapshot) else { - return false; - }; - if old_edits.len() == 1 && new_edits.len() == 1 { + // This reduces the occurrence of UI thrash from replacing edits + // + // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits. + if self.requested_by_buffer_id == self.prediction.buffer.entity_id() + && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id() + && old_edits.len() == 1 + && new_edits.len() == 1 + { let (old_range, old_text) = &old_edits[0]; let (new_range, new_text) = &new_edits[0]; new_range == old_range && new_text.starts_with(old_text) @@ -421,8 +431,7 @@ impl Zeta { .current_prediction .as_ref() .is_none_or(|old_prediction| { - new_prediction - .should_replace_prediction(&old_prediction, buffer.read(cx)) + new_prediction.should_replace_prediction(&old_prediction, cx) }) { project_state.current_prediction = Some(new_prediction); @@ -926,7 +935,7 @@ fn make_cloud_request( referenced_declarations.push(predict_edits_v3::ReferencedDeclaration { path: path.as_std_path().into(), text: text.into(), - range: snippet.declaration.item_range(), + range: snippet.declaration.item_line_range(), text_is_truncated, signature_range: snippet.declaration.signature_range_in_item_text(), parent_index, @@ -954,8 +963,12 @@ fn make_cloud_request( predict_edits_v3::PredictEditsRequest { excerpt_path, excerpt: context.excerpt_text.body, + excerpt_line_range: context.excerpt.line_range, excerpt_range: context.excerpt.range, - cursor_offset: context.cursor_offset_in_excerpt, + cursor_point: predict_edits_v3::Point { + line: predict_edits_v3::Line(context.cursor_point.row), + column: context.cursor_point.column, + }, referenced_declarations, signatures, excerpt_parent, @@ -992,7 +1005,7 @@ fn add_signature( text: text.into(), text_is_truncated, parent_index, - range: parent_declaration.signature_range(), + range: parent_declaration.signature_line_range(), }); declaration_to_signature_index.insert(declaration_id, signature_index); Some(signature_index) @@ -1007,7 +1020,8 @@ mod tests { use client::UserStore; use clock::FakeSystemClock; - use cloud_llm_client::predict_edits_v3; + use cloud_llm_client::predict_edits_v3::{self, Point}; + use edit_prediction_context::Line; use futures::{ AsyncReadExt, StreamExt, channel::{mpsc, oneshot}, @@ -1067,7 +1081,7 @@ mod tests { request_id: Uuid::new_v4(), edits: vec![predict_edits_v3::Edit { path: Path::new(path!("root/1.txt")).into(), - range: 0..snapshot1.len(), + range: Line(0)..Line(snapshot1.max_point().row + 1), content: "Hello!\nHow are you?\nBye".into(), }], debug_info: None, @@ -1083,7 +1097,6 @@ mod tests { }); // Prediction for another file - let prediction_task = zeta.update(cx, |zeta, cx| { zeta.refresh_prediction(&project, &buffer1, position, cx) }); @@ -1093,14 +1106,13 @@ mod tests { request_id: Uuid::new_v4(), edits: vec![predict_edits_v3::Edit { path: Path::new(path!("root/2.txt")).into(), - range: 0..snapshot1.len(), + range: Line(0)..Line(snapshot1.max_point().row + 1), content: "Hola!\nComo estas?\nAdios".into(), }], debug_info: None, }) .unwrap(); prediction_task.await.unwrap(); - zeta.read_with(cx, |zeta, cx| { let prediction = zeta .current_prediction_for_buffer(&buffer1, &project, cx) @@ -1159,14 +1171,20 @@ mod tests { request.excerpt_path.as_ref(), Path::new(path!("root/foo.md")) ); - assert_eq!(request.cursor_offset, 10); + assert_eq!( + request.cursor_point, + Point { + line: Line(1), + column: 3 + } + ); respond_tx .send(predict_edits_v3::PredictEditsResponse { request_id: Uuid::new_v4(), edits: vec![predict_edits_v3::Edit { path: Path::new(path!("root/foo.md")).into(), - range: 0..snapshot.len(), + range: Line(0)..Line(snapshot.max_point().row + 1), content: "Hello!\nHow are you?\nBye".into(), }], debug_info: None, @@ -1244,7 +1262,7 @@ mod tests { request_id: Uuid::new_v4(), edits: vec![predict_edits_v3::Edit { path: Path::new(path!("root/foo.md")).into(), - range: 0..snapshot.len(), + range: Line(0)..Line(snapshot.max_point().row + 1), content: "Hello!\nHow are you?\nBye".into(), }], debug_info: None, diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index 7e490266decb59cf16fb9ce4acd8ecc8d7bc0f91..efd5dd2d0688571cf8cef9e77b7d89c6e8ad33a9 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -98,10 +98,11 @@ struct Zeta2Args { #[derive(clap::ValueEnum, Default, Debug, Clone)] enum PromptFormat { - #[default] MarkedExcerpt, LabeledSections, OnlySnippets, + #[default] + NumberedLines, } impl Into for PromptFormat { @@ -110,6 +111,7 @@ impl Into for PromptFormat { Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt, Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections, Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets, + Self::NumberedLines => predict_edits_v3::PromptFormat::NumberedLines, } } }