edit prediction: Try to expand context to parent treesitter region (#24186)

Bennet Bo Fenner and Antonio Scandurra created

Also send the `speculated_output` (which is just the editable region) to
the llm backend

Closes #ISSUE

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>

Change summary

crates/rpc/src/llm.rs            |   1 
crates/zeta/src/input_excerpt.rs | 238 ++++++++++++++++++++++++++++++
crates/zeta/src/zeta.rs          | 263 +++++++--------------------------
3 files changed, 296 insertions(+), 206 deletions(-)

Detailed changes

crates/rpc/src/llm.rs 🔗

@@ -39,6 +39,7 @@ pub struct PredictEditsParams {
     pub outline: Option<String>,
     pub input_events: String,
     pub input_excerpt: String,
+    pub speculated_output: String,
     /// Whether the user provided consent for sampling this interaction.
     #[serde(default)]
     pub data_collection_permission: bool,

crates/zeta/src/input_excerpt.rs 🔗

@@ -0,0 +1,238 @@
+use crate::{
+    BYTES_PER_TOKEN_GUESS, CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER,
+    START_OF_FILE_MARKER,
+};
+use language::{BufferSnapshot, Point};
+use std::{fmt::Write, ops::Range};
+
+pub struct InputExcerpt {
+    pub editable_range: Range<Point>,
+    pub prompt: String,
+    pub speculated_output: String,
+}
+
+pub fn excerpt_for_cursor_position(
+    position: Point,
+    path: &str,
+    snapshot: &BufferSnapshot,
+    editable_region_token_limit: usize,
+    context_token_limit: usize,
+) -> InputExcerpt {
+    let mut scope_range = position..position;
+    let mut remaining_edit_tokens = editable_region_token_limit;
+
+    while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) {
+        let parent_tokens = tokens_for_bytes(parent.byte_range().len());
+        if parent_tokens <= editable_region_token_limit {
+            scope_range = Point::new(
+                parent.start_position().row as u32,
+                parent.start_position().column as u32,
+            )
+                ..Point::new(
+                    parent.end_position().row as u32,
+                    parent.end_position().column as u32,
+                );
+            remaining_edit_tokens = editable_region_token_limit - parent_tokens;
+        } else {
+            break;
+        }
+    }
+
+    let editable_range = expand_range(snapshot, scope_range, remaining_edit_tokens);
+    let context_range = expand_range(snapshot, editable_range.clone(), context_token_limit);
+
+    let mut prompt = String::new();
+    let mut speculated_output = String::new();
+
+    writeln!(&mut prompt, "```{path}").unwrap();
+    if context_range.start == Point::zero() {
+        writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
+    }
+
+    for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
+        prompt.push_str(chunk.text);
+    }
+
+    push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
+    push_editable_range(
+        position,
+        snapshot,
+        editable_range.clone(),
+        &mut speculated_output,
+    );
+
+    for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
+        prompt.push_str(chunk.text);
+    }
+    write!(prompt, "\n```").unwrap();
+
+    InputExcerpt {
+        editable_range,
+        prompt,
+        speculated_output,
+    }
+}
+
+fn push_editable_range(
+    cursor_position: Point,
+    snapshot: &BufferSnapshot,
+    editable_range: Range<Point>,
+    prompt: &mut String,
+) {
+    writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
+    for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
+        prompt.push_str(chunk.text);
+    }
+    prompt.push_str(CURSOR_MARKER);
+    for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
+        prompt.push_str(chunk.text);
+    }
+    write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
+}
+
+fn expand_range(
+    snapshot: &BufferSnapshot,
+    range: Range<Point>,
+    mut remaining_tokens: usize,
+) -> Range<Point> {
+    let mut expanded_range = range.clone();
+    expanded_range.start.column = 0;
+    expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
+    loop {
+        let mut expanded = false;
+
+        if remaining_tokens > 0 && expanded_range.start.row > 0 {
+            expanded_range.start.row -= 1;
+            let line_tokens =
+                tokens_for_bytes(snapshot.line_len(expanded_range.start.row) as usize);
+            remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
+            expanded = true;
+        }
+
+        if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row {
+            expanded_range.end.row += 1;
+            expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
+            let line_tokens = tokens_for_bytes(expanded_range.end.column as usize);
+            remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
+            expanded = true;
+        }
+
+        if !expanded {
+            break;
+        }
+    }
+    expanded_range
+}
+
+fn tokens_for_bytes(bytes: usize) -> usize {
+    bytes / BYTES_PER_TOKEN_GUESS
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use gpui::{App, AppContext};
+    use indoc::indoc;
+    use language::{Buffer, Language, LanguageConfig, LanguageMatcher};
+    use std::sync::Arc;
+
+    #[gpui::test]
+    fn test_excerpt_for_cursor_position(cx: &mut App) {
+        let text = indoc! {r#"
+            fn foo() {
+                let x = 42;
+                println!("Hello, world!");
+            }
+
+            fn bar() {
+                let x = 42;
+                let mut sum = 0;
+                for i in 0..x {
+                    sum += i;
+                }
+                println!("Sum: {}", sum);
+                return sum;
+            }
+
+            fn generate_random_numbers() -> Vec<i32> {
+                let mut rng = rand::thread_rng();
+                let mut numbers = Vec::new();
+                for _ in 0..5 {
+                    numbers.push(rng.gen_range(1..101));
+                }
+                numbers
+            }
+        "#};
+        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
+        let snapshot = buffer.read(cx).snapshot();
+
+        // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
+        // when a larger scope doesn't fit the editable region.
+        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
+        assert_eq!(
+            excerpt.prompt,
+            indoc! {r#"
+            ```main.rs
+                let x = 42;
+                println!("Hello, world!");
+            <|editable_region_start|>
+            }
+
+            fn bar() {
+                let x = 42;
+                let mut sum = 0;
+                for i in 0..x {
+                    sum += i;
+                }
+                println!("Sum: {}", sum);
+                r<|user_cursor_is_here|>eturn sum;
+            }
+
+            fn generate_random_numbers() -> Vec<i32> {
+            <|editable_region_end|>
+                let mut rng = rand::thread_rng();
+                let mut numbers = Vec::new();
+            ```"#}
+        );
+
+        // The `bar` function won't fit within the editable region, so we resort to line-based expansion.
+        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
+        assert_eq!(
+            excerpt.prompt,
+            indoc! {r#"
+            ```main.rs
+            fn bar() {
+                let x = 42;
+                let mut sum = 0;
+            <|editable_region_start|>
+                for i in 0..x {
+                    sum += i;
+                }
+                println!("Sum: {}", sum);
+                r<|user_cursor_is_here|>eturn sum;
+            }
+
+            fn generate_random_numbers() -> Vec<i32> {
+                let mut rng = rand::thread_rng();
+            <|editable_region_end|>
+                let mut numbers = Vec::new();
+                for _ in 0..5 {
+                    numbers.push(rng.gen_range(1..101));
+            ```"#}
+        );
+    }
+
+    fn rust_lang() -> Language {
+        Language::new(
+            LanguageConfig {
+                name: "Rust".into(),
+                matcher: LanguageMatcher {
+                    path_suffixes: vec!["rs".to_string()],
+                    ..Default::default()
+                },
+                ..Default::default()
+            },
+            Some(tree_sitter_rust::LANGUAGE.into()),
+        )
+    }
+}

crates/zeta/src/zeta.rs 🔗

@@ -1,5 +1,6 @@
 mod completion_diff_element;
 mod init;
+mod input_excerpt;
 mod license_detection;
 mod onboarding_banner;
 mod onboarding_modal;
@@ -25,7 +26,7 @@ use gpui::{
 use http_client::{HttpClient, Method};
 use language::{
     language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, EditPreview,
-    OffsetRangeExt, Point, ToOffset, ToPoint,
+    OffsetRangeExt, ToOffset, ToPoint,
 };
 use language_models::LlmApiToken;
 use postage::watch;
@@ -61,26 +62,26 @@ const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_ch
 /// intentionally low to err on the side of underestimating limits.
 const BYTES_PER_TOKEN_GUESS: usize = 3;
 
-/// Output token limit, used to inform the size of the input. A copy of this constant is also in
+/// Input token limit, used to inform the size of the input. A copy of this constant is also in
 /// `crates/collab/src/llm.rs`.
-const MAX_OUTPUT_TOKENS: usize = 2048;
+const MAX_INPUT_TOKENS: usize = 2048;
+
+const MAX_CONTEXT_TOKENS: usize = 64;
+const MAX_OUTPUT_TOKENS: usize = 256;
 
 /// Total bytes limit for editable region of buffer excerpt.
 ///
 /// The number of output tokens is relevant to the size of the input excerpt because the model is
 /// tasked with outputting a modified excerpt. `2/3` is chosen so that there are some output tokens
 /// remaining for the model to specify insertions.
-const BUFFER_EXCERPT_BYTE_LIMIT: usize = (MAX_OUTPUT_TOKENS * 2 / 3) * BYTES_PER_TOKEN_GUESS;
-
-/// Total line limit for editable region of buffer excerpt.
-const BUFFER_EXCERPT_LINE_LIMIT: u32 = 64;
+const BUFFER_EXCERPT_BYTE_LIMIT: usize = (MAX_INPUT_TOKENS * 2 / 3) * BYTES_PER_TOKEN_GUESS;
 
 /// Note that this is not the limit for the overall prompt, just for the inputs to the template
 /// instantiated in `crates/collab/src/llm.rs`.
 const TOTAL_BYTE_LIMIT: usize = BUFFER_EXCERPT_BYTE_LIMIT * 2;
 
 /// Maximum number of events to include in the prompt.
-const MAX_EVENT_COUNT: usize = 16;
+const MAX_EVENT_COUNT: usize = 8;
 
 /// Maximum number of string bytes in a single event. Arbitrarily choosing this to be 4x the size of
 /// equally splitting up the the remaining bytes after the largest possible buffer excerpt.
@@ -373,8 +374,8 @@ impl Zeta {
         R: Future<Output = Result<PredictEditsResponse>> + Send + 'static,
     {
         let snapshot = self.report_changes_for_buffer(&buffer, cx);
-        let cursor_point = cursor.to_point(&snapshot);
-        let cursor_offset = cursor_point.to_offset(&snapshot);
+        let cursor_position = cursor.to_point(&snapshot);
+        let cursor_offset = cursor_position.to_offset(&snapshot);
         let events = self.events.clone();
         let path: Arc<Path> = snapshot
             .file()
@@ -389,45 +390,47 @@ impl Zeta {
         cx.spawn(|_, cx| async move {
             let request_sent_at = Instant::now();
 
-            let (input_events, input_excerpt, excerpt_range, input_outline) = cx
-                .background_executor()
-                .spawn({
-                    let snapshot = snapshot.clone();
-                    let path = path.clone();
-                    async move {
-                        let path = path.to_string_lossy();
-                        let (excerpt_range, excerpt_len_guess) = excerpt_range_for_position(
-                            cursor_point,
-                            BUFFER_EXCERPT_BYTE_LIMIT,
-                            BUFFER_EXCERPT_LINE_LIMIT,
-                            &path,
-                            &snapshot,
-                        )?;
-                        let input_excerpt = prompt_for_excerpt(
-                            cursor_offset,
-                            &excerpt_range,
-                            excerpt_len_guess,
-                            &path,
-                            &snapshot,
-                        );
-
-                        let bytes_remaining = TOTAL_BYTE_LIMIT.saturating_sub(input_excerpt.len());
-                        let input_events = prompt_for_events(events.iter(), bytes_remaining);
-
-                        // Note that input_outline is not currently used in prompt generation and so
-                        // is not counted towards TOTAL_BYTE_LIMIT.
-                        let input_outline = prompt_for_outline(&snapshot);
-
-                        anyhow::Ok((input_events, input_excerpt, excerpt_range, input_outline))
-                    }
-                })
-                .await?;
+            let (input_events, input_excerpt, editable_range, input_outline, speculated_output) =
+                cx.background_executor()
+                    .spawn({
+                        let snapshot = snapshot.clone();
+                        let path = path.clone();
+                        async move {
+                            let path = path.to_string_lossy();
+                            let input_excerpt = input_excerpt::excerpt_for_cursor_position(
+                                cursor_position,
+                                &path,
+                                &snapshot,
+                                MAX_OUTPUT_TOKENS,
+                                MAX_CONTEXT_TOKENS,
+                            );
+
+                            let bytes_remaining =
+                                TOTAL_BYTE_LIMIT.saturating_sub(input_excerpt.prompt.len());
+                            let input_events = prompt_for_events(events.iter(), bytes_remaining);
+
+                            // Note that input_outline is not currently used in prompt generation and so
+                            // is not counted towards TOTAL_BYTE_LIMIT.
+                            let input_outline = prompt_for_outline(&snapshot);
+
+                            let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
+                            anyhow::Ok((
+                                input_events,
+                                input_excerpt.prompt,
+                                editable_range,
+                                input_outline,
+                                input_excerpt.speculated_output,
+                            ))
+                        }
+                    })
+                    .await?;
 
             log::debug!("Events:\n{}\nExcerpt:\n{}", input_events, input_excerpt);
 
             let body = PredictEditsParams {
                 input_events: input_events.clone(),
                 input_excerpt: input_excerpt.clone(),
+                speculated_output,
                 outline: Some(input_outline.clone()),
                 data_collection_permission,
             };
@@ -441,7 +444,7 @@ impl Zeta {
                 output_excerpt,
                 buffer,
                 &snapshot,
-                excerpt_range,
+                editable_range,
                 cursor_offset,
                 path,
                 input_outline,
@@ -457,6 +460,8 @@ impl Zeta {
     // Generates several example completions of various states to fill the Zeta completion modal
     #[cfg(any(test, feature = "test-support"))]
     pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
+        use language::Point;
+
         let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
             And maybe a short line
 
@@ -675,7 +680,7 @@ and then another
         output_excerpt: String,
         buffer: Entity<Buffer>,
         snapshot: &BufferSnapshot,
-        excerpt_range: Range<usize>,
+        editable_range: Range<usize>,
         cursor_offset: usize,
         path: Arc<Path>,
         input_outline: String,
@@ -692,9 +697,9 @@ and then another
                 .background_executor()
                 .spawn({
                     let output_excerpt = output_excerpt.clone();
-                    let excerpt_range = excerpt_range.clone();
+                    let editable_range = editable_range.clone();
                     let snapshot = snapshot.clone();
-                    async move { Self::parse_edits(output_excerpt, excerpt_range, &snapshot) }
+                    async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
                 })
                 .await?
                 .into();
@@ -717,7 +722,7 @@ and then another
             Ok(Some(InlineCompletion {
                 id: InlineCompletionId::new(),
                 path,
-                excerpt_range,
+                excerpt_range: editable_range,
                 cursor_offset,
                 edits,
                 edit_preview,
@@ -734,7 +739,7 @@ and then another
 
     fn parse_edits(
         output_excerpt: Arc<str>,
-        excerpt_range: Range<usize>,
+        editable_range: Range<usize>,
         snapshot: &BufferSnapshot,
     ) -> Result<Vec<(Range<Anchor>, String)>> {
         let content = output_excerpt.replace(CURSOR_MARKER, "");
@@ -778,13 +783,13 @@ and then another
         let new_text = &content[..codefence_end];
 
         let old_text = snapshot
-            .text_for_range(excerpt_range.clone())
+            .text_for_range(editable_range.clone())
             .collect::<String>();
 
         Ok(Self::compute_edits(
             old_text,
             new_text,
-            excerpt_range.start,
+            editable_range.start,
             &snapshot,
         ))
     }
@@ -1011,161 +1016,6 @@ fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
     input_outline
 }
 
-fn prompt_for_excerpt(
-    offset: usize,
-    excerpt_range: &Range<usize>,
-    mut len_guess: usize,
-    path: &str,
-    snapshot: &BufferSnapshot,
-) -> String {
-    let point_range = excerpt_range.to_point(snapshot);
-
-    // Include one line of extra context before and after editable range, if those lines are non-empty.
-    let extra_context_before_range =
-        if point_range.start.row > 0 && !snapshot.is_line_blank(point_range.start.row - 1) {
-            let range =
-                (Point::new(point_range.start.row - 1, 0)..point_range.start).to_offset(snapshot);
-            len_guess += range.end - range.start;
-            Some(range)
-        } else {
-            None
-        };
-    let extra_context_after_range = if point_range.end.row < snapshot.max_point().row
-        && !snapshot.is_line_blank(point_range.end.row + 1)
-    {
-        let range = (point_range.end
-            ..Point::new(
-                point_range.end.row + 1,
-                snapshot.line_len(point_range.end.row + 1),
-            ))
-            .to_offset(snapshot);
-        len_guess += range.end - range.start;
-        Some(range)
-    } else {
-        None
-    };
-
-    let mut prompt_excerpt = String::with_capacity(len_guess);
-    writeln!(prompt_excerpt, "```{}", path).unwrap();
-
-    if excerpt_range.start == 0 {
-        writeln!(prompt_excerpt, "{START_OF_FILE_MARKER}").unwrap();
-    }
-
-    if let Some(extra_context_before_range) = extra_context_before_range {
-        for chunk in snapshot.text_for_range(extra_context_before_range) {
-            prompt_excerpt.push_str(chunk);
-        }
-    }
-    writeln!(prompt_excerpt, "{EDITABLE_REGION_START_MARKER}").unwrap();
-    for chunk in snapshot.text_for_range(excerpt_range.start..offset) {
-        prompt_excerpt.push_str(chunk);
-    }
-    prompt_excerpt.push_str(CURSOR_MARKER);
-    for chunk in snapshot.text_for_range(offset..excerpt_range.end) {
-        prompt_excerpt.push_str(chunk);
-    }
-    write!(prompt_excerpt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
-
-    if let Some(extra_context_after_range) = extra_context_after_range {
-        for chunk in snapshot.text_for_range(extra_context_after_range) {
-            prompt_excerpt.push_str(chunk);
-        }
-    }
-
-    write!(prompt_excerpt, "\n```").unwrap();
-    debug_assert!(
-        prompt_excerpt.len() <= len_guess,
-        "Excerpt length {} exceeds estimated length {}",
-        prompt_excerpt.len(),
-        len_guess
-    );
-    prompt_excerpt
-}
-
-fn excerpt_range_for_position(
-    cursor_point: Point,
-    byte_limit: usize,
-    line_limit: u32,
-    path: &str,
-    snapshot: &BufferSnapshot,
-) -> Result<(Range<usize>, usize)> {
-    let cursor_row = cursor_point.row;
-    let last_buffer_row = snapshot.max_point().row;
-
-    // This is an overestimate because it includes parts of prompt_for_excerpt which are
-    // conditionally skipped.
-    let mut len_guess = 0;
-    len_guess += "```".len() + path.len() + 1;
-    len_guess += START_OF_FILE_MARKER.len() + 1;
-    len_guess += EDITABLE_REGION_START_MARKER.len() + 1;
-    len_guess += CURSOR_MARKER.len();
-    len_guess += EDITABLE_REGION_END_MARKER.len() + 1;
-    len_guess += "```".len() + 1;
-
-    len_guess += usize::try_from(snapshot.line_len(cursor_row) + 1).unwrap();
-
-    if len_guess > byte_limit {
-        return Err(anyhow!("Current line too long to send to model."));
-    }
-
-    let mut excerpt_start_row = cursor_row;
-    let mut excerpt_end_row = cursor_row;
-    let mut no_more_before = cursor_row == 0;
-    let mut no_more_after = cursor_row >= last_buffer_row;
-    let mut row_delta = 1;
-    loop {
-        if !no_more_before {
-            let row = cursor_point.row - row_delta;
-            let line_len: usize = usize::try_from(snapshot.line_len(row) + 1).unwrap();
-            let mut new_len_guess = len_guess + line_len;
-            if row == 0 {
-                new_len_guess += START_OF_FILE_MARKER.len() + 1;
-            }
-            if new_len_guess <= byte_limit {
-                len_guess = new_len_guess;
-                excerpt_start_row = row;
-                if row == 0 {
-                    no_more_before = true;
-                }
-            } else {
-                no_more_before = true;
-            }
-        }
-        if excerpt_end_row - excerpt_start_row >= line_limit {
-            break;
-        }
-        if !no_more_after {
-            let row = cursor_point.row + row_delta;
-            let line_len: usize = usize::try_from(snapshot.line_len(row) + 1).unwrap();
-            let new_len_guess = len_guess + line_len;
-            if new_len_guess <= byte_limit {
-                len_guess = new_len_guess;
-                excerpt_end_row = row;
-                if row >= last_buffer_row {
-                    no_more_after = true;
-                }
-            } else {
-                no_more_after = true;
-            }
-        }
-        if excerpt_end_row - excerpt_start_row >= line_limit {
-            break;
-        }
-        if no_more_before && no_more_after {
-            break;
-        }
-        row_delta += 1;
-    }
-
-    let excerpt_start = Point::new(excerpt_start_row, 0);
-    let excerpt_end = Point::new(excerpt_end_row, snapshot.line_len(excerpt_end_row));
-    Ok((
-        excerpt_start.to_offset(snapshot)..excerpt_end.to_offset(snapshot),
-        len_guess,
-    ))
-}
-
 fn prompt_for_events<'a>(
     events: impl Iterator<Item = &'a Event>,
     mut bytes_remaining: usize,
@@ -1671,6 +1521,7 @@ mod tests {
     use gpui::TestAppContext;
     use http_client::FakeHttpClient;
     use indoc::indoc;
+    use language::Point;
     use language_models::RefreshLlmTokenListener;
     use rpc::proto;
     use settings::SettingsStore;