zeta2: Numbered lines prompt format (#40218)

Agus Zubiaga , Michael Sloan , and Michael created

Adds a new `NumberedLines` format which is similar to `MarkedExcerpt`
but each line is prefixed with its line number.

Also fixes a bug where contagious snippets wouldn't get merged.

Release Notes:

- N/A

---------

Co-authored-by: Michael Sloan <mgsloan@gmail.com>
Co-authored-by: Michael <michael@zed.dev>

Change summary

crates/cloud_llm_client/src/predict_edits_v3.rs               |  42 
crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs           | 204 +++-
crates/edit_prediction_context/src/declaration.rs             |  64 +
crates/edit_prediction_context/src/edit_prediction_context.rs |   9 
crates/edit_prediction_context/src/excerpt.rs                 |  74 +
crates/zeta2/src/prediction.rs                                |  19 
crates/zeta2/src/zeta2.rs                                     |  70 +
crates/zeta_cli/src/main.rs                                   |   4 
8 files changed, 350 insertions(+), 136 deletions(-)

Detailed changes

crates/cloud_llm_client/src/predict_edits_v3.rs 🔗

@@ -1,7 +1,7 @@
 use chrono::Duration;
 use serde::{Deserialize, Serialize};
 use std::{
-    ops::Range,
+    ops::{Add, Range, Sub},
     path::{Path, PathBuf},
     sync::Arc,
 };
@@ -18,8 +18,8 @@ pub struct PredictEditsRequest {
     pub excerpt_path: Arc<Path>,
     /// Within file
     pub excerpt_range: Range<usize>,
-    /// Within `excerpt`
-    pub cursor_offset: usize,
+    pub excerpt_line_range: Range<Line>,
+    pub cursor_point: Point,
     /// Within `signatures`
     pub excerpt_parent: Option<usize>,
     pub signatures: Vec<Signature>,
@@ -47,12 +47,13 @@ pub struct PredictEditsRequest {
 pub enum PromptFormat {
     MarkedExcerpt,
     LabeledSections,
+    NumberedLines,
     /// Prompt format intended for use via zeta_cli
     OnlySnippets,
 }
 
 impl PromptFormat {
-    pub const DEFAULT: PromptFormat = PromptFormat::LabeledSections;
+    pub const DEFAULT: PromptFormat = PromptFormat::NumberedLines;
 }
 
 impl Default for PromptFormat {
@@ -73,6 +74,7 @@ impl std::fmt::Display for PromptFormat {
             PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
             PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
             PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
+            PromptFormat::NumberedLines => write!(f, "Numbered Lines"),
         }
     }
 }
@@ -97,7 +99,7 @@ pub struct Signature {
     pub parent_index: Option<usize>,
     /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
     /// file is implicitly the file that contains the descendant declaration or excerpt.
-    pub range: Range<usize>,
+    pub range: Range<Line>,
 }
 
 #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -106,7 +108,7 @@ pub struct ReferencedDeclaration {
     pub text: String,
     pub text_is_truncated: bool,
     /// Range of `text` within file, possibly truncated according to `text_is_truncated`
-    pub range: Range<usize>,
+    pub range: Range<Line>,
     /// Range within `text`
     pub signature_range: Range<usize>,
     /// Index within `signatures`.
@@ -169,10 +171,36 @@ pub struct DebugInfo {
 #[derive(Debug, Clone, Serialize, Deserialize)]
 pub struct Edit {
     pub path: Arc<Path>,
-    pub range: Range<usize>,
+    pub range: Range<Line>,
     pub content: String,
 }
 
 fn is_default<T: Default + PartialEq>(value: &T) -> bool {
     *value == T::default()
 }
+
+#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
+pub struct Point {
+    pub line: Line,
+    pub column: u32,
+}
+
+#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
+#[serde(transparent)]
+pub struct Line(pub u32);
+
+impl Add for Line {
+    type Output = Self;
+
+    fn add(self, rhs: Self) -> Self::Output {
+        Self(self.0 + rhs.0)
+    }
+}
+
+impl Sub for Line {
+    type Output = Self;
+
+    fn sub(self, rhs: Self) -> Self::Output {
+        Self(self.0 - rhs.0)
+    }
+}

crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs 🔗

@@ -1,7 +1,9 @@
 //! Zeta2 prompt planning and generation code shared with cloud.
 
 use anyhow::{Context as _, Result, anyhow};
-use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, ReferencedDeclaration};
+use cloud_llm_client::predict_edits_v3::{
+    self, Event, Line, Point, PromptFormat, ReferencedDeclaration,
+};
 use indoc::indoc;
 use ordered_float::OrderedFloat;
 use rustc_hash::{FxHashMap, FxHashSet};
@@ -43,6 +45,42 @@ const LABELED_SECTIONS_SYSTEM_PROMPT: &str = indoc! {r#"
     }
 "#};
 
+const NUMBERED_LINES_SYSTEM_PROMPT: &str = indoc! {r#"
+    # Instructions
+
+    You are a code completion assistant helping a programmer finish their work. Your task is to:
+
+    1. Analyze the edit history to understand what the programmer is trying to achieve
+    2. Identify any incomplete refactoring or changes that need to be finished
+    3. Make the remaining edits that a human programmer would logically make next
+    4. Apply systematic changes consistently across the entire codebase - if you see a pattern starting, complete it everywhere.
+
+    Focus on:
+    - Understanding the intent behind the changes (e.g., improving error handling, refactoring APIs, fixing bugs)
+    - Completing any partially-applied changes across the codebase
+    - Ensuring consistency with the programming style and patterns already established
+    - Making edits that maintain or improve code quality
+    - If the programmer started refactoring one instance of a pattern, find and update ALL similar instances
+    - Don't write a lot of code if you're not sure what to do
+
+    Rules:
+    - Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals.
+    - Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code.
+    - Write the edits in the unified diff format as shown in the example.
+
+    # Example output:
+
+    ```
+    --- a/distill-claude/tmp-outs/edits_history.txt
+    +++ b/distill-claude/tmp-outs/edits_history.txt
+    @@ -1,3 +1,3 @@
+    -
+    -
+    -import sys
+    +import json
+    ```
+"#};
+
 pub struct PlannedPrompt<'a> {
     request: &'a predict_edits_v3::PredictEditsRequest,
     /// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in
@@ -55,6 +93,7 @@ pub fn system_prompt(format: PromptFormat) -> &'static str {
     match format {
         PromptFormat::MarkedExcerpt => MARKED_EXCERPT_SYSTEM_PROMPT,
         PromptFormat::LabeledSections => LABELED_SECTIONS_SYSTEM_PROMPT,
+        PromptFormat::NumberedLines => NUMBERED_LINES_SYSTEM_PROMPT,
         // only intended for use via zeta_cli
         PromptFormat::OnlySnippets => "",
     }
@@ -63,7 +102,7 @@ pub fn system_prompt(format: PromptFormat) -> &'static str {
 #[derive(Clone, Debug)]
 pub struct PlannedSnippet<'a> {
     path: Arc<Path>,
-    range: Range<usize>,
+    range: Range<Line>,
     text: &'a str,
     // TODO: Indicate this in the output
     #[allow(dead_code)]
@@ -79,7 +118,7 @@ pub enum DeclarationStyle {
 #[derive(Clone, Debug, Serialize)]
 pub struct SectionLabels {
     pub excerpt_index: usize,
-    pub section_ranges: Vec<(Arc<Path>, Range<usize>)>,
+    pub section_ranges: Vec<(Arc<Path>, Range<Line>)>,
 }
 
 impl<'a> PlannedPrompt<'a> {
@@ -196,10 +235,24 @@ impl<'a> PlannedPrompt<'a> {
                             declaration.text.len()
                         ));
                     };
+                    let signature_start_line = declaration.range.start
+                        + Line(
+                            declaration.text[..declaration.signature_range.start]
+                                .lines()
+                                .count() as u32,
+                        );
+                    let signature_end_line = signature_start_line
+                        + Line(
+                            declaration.text
+                                [declaration.signature_range.start..declaration.signature_range.end]
+                                .lines()
+                                .count() as u32,
+                        );
+                    let range = signature_start_line..signature_end_line;
+
                     PlannedSnippet {
                         path: declaration.path.clone(),
-                        range: (declaration.signature_range.start + declaration.range.start)
-                            ..(declaration.signature_range.end + declaration.range.start),
+                        range,
                         text,
                         text_is_truncated: declaration.text_is_truncated,
                     }
@@ -318,7 +371,7 @@ impl<'a> PlannedPrompt<'a> {
         }
         let excerpt_snippet = PlannedSnippet {
             path: self.request.excerpt_path.clone(),
-            range: self.request.excerpt_range.clone(),
+            range: self.request.excerpt_line_range.clone(),
             text: &self.request.excerpt,
             text_is_truncated: false,
         };
@@ -328,32 +381,33 @@ impl<'a> PlannedPrompt<'a> {
         let mut excerpt_file_insertions = match self.request.prompt_format {
             PromptFormat::MarkedExcerpt => vec![
                 (
-                    self.request.excerpt_range.start,
+                    Point {
+                        line: self.request.excerpt_line_range.start,
+                        column: 0,
+                    },
                     EDITABLE_REGION_START_MARKER_WITH_NEWLINE,
                 ),
+                (self.request.cursor_point, CURSOR_MARKER),
                 (
-                    self.request.excerpt_range.start + self.request.cursor_offset,
-                    CURSOR_MARKER,
-                ),
-                (
-                    self.request
-                        .excerpt_range
-                        .end
-                        .saturating_sub(0)
-                        .max(self.request.excerpt_range.start),
+                    Point {
+                        line: self.request.excerpt_line_range.end,
+                        column: 0,
+                    },
                     EDITABLE_REGION_END_MARKER_WITH_NEWLINE,
                 ),
             ],
-            PromptFormat::LabeledSections => vec![(
-                self.request.excerpt_range.start + self.request.cursor_offset,
-                CURSOR_MARKER,
-            )],
+            PromptFormat::LabeledSections => vec![(self.request.cursor_point, CURSOR_MARKER)],
+            PromptFormat::NumberedLines => vec![(self.request.cursor_point, CURSOR_MARKER)],
             PromptFormat::OnlySnippets => vec![],
         };
 
         let mut prompt = String::new();
         prompt.push_str("## User Edits\n\n");
-        Self::push_events(&mut prompt, &self.request.events);
+        if self.request.events.is_empty() {
+            prompt.push_str("No edits yet.\n");
+        } else {
+            Self::push_events(&mut prompt, &self.request.events);
+        }
 
         prompt.push_str("\n## Code\n\n");
         let section_labels =
@@ -391,13 +445,17 @@ impl<'a> PlannedPrompt<'a> {
                     if *predicted {
                         writeln!(
                             output,
-                            "User accepted prediction {:?}:\n```diff\n{}\n```\n",
+                            "User accepted prediction {:?}:\n`````diff\n{}\n`````\n",
                             path, diff
                         )
                         .unwrap();
                     } else {
-                        writeln!(output, "User edited {:?}:\n```diff\n{}\n```\n", path, diff)
-                            .unwrap();
+                        writeln!(
+                            output,
+                            "User edited {:?}:\n`````diff\n{}\n`````\n",
+                            path, diff
+                        )
+                        .unwrap();
                     }
                 }
             }
@@ -407,7 +465,7 @@ impl<'a> PlannedPrompt<'a> {
     fn push_file_snippets(
         &self,
         output: &mut String,
-        excerpt_file_insertions: &mut Vec<(usize, &'static str)>,
+        excerpt_file_insertions: &mut Vec<(Point, &'static str)>,
         file_snippets: Vec<(&'a Path, Vec<&'a PlannedSnippet>, bool)>,
     ) -> Result<SectionLabels> {
         let mut section_ranges = Vec::new();
@@ -417,15 +475,13 @@ impl<'a> PlannedPrompt<'a> {
             snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end)));
 
             // TODO: What if the snippets get expanded too large to be editable?
-            let mut current_snippet: Option<(&PlannedSnippet, Range<usize>)> = None;
-            let mut disjoint_snippets: Vec<(&PlannedSnippet, Range<usize>)> = Vec::new();
+            let mut current_snippet: Option<(&PlannedSnippet, Range<Line>)> = None;
+            let mut disjoint_snippets: Vec<(&PlannedSnippet, Range<Line>)> = Vec::new();
             for snippet in snippets {
                 if let Some((_, current_snippet_range)) = current_snippet.as_mut()
-                    && snippet.range.start < current_snippet_range.end
+                    && snippet.range.start <= current_snippet_range.end
                 {
-                    if snippet.range.end > current_snippet_range.end {
-                        current_snippet_range.end = snippet.range.end;
-                    }
+                    current_snippet_range.end = current_snippet_range.end.max(snippet.range.end);
                     continue;
                 }
                 if let Some(current_snippet) = current_snippet.take() {
@@ -437,21 +493,24 @@ impl<'a> PlannedPrompt<'a> {
                 disjoint_snippets.push(current_snippet);
             }
 
-            writeln!(output, "```{}", file_path.display()).ok();
+            // TODO: remove filename=?
+            writeln!(output, "`````filename={}", file_path.display()).ok();
             let mut skipped_last_snippet = false;
             for (snippet, range) in disjoint_snippets {
                 let section_index = section_ranges.len();
 
                 match self.request.prompt_format {
-                    PromptFormat::MarkedExcerpt | PromptFormat::OnlySnippets => {
-                        if range.start > 0 && !skipped_last_snippet {
+                    PromptFormat::MarkedExcerpt
+                    | PromptFormat::OnlySnippets
+                    | PromptFormat::NumberedLines => {
+                        if range.start.0 > 0 && !skipped_last_snippet {
                             output.push_str("…\n");
                         }
                     }
                     PromptFormat::LabeledSections => {
                         if is_excerpt_file
-                            && range.start <= self.request.excerpt_range.start
-                            && range.end >= self.request.excerpt_range.end
+                            && range.start <= self.request.excerpt_line_range.start
+                            && range.end >= self.request.excerpt_line_range.end
                         {
                             writeln!(output, "<|current_section|>").ok();
                         } else {
@@ -460,46 +519,83 @@ impl<'a> PlannedPrompt<'a> {
                     }
                 }
 
+                let push_full_snippet = |output: &mut String| {
+                    if self.request.prompt_format == PromptFormat::NumberedLines {
+                        for (i, line) in snippet.text.lines().enumerate() {
+                            writeln!(output, "{}|{}", i as u32 + range.start.0 + 1, line)?;
+                        }
+                    } else {
+                        output.push_str(&snippet.text);
+                    }
+                    anyhow::Ok(())
+                };
+
                 if is_excerpt_file {
                     if self.request.prompt_format == PromptFormat::OnlySnippets {
-                        if range.start >= self.request.excerpt_range.start
-                            && range.end <= self.request.excerpt_range.end
+                        if range.start >= self.request.excerpt_line_range.start
+                            && range.end <= self.request.excerpt_line_range.end
                         {
                             skipped_last_snippet = true;
                         } else {
                             skipped_last_snippet = false;
                             output.push_str(snippet.text);
                         }
-                    } else {
-                        let mut last_offset = range.start;
-                        let mut i = 0;
-                        while i < excerpt_file_insertions.len() {
-                            let (offset, insertion) = &excerpt_file_insertions[i];
-                            let found = *offset >= range.start && *offset <= range.end;
+                    } else if !excerpt_file_insertions.is_empty() {
+                        let lines = snippet.text.lines().collect::<Vec<_>>();
+                        let push_line = |output: &mut String, line_ix: usize| {
+                            if self.request.prompt_format == PromptFormat::NumberedLines {
+                                write!(output, "{}|", line_ix as u32 + range.start.0 + 1)?;
+                            }
+                            anyhow::Ok(writeln!(output, "{}", lines[line_ix])?)
+                        };
+                        let mut last_line_ix = 0;
+                        let mut insertion_ix = 0;
+                        while insertion_ix < excerpt_file_insertions.len() {
+                            let (point, insertion) = &excerpt_file_insertions[insertion_ix];
+                            let found = point.line >= range.start && point.line <= range.end;
                             if found {
                                 excerpt_index = Some(section_index);
-                                output.push_str(
-                                    &snippet.text[last_offset - range.start..offset - range.start],
-                                );
-                                output.push_str(insertion);
-                                last_offset = *offset;
-                                excerpt_file_insertions.remove(i);
+                                let insertion_line_ix = (point.line.0 - range.start.0) as usize;
+                                for line_ix in last_line_ix..insertion_line_ix {
+                                    push_line(output, line_ix)?;
+                                }
+                                if let Some(next_line) = lines.get(insertion_line_ix) {
+                                    if self.request.prompt_format == PromptFormat::NumberedLines {
+                                        write!(
+                                            output,
+                                            "{}|",
+                                            insertion_line_ix as u32 + range.start.0 + 1
+                                        )?
+                                    }
+                                    output.push_str(&next_line[..point.column as usize]);
+                                    output.push_str(insertion);
+                                    writeln!(output, "{}", &next_line[point.column as usize..])?;
+                                } else {
+                                    writeln!(output, "{}", insertion)?;
+                                }
+                                last_line_ix = insertion_line_ix + 1;
+                                excerpt_file_insertions.remove(insertion_ix);
                                 continue;
                             }
-                            i += 1;
+                            insertion_ix += 1;
                         }
                         skipped_last_snippet = false;
-                        output.push_str(&snippet.text[last_offset - range.start..]);
+                        for line_ix in last_line_ix..lines.len() {
+                            push_line(output, line_ix)?;
+                        }
+                    } else {
+                        skipped_last_snippet = false;
+                        push_full_snippet(output)?;
                     }
                 } else {
                     skipped_last_snippet = false;
-                    output.push_str(snippet.text);
+                    push_full_snippet(output)?;
                 }
 
                 section_ranges.push((snippet.path.clone(), range));
             }
 
-            output.push_str("```\n\n");
+            output.push_str("`````\n\n");
         }
 
         Ok(SectionLabels {

crates/edit_prediction_context/src/declaration.rs 🔗

@@ -1,3 +1,4 @@
+use cloud_llm_client::predict_edits_v3::{self, Line};
 use language::{Language, LanguageId};
 use project::ProjectEntryId;
 use std::ops::Range;
@@ -91,6 +92,18 @@ impl Declaration {
         }
     }
 
+    pub fn item_line_range(&self) -> Range<Line> {
+        match self {
+            Declaration::File { declaration, .. } => declaration.item_line_range.clone(),
+            Declaration::Buffer {
+                declaration, rope, ..
+            } => {
+                Line(rope.offset_to_point(declaration.item_range.start).row)
+                    ..Line(rope.offset_to_point(declaration.item_range.end).row)
+            }
+        }
+    }
+
     pub fn item_text(&self) -> (Cow<'_, str>, bool) {
         match self {
             Declaration::File { declaration, .. } => (
@@ -130,6 +143,18 @@ impl Declaration {
         }
     }
 
+    pub fn signature_line_range(&self) -> Range<Line> {
+        match self {
+            Declaration::File { declaration, .. } => declaration.signature_line_range.clone(),
+            Declaration::Buffer {
+                declaration, rope, ..
+            } => {
+                Line(rope.offset_to_point(declaration.signature_range.start).row)
+                    ..Line(rope.offset_to_point(declaration.signature_range.end).row)
+            }
+        }
+    }
+
     pub fn signature_range_in_item_text(&self) -> Range<usize> {
         let signature_range = self.signature_range();
         let item_range = self.item_range();
@@ -142,7 +167,7 @@ fn expand_range_to_line_boundaries_and_truncate(
     range: &Range<usize>,
     limit: usize,
     rope: &Rope,
-) -> (Range<usize>, bool) {
+) -> (Range<usize>, Range<predict_edits_v3::Line>, bool) {
     let mut point_range = rope.offset_to_point(range.start)..rope.offset_to_point(range.end);
     point_range.start.column = 0;
     point_range.end.row += 1;
@@ -155,7 +180,10 @@ fn expand_range_to_line_boundaries_and_truncate(
         item_range.end = item_range.start + limit;
     }
     item_range.end = rope.clip_offset(item_range.end, Bias::Left);
-    (item_range, is_truncated)
+
+    let line_range =
+        predict_edits_v3::Line(point_range.start.row)..predict_edits_v3::Line(point_range.end.row);
+    (item_range, line_range, is_truncated)
 }
 
 #[derive(Debug, Clone)]
@@ -164,25 +192,30 @@ pub struct FileDeclaration {
     pub identifier: Identifier,
     /// offset range of the declaration in the file, expanded to line boundaries and truncated
     pub item_range: Range<usize>,
+    /// line range of the declaration in the file, potentially truncated
+    pub item_line_range: Range<predict_edits_v3::Line>,
     /// text of `item_range`
     pub text: Arc<str>,
     /// whether `text` was truncated
     pub text_is_truncated: bool,
     /// offset range of the signature in the file, expanded to line boundaries and truncated
     pub signature_range: Range<usize>,
+    /// line range of the signature in the file, truncated
+    pub signature_line_range: Range<Line>,
     /// whether `signature` was truncated
     pub signature_is_truncated: bool,
 }
 
 impl FileDeclaration {
     pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> FileDeclaration {
-        let (item_range_in_file, text_is_truncated) = expand_range_to_line_boundaries_and_truncate(
-            &declaration.item_range,
-            ITEM_TEXT_TRUNCATION_LENGTH,
-            rope,
-        );
+        let (item_range_in_file, item_line_range_in_file, text_is_truncated) =
+            expand_range_to_line_boundaries_and_truncate(
+                &declaration.item_range,
+                ITEM_TEXT_TRUNCATION_LENGTH,
+                rope,
+            );
 
-        let (mut signature_range_in_file, mut signature_is_truncated) =
+        let (mut signature_range_in_file, signature_line_range, mut signature_is_truncated) =
             expand_range_to_line_boundaries_and_truncate(
                 &declaration.signature_range,
                 ITEM_TEXT_TRUNCATION_LENGTH,
@@ -202,6 +235,7 @@ impl FileDeclaration {
             parent: None,
             identifier: declaration.identifier,
             signature_range: signature_range_in_file,
+            signature_line_range,
             signature_is_truncated,
             text: rope
                 .chunks_in_range(item_range_in_file.clone())
@@ -209,6 +243,7 @@ impl FileDeclaration {
                 .into(),
             text_is_truncated,
             item_range: item_range_in_file,
+            item_line_range: item_line_range_in_file,
         }
     }
 }
@@ -225,12 +260,13 @@ pub struct BufferDeclaration {
 
 impl BufferDeclaration {
     pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> Self {
-        let (item_range, item_range_is_truncated) = expand_range_to_line_boundaries_and_truncate(
-            &declaration.item_range,
-            ITEM_TEXT_TRUNCATION_LENGTH,
-            rope,
-        );
-        let (signature_range, signature_range_is_truncated) =
+        let (item_range, _item_line_range, item_range_is_truncated) =
+            expand_range_to_line_boundaries_and_truncate(
+                &declaration.item_range,
+                ITEM_TEXT_TRUNCATION_LENGTH,
+                rope,
+            );
+        let (signature_range, _signature_line_range, signature_range_is_truncated) =
             expand_range_to_line_boundaries_and_truncate(
                 &declaration.signature_range,
                 ITEM_TEXT_TRUNCATION_LENGTH,

crates/edit_prediction_context/src/edit_prediction_context.rs 🔗

@@ -9,6 +9,7 @@ pub mod text_similarity;
 
 use std::{path::Path, sync::Arc};
 
+use cloud_llm_client::predict_edits_v3;
 use collections::HashMap;
 use gpui::{App, AppContext as _, Entity, Task};
 use language::BufferSnapshot;
@@ -21,6 +22,8 @@ pub use imports::*;
 pub use reference::*;
 pub use syntax_index::*;
 
+pub use predict_edits_v3::Line;
+
 #[derive(Clone, Debug, PartialEq)]
 pub struct EditPredictionContextOptions {
     pub use_imports: bool,
@@ -32,7 +35,7 @@ pub struct EditPredictionContextOptions {
 pub struct EditPredictionContext {
     pub excerpt: EditPredictionExcerpt,
     pub excerpt_text: EditPredictionExcerptText,
-    pub cursor_offset_in_excerpt: usize,
+    pub cursor_point: Point,
     pub declarations: Vec<ScoredDeclaration>,
 }
 
@@ -124,8 +127,6 @@ impl EditPredictionContext {
         );
 
         let cursor_offset_in_file = cursor_point.to_offset(buffer);
-        // TODO fix this to not need saturating_sub
-        let cursor_offset_in_excerpt = cursor_offset_in_file.saturating_sub(excerpt.range.start);
 
         let declarations = if let Some(index_state) = index_state {
             let references = get_references(&excerpt, &excerpt_text, buffer);
@@ -148,7 +149,7 @@ impl EditPredictionContext {
         Some(Self {
             excerpt,
             excerpt_text,
-            cursor_offset_in_excerpt,
+            cursor_point,
             declarations,
         })
     }

crates/edit_prediction_context/src/excerpt.rs 🔗

@@ -4,7 +4,7 @@ use text::{Point, ToOffset as _, ToPoint as _};
 use tree_sitter::{Node, TreeCursor};
 use util::RangeExt;
 
-use crate::{BufferDeclaration, declaration::DeclarationId, syntax_index::SyntaxIndexState};
+use crate::{BufferDeclaration, Line, declaration::DeclarationId, syntax_index::SyntaxIndexState};
 
 // TODO:
 //
@@ -35,6 +35,7 @@ pub struct EditPredictionExcerptOptions {
 #[derive(Debug, Clone)]
 pub struct EditPredictionExcerpt {
     pub range: Range<usize>,
+    pub line_range: Range<Line>,
     pub parent_declarations: Vec<(DeclarationId, Range<usize>)>,
     pub size: usize,
 }
@@ -86,12 +87,19 @@ impl EditPredictionExcerpt {
                 buffer.len(),
                 options.max_bytes
             );
-            return Some(EditPredictionExcerpt::new(0..buffer.len(), Vec::new()));
+            let offset_range = 0..buffer.len();
+            let line_range = Line(0)..Line(buffer.max_point().row);
+            return Some(EditPredictionExcerpt::new(
+                offset_range,
+                line_range,
+                Vec::new(),
+            ));
         }
 
         let query_offset = query_point.to_offset(buffer);
-        let query_range = Point::new(query_point.row, 0).to_offset(buffer)
-            ..Point::new(query_point.row + 1, 0).to_offset(buffer);
+        let query_line_range = query_point.row..query_point.row + 1;
+        let query_range = Point::new(query_line_range.start, 0).to_offset(buffer)
+            ..Point::new(query_line_range.end, 0).to_offset(buffer);
         if query_range.len() >= options.max_bytes {
             return None;
         }
@@ -107,6 +115,7 @@ impl EditPredictionExcerpt {
         let excerpt_selector = ExcerptSelector {
             query_offset,
             query_range,
+            query_line_range: Line(query_line_range.start)..Line(query_line_range.end),
             parent_declarations: &parent_declarations,
             buffer,
             options,
@@ -130,7 +139,11 @@ impl EditPredictionExcerpt {
         excerpt_selector.select_lines()
     }
 
-    fn new(range: Range<usize>, parent_declarations: Vec<(DeclarationId, Range<usize>)>) -> Self {
+    fn new(
+        range: Range<usize>,
+        line_range: Range<Line>,
+        parent_declarations: Vec<(DeclarationId, Range<usize>)>,
+    ) -> Self {
         let size = range.len()
             + parent_declarations
                 .iter()
@@ -140,10 +153,11 @@ impl EditPredictionExcerpt {
             range,
             parent_declarations,
             size,
+            line_range,
         }
     }
 
-    fn with_expanded_range(&self, new_range: Range<usize>) -> Self {
+    fn with_expanded_range(&self, new_range: Range<usize>, new_line_range: Range<Line>) -> Self {
         if !new_range.contains_inclusive(&self.range) {
             // this is an issue because parent_signature_ranges may be incorrect
             log::error!("bug: with_expanded_range called with disjoint range");
@@ -155,7 +169,7 @@ impl EditPredictionExcerpt {
             }
             parent_declarations.push((*declaration_id, range.clone()));
         }
-        Self::new(new_range, parent_declarations)
+        Self::new(new_range, new_line_range, parent_declarations)
     }
 
     fn parent_signatures_size(&self) -> usize {
@@ -166,6 +180,7 @@ impl EditPredictionExcerpt {
 struct ExcerptSelector<'a> {
     query_offset: usize,
     query_range: Range<usize>,
+    query_line_range: Range<Line>,
     parent_declarations: &'a [(DeclarationId, &'a BufferDeclaration)],
     buffer: &'a BufferSnapshot,
     options: &'a EditPredictionExcerptOptions,
@@ -178,10 +193,13 @@ impl<'a> ExcerptSelector<'a> {
         let mut cursor = selected_layer_root.walk();
 
         loop {
-            let excerpt_range = node_line_start(cursor.node()).to_offset(&self.buffer)
-                ..node_line_end(cursor.node()).to_offset(&self.buffer);
+            let line_start = node_line_start(cursor.node());
+            let line_end = node_line_end(cursor.node());
+            let line_range = Line(line_start.row)..Line(line_end.row);
+            let excerpt_range =
+                line_start.to_offset(&self.buffer)..line_end.to_offset(&self.buffer);
             if excerpt_range.contains_inclusive(&self.query_range) {
-                let excerpt = self.make_excerpt(excerpt_range);
+                let excerpt = self.make_excerpt(excerpt_range, line_range);
                 if excerpt.size <= self.options.max_bytes {
                     return Some(self.expand_to_siblings(&mut cursor, excerpt));
                 }
@@ -272,9 +290,13 @@ impl<'a> ExcerptSelector<'a> {
 
             let mut forward = None;
             while !forward_done {
-                let new_end = node_line_end(forward_cursor.node()).to_offset(&self.buffer);
+                let new_end_point = node_line_end(forward_cursor.node());
+                let new_end = new_end_point.to_offset(&self.buffer);
                 if new_end > excerpt.range.end {
-                    let new_excerpt = excerpt.with_expanded_range(excerpt.range.start..new_end);
+                    let new_excerpt = excerpt.with_expanded_range(
+                        excerpt.range.start..new_end,
+                        excerpt.line_range.start..Line(new_end_point.row),
+                    );
                     if new_excerpt.size <= self.options.max_bytes {
                         forward = Some(new_excerpt);
                         break;
@@ -289,9 +311,13 @@ impl<'a> ExcerptSelector<'a> {
 
             let mut backward = None;
             while !backward_done {
-                let new_start = node_line_start(backward_cursor.node()).to_offset(&self.buffer);
+                let new_start_point = node_line_start(backward_cursor.node());
+                let new_start = new_start_point.to_offset(&self.buffer);
                 if new_start < excerpt.range.start {
-                    let new_excerpt = excerpt.with_expanded_range(new_start..excerpt.range.end);
+                    let new_excerpt = excerpt.with_expanded_range(
+                        new_start..excerpt.range.end,
+                        Line(new_start_point.row)..excerpt.line_range.end,
+                    );
                     if new_excerpt.size <= self.options.max_bytes {
                         backward = Some(new_excerpt);
                         break;
@@ -339,7 +365,7 @@ impl<'a> ExcerptSelector<'a> {
 
     fn select_lines(&self) -> Option<EditPredictionExcerpt> {
         // early return if line containing query_offset is already too large
-        let excerpt = self.make_excerpt(self.query_range.clone());
+        let excerpt = self.make_excerpt(self.query_range.clone(), self.query_line_range.clone());
         if excerpt.size > self.options.max_bytes {
             log::debug!(
                 "excerpt for cursor line is {} bytes, which exceeds the window",
@@ -353,24 +379,24 @@ impl<'a> ExcerptSelector<'a> {
         let before_bytes =
             (self.options.target_before_cursor_over_total_bytes * bytes_remaining as f32) as usize;
 
-        let start_point = {
+        let start_line = {
             let offset = self.query_offset.saturating_sub(before_bytes);
             let point = offset.to_point(self.buffer);
-            Point::new(point.row + 1, 0)
+            Line(point.row + 1)
         };
-        let start_offset = start_point.to_offset(&self.buffer);
-        let end_point = {
+        let start_offset = Point::new(start_line.0, 0).to_offset(&self.buffer);
+        let end_line = {
             let offset = start_offset + bytes_remaining;
             let point = offset.to_point(self.buffer);
-            Point::new(point.row, 0)
+            Line(point.row)
         };
-        let end_offset = end_point.to_offset(&self.buffer);
+        let end_offset = Point::new(end_line.0, 0).to_offset(&self.buffer);
 
         // this could be expanded further since recalculated `signature_size` may be smaller, but
         // skipping that for now for simplicity
         //
         // TODO: could also consider checking if lines immediately before / after fit.
-        let excerpt = self.make_excerpt(start_offset..end_offset);
+        let excerpt = self.make_excerpt(start_offset..end_offset, start_line..end_line);
         if excerpt.size > self.options.max_bytes {
             log::error!(
                 "bug: line-based excerpt selection has size {}, \
@@ -382,14 +408,14 @@ impl<'a> ExcerptSelector<'a> {
         return Some(excerpt);
     }
 
-    fn make_excerpt(&self, range: Range<usize>) -> EditPredictionExcerpt {
+    fn make_excerpt(&self, range: Range<usize>, line_range: Range<Line>) -> EditPredictionExcerpt {
         let parent_declarations = self
             .parent_declarations
             .iter()
             .filter(|(_, declaration)| declaration.item_range.contains_inclusive(&range))
             .map(|(id, declaration)| (*id, declaration.signature_range.clone()))
             .collect();
-        EditPredictionExcerpt::new(range, parent_declarations)
+        EditPredictionExcerpt::new(range, line_range, parent_declarations)
     }
 
     /// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt.

crates/zeta2/src/prediction.rs 🔗

@@ -33,7 +33,7 @@ pub struct EditPrediction {
     pub snapshot: BufferSnapshot,
     pub edit_preview: EditPreview,
     // We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction.
-    _buffer: Entity<Buffer>,
+    pub buffer: Entity<Buffer>,
 }
 
 impl EditPrediction {
@@ -108,7 +108,7 @@ impl EditPrediction {
             edits,
             snapshot,
             edit_preview,
-            _buffer: buffer,
+            buffer,
         })
     }
 
@@ -184,6 +184,10 @@ pub fn interpolate_edits(
     if edits.is_empty() { None } else { Some(edits) }
 }
 
+pub fn line_range_to_point_range(range: Range<predict_edits_v3::Line>) -> Range<language::Point> {
+    language::Point::new(range.start.0, 0)..language::Point::new(range.end.0, 0)
+}
+
 fn edits_from_response(
     edits: &[predict_edits_v3::Edit],
     snapshot: &TextBufferSnapshot,
@@ -191,12 +195,14 @@ fn edits_from_response(
     edits
         .iter()
         .flat_map(|edit| {
-            let old_text = snapshot.text_for_range(edit.range.clone());
+            let point_range = line_range_to_point_range(edit.range.clone());
+            let offset = point_range.to_offset(snapshot).start;
+            let old_text = snapshot.text_for_range(point_range);
 
             excerpt_edits_from_response(
                 old_text.collect::<Cow<str>>(),
                 &edit.content,
-                edit.range.start,
+                offset,
                 &snapshot,
             )
         })
@@ -252,6 +258,7 @@ mod tests {
 
     use super::*;
     use cloud_llm_client::predict_edits_v3;
+    use edit_prediction_context::Line;
     use gpui::{App, Entity, TestAppContext, prelude::*};
     use indoc::indoc;
     use language::{Buffer, ToOffset as _};
@@ -278,7 +285,7 @@ mod tests {
         // TODO cover more cases when multi-file is supported
         let big_edits = vec![predict_edits_v3::Edit {
             path: PathBuf::from("test.txt").into(),
-            range: 0..old.len(),
+            range: Line(0)..Line(old.lines().count() as u32),
             content: new.into(),
         }];
 
@@ -317,7 +324,7 @@ mod tests {
             edits,
             snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
             path: Path::new("test.txt").into(),
-            _buffer: buffer.clone(),
+            buffer: buffer.clone(),
             edit_preview,
         };
 

crates/zeta2/src/zeta2.rs 🔗

@@ -17,8 +17,8 @@ use gpui::{
     App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
     http_client, prelude::*,
 };
+use language::BufferSnapshot;
 use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
-use language::{BufferSnapshot, TextBufferSnapshot};
 use language_model::{LlmApiToken, RefreshLlmTokenListener};
 use project::Project;
 use release_channel::AppVersion;
@@ -106,30 +106,40 @@ struct ZetaProject {
     current_prediction: Option<CurrentEditPrediction>,
 }
 
-#[derive(Clone)]
+#[derive(Debug, Clone)]
 struct CurrentEditPrediction {
     pub requested_by_buffer_id: EntityId,
     pub prediction: EditPrediction,
 }
 
 impl CurrentEditPrediction {
-    fn should_replace_prediction(
-        &self,
-        old_prediction: &Self,
-        snapshot: &TextBufferSnapshot,
-    ) -> bool {
-        if self.requested_by_buffer_id != old_prediction.requested_by_buffer_id {
+    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
+        let Some(new_edits) = self
+            .prediction
+            .interpolate(&self.prediction.buffer.read(cx))
+        else {
+            return false;
+        };
+
+        if self.prediction.buffer != old_prediction.prediction.buffer {
             return true;
         }
 
-        let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
+        let Some(old_edits) = old_prediction
+            .prediction
+            .interpolate(&old_prediction.prediction.buffer.read(cx))
+        else {
             return true;
         };
 
-        let Some(new_edits) = self.prediction.interpolate(snapshot) else {
-            return false;
-        };
-        if old_edits.len() == 1 && new_edits.len() == 1 {
+        // This reduces the occurrence of UI thrash from replacing edits
+        //
+        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
+        if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
+            && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
+            && old_edits.len() == 1
+            && new_edits.len() == 1
+        {
             let (old_range, old_text) = &old_edits[0];
             let (new_range, new_text) = &new_edits[0];
             new_range == old_range && new_text.starts_with(old_text)
@@ -421,8 +431,7 @@ impl Zeta {
                         .current_prediction
                         .as_ref()
                         .is_none_or(|old_prediction| {
-                            new_prediction
-                                .should_replace_prediction(&old_prediction, buffer.read(cx))
+                            new_prediction.should_replace_prediction(&old_prediction, cx)
                         })
                     {
                         project_state.current_prediction = Some(new_prediction);
@@ -926,7 +935,7 @@ fn make_cloud_request(
         referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
             path: path.as_std_path().into(),
             text: text.into(),
-            range: snippet.declaration.item_range(),
+            range: snippet.declaration.item_line_range(),
             text_is_truncated,
             signature_range: snippet.declaration.signature_range_in_item_text(),
             parent_index,
@@ -954,8 +963,12 @@ fn make_cloud_request(
     predict_edits_v3::PredictEditsRequest {
         excerpt_path,
         excerpt: context.excerpt_text.body,
+        excerpt_line_range: context.excerpt.line_range,
         excerpt_range: context.excerpt.range,
-        cursor_offset: context.cursor_offset_in_excerpt,
+        cursor_point: predict_edits_v3::Point {
+            line: predict_edits_v3::Line(context.cursor_point.row),
+            column: context.cursor_point.column,
+        },
         referenced_declarations,
         signatures,
         excerpt_parent,
@@ -992,7 +1005,7 @@ fn add_signature(
         text: text.into(),
         text_is_truncated,
         parent_index,
-        range: parent_declaration.signature_range(),
+        range: parent_declaration.signature_line_range(),
     });
     declaration_to_signature_index.insert(declaration_id, signature_index);
     Some(signature_index)
@@ -1007,7 +1020,8 @@ mod tests {
 
     use client::UserStore;
     use clock::FakeSystemClock;
-    use cloud_llm_client::predict_edits_v3;
+    use cloud_llm_client::predict_edits_v3::{self, Point};
+    use edit_prediction_context::Line;
     use futures::{
         AsyncReadExt, StreamExt,
         channel::{mpsc, oneshot},
@@ -1067,7 +1081,7 @@ mod tests {
                 request_id: Uuid::new_v4(),
                 edits: vec![predict_edits_v3::Edit {
                     path: Path::new(path!("root/1.txt")).into(),
-                    range: 0..snapshot1.len(),
+                    range: Line(0)..Line(snapshot1.max_point().row + 1),
                     content: "Hello!\nHow are you?\nBye".into(),
                 }],
                 debug_info: None,
@@ -1083,7 +1097,6 @@ mod tests {
         });
 
         // Prediction for another file
-
         let prediction_task = zeta.update(cx, |zeta, cx| {
             zeta.refresh_prediction(&project, &buffer1, position, cx)
         });
@@ -1093,14 +1106,13 @@ mod tests {
                 request_id: Uuid::new_v4(),
                 edits: vec![predict_edits_v3::Edit {
                     path: Path::new(path!("root/2.txt")).into(),
-                    range: 0..snapshot1.len(),
+                    range: Line(0)..Line(snapshot1.max_point().row + 1),
                     content: "Hola!\nComo estas?\nAdios".into(),
                 }],
                 debug_info: None,
             })
             .unwrap();
         prediction_task.await.unwrap();
-
         zeta.read_with(cx, |zeta, cx| {
             let prediction = zeta
                 .current_prediction_for_buffer(&buffer1, &project, cx)
@@ -1159,14 +1171,20 @@ mod tests {
             request.excerpt_path.as_ref(),
             Path::new(path!("root/foo.md"))
         );
-        assert_eq!(request.cursor_offset, 10);
+        assert_eq!(
+            request.cursor_point,
+            Point {
+                line: Line(1),
+                column: 3
+            }
+        );
 
         respond_tx
             .send(predict_edits_v3::PredictEditsResponse {
                 request_id: Uuid::new_v4(),
                 edits: vec![predict_edits_v3::Edit {
                     path: Path::new(path!("root/foo.md")).into(),
-                    range: 0..snapshot.len(),
+                    range: Line(0)..Line(snapshot.max_point().row + 1),
                     content: "Hello!\nHow are you?\nBye".into(),
                 }],
                 debug_info: None,
@@ -1244,7 +1262,7 @@ mod tests {
                 request_id: Uuid::new_v4(),
                 edits: vec![predict_edits_v3::Edit {
                     path: Path::new(path!("root/foo.md")).into(),
-                    range: 0..snapshot.len(),
+                    range: Line(0)..Line(snapshot.max_point().row + 1),
                     content: "Hello!\nHow are you?\nBye".into(),
                 }],
                 debug_info: None,

crates/zeta_cli/src/main.rs 🔗

@@ -98,10 +98,11 @@ struct Zeta2Args {
 
 #[derive(clap::ValueEnum, Default, Debug, Clone)]
 enum PromptFormat {
-    #[default]
     MarkedExcerpt,
     LabeledSections,
     OnlySnippets,
+    #[default]
+    NumberedLines,
 }
 
 impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
@@ -110,6 +111,7 @@ impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
             Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
             Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
             Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
+            Self::NumberedLines => predict_edits_v3::PromptFormat::NumberedLines,
         }
     }
 }