Add character limits to edit prediction prompt generation (#23814)

Michael Sloan , Richard , and Joao created

Limits the size of the buffer excerpt and the size of change history.

Release Notes:

- N/A

---------

Co-authored-by: Richard <richard@zed.dev>
Co-authored-by: Joao <joao@zed.dev>

Change summary

crates/zeta/src/zeta.rs | 383 ++++++++++++++++++++++++++++++++++--------
1 file changed, 308 insertions(+), 75 deletions(-)

Detailed changes

crates/zeta/src/zeta.rs 🔗

@@ -41,6 +41,30 @@ const EDITABLE_REGION_START_MARKER: &'static str = "<|editable_region_start|>";
 const EDITABLE_REGION_END_MARKER: &'static str = "<|editable_region_end|>";
 const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
 
+// TODO(mgsloan): more systematic way to choose or tune these fairly arbitrary constants?
+
+/// Typical number of string bytes per token for the purposes of limiting model input. This is
+/// intentionally low to err on the side of underestimating limits.
+const BYTES_PER_TOKEN_GUESS: usize = 3;
+
+/// This is based on the output token limit `max_tokens: 2048` in `crates/collab/src/llm.rs`. Number
+/// of output tokens is relevant to the size of the input excerpt because the model is tasked with
+/// outputting a modified excerpt. `2/3` is chosen so that there are some output tokens remaining
+/// for the model to specify insertions.
+const BUFFER_EXCERPT_BYTE_LIMIT: usize = (2048 * 2 / 3) * BYTES_PER_TOKEN_GUESS;
+
+/// Note that this is not the limit for the overall prompt, just for the inputs to the template
+/// instantiated in `crates/collab/src/llm.rs`.
+const TOTAL_BYTE_LIMIT: usize = BUFFER_EXCERPT_BYTE_LIMIT * 2;
+
+/// Maximum number of events to include in the prompt.
+const MAX_EVENT_COUNT: usize = 16;
+
+/// Maximum number of string bytes in a single event. Arbitrarily choosing this to be 4x the size of
+/// equally splitting up the the remaining bytes after the largest possible buffer excerpt.
+const PER_EVENT_BYTE_LIMIT: usize =
+    (TOTAL_BYTE_LIMIT - BUFFER_EXCERPT_BYTE_LIMIT) / MAX_EVENT_COUNT * 4;
+
 actions!(edit_prediction, [ClearHistory]);
 
 #[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
@@ -223,8 +247,6 @@ impl Zeta {
     }
 
     fn push_event(&mut self, event: Event) {
-        const MAX_EVENT_COUNT: usize = 16;
-
         if let Some(Event::BufferChange {
             new_snapshot: last_new_snapshot,
             timestamp: last_timestamp,
@@ -294,7 +316,7 @@ impl Zeta {
     pub fn request_completion_impl<F, R>(
         &mut self,
         buffer: &Entity<Buffer>,
-        position: language::Anchor,
+        cursor: language::Anchor,
         cx: &mut Context<Self>,
         perform_predict_edits: F,
     ) -> Task<Result<Option<InlineCompletion>>>
@@ -303,9 +325,8 @@ impl Zeta {
         R: Future<Output = Result<PredictEditsResponse>> + Send + 'static,
     {
         let snapshot = self.report_changes_for_buffer(buffer, cx);
-        let point = position.to_point(&snapshot);
-        let offset = point.to_offset(&snapshot);
-        let excerpt_range = excerpt_range_for_position(point, &snapshot);
+        let cursor_point = cursor.to_point(&snapshot);
+        let cursor_offset = cursor_point.to_offset(&snapshot);
         let events = self.events.clone();
         let path = snapshot
             .file()
@@ -319,28 +340,25 @@ impl Zeta {
         cx.spawn(|_, cx| async move {
             let request_sent_at = Instant::now();
 
-            let (input_events, input_excerpt, input_outline) = cx
+            let (input_events, input_excerpt, input_outline, excerpt_range) = cx
                 .background_executor()
                 .spawn({
                     let snapshot = snapshot.clone();
-                    let excerpt_range = excerpt_range.clone();
                     async move {
-                        let mut input_events = String::new();
-                        for event in events {
-                            if !input_events.is_empty() {
-                                input_events.push('\n');
-                                input_events.push('\n');
-                            }
-                            input_events.push_str(&event.to_prompt());
-                        }
+                        let (input_excerpt, excerpt_range) =
+                            prompt_for_excerpt(&snapshot, cursor_point, cursor_offset)?;
+
+                        let chars_remaining = TOTAL_BYTE_LIMIT.saturating_sub(input_excerpt.len());
+                        let input_events = prompt_for_events(events.iter(), chars_remaining);
 
-                        let input_excerpt = prompt_for_excerpt(&snapshot, &excerpt_range, offset);
+                        // Note that input_outline is not currently used in prompt generation and so
+                        // is not counted towards TOTAL_BYTE_LIMIT.
                         let input_outline = prompt_for_outline(&snapshot);
 
-                        (input_events, input_excerpt, input_outline)
+                        anyhow::Ok((input_events, input_excerpt, input_outline, excerpt_range))
                     }
                 })
-                .await;
+                .await?;
 
             log::debug!("Events:\n{}\nExcerpt:\n{}", input_events, input_excerpt);
 
@@ -359,7 +377,7 @@ impl Zeta {
                 output_excerpt,
                 &snapshot,
                 excerpt_range,
-                offset,
+                cursor_offset,
                 path,
                 input_outline,
                 input_events,
@@ -814,77 +832,292 @@ fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
     input_outline
 }
 
-fn prompt_for_excerpt(
-    snapshot: &BufferSnapshot,
-    excerpt_range: &Range<usize>,
-    offset: usize,
-) -> String {
-    let mut prompt_excerpt = String::new();
-    writeln!(
-        prompt_excerpt,
-        "```{}",
-        snapshot
-            .file()
-            .map_or(Cow::Borrowed("untitled"), |file| file
-                .path()
-                .to_string_lossy())
-    )
-    .unwrap();
+#[derive(Debug, Default)]
+struct ExcerptPromptBuilder<'a> {
+    file_path: Cow<'a, str>,
+    include_start_of_file_marker: bool,
+    before_editable_region: Option<ReversedStringChunks<'a>>,
+    before_cursor: ReversedStringChunks<'a>,
+    after_cursor: StringChunks<'a>,
+    after_editable_region: Option<StringChunks<'a>>,
+}
 
-    if excerpt_range.start == 0 {
-        writeln!(prompt_excerpt, "{START_OF_FILE_MARKER}").unwrap();
+impl<'a> ExcerptPromptBuilder<'a> {
+    pub fn len(&self) -> usize {
+        let mut length = 0;
+        length += "```".len();
+        length += self.file_path.len();
+        length += 1;
+        if self.include_start_of_file_marker {
+            length += START_OF_FILE_MARKER.len();
+            length += 1;
+        }
+        if let Some(before_editable_region) = &self.before_editable_region {
+            length += before_editable_region.len();
+            length += 1;
+        }
+        length += EDITABLE_REGION_START_MARKER.len();
+        length += 1;
+        length += self.before_cursor.len();
+        length += CURSOR_MARKER.len();
+        length += self.after_cursor.len();
+        length += 1;
+        length += EDITABLE_REGION_END_MARKER.len();
+        length += 1;
+        if let Some(after_editable_region) = &self.after_editable_region {
+            length += after_editable_region.len();
+            length += 1;
+        }
+        length += "```".len();
+        length
     }
 
-    let point_range = excerpt_range.to_point(snapshot);
-    if point_range.start.row > 0 && !snapshot.is_line_blank(point_range.start.row - 1) {
-        let extra_context_line_range = Point::new(point_range.start.row - 1, 0)..point_range.start;
-        for chunk in snapshot.text_for_range(extra_context_line_range) {
-            prompt_excerpt.push_str(chunk);
+    pub fn to_string(&self) -> String {
+        let length = self.len();
+        let mut result = String::with_capacity(length);
+        result.push_str("```");
+        result.push_str(&self.file_path);
+        result.push('\n');
+        if self.include_start_of_file_marker {
+            result.push_str(START_OF_FILE_MARKER);
+            result.push('\n');
+        }
+        if let Some(before_editable_region) = &self.before_editable_region {
+            before_editable_region.add_to_string(&mut result);
+            result.push('\n');
         }
+        result.push_str(EDITABLE_REGION_START_MARKER);
+        result.push('\n');
+        self.before_cursor.add_to_string(&mut result);
+        result.push_str(CURSOR_MARKER);
+        self.after_cursor.add_to_string(&mut result);
+        result.push('\n');
+        result.push_str(EDITABLE_REGION_END_MARKER);
+        result.push('\n');
+        if let Some(after_editable_region) = &self.after_editable_region {
+            after_editable_region.add_to_string(&mut result);
+            result.push('\n');
+        }
+        result.push_str("```");
+        debug_assert!(
+            result.len() == length,
+            "Expected length: {}, Actual length: {}",
+            length,
+            result.len()
+        );
+        result
     }
-    writeln!(prompt_excerpt, "{EDITABLE_REGION_START_MARKER}").unwrap();
-    for chunk in snapshot.text_for_range(excerpt_range.start..offset) {
-        prompt_excerpt.push_str(chunk);
+}
+
+#[derive(Debug, Default)]
+pub struct StringChunks<'a> {
+    chunks: Vec<&'a str>,
+    length: usize,
+}
+
+#[derive(Debug, Default)]
+pub struct ReversedStringChunks<'a>(StringChunks<'a>);
+
+impl<'a> StringChunks<'a> {
+    pub fn len(&self) -> usize {
+        self.length
     }
-    prompt_excerpt.push_str(CURSOR_MARKER);
-    for chunk in snapshot.text_for_range(offset..excerpt_range.end) {
-        prompt_excerpt.push_str(chunk);
+
+    pub fn extend(&mut self, new_chunks: impl Iterator<Item = &'a str>) {
+        self.chunks
+            .extend(new_chunks.inspect(|chunk| self.length += chunk.len()));
     }
-    write!(prompt_excerpt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
 
-    if point_range.end.row < snapshot.max_point().row
-        && !snapshot.is_line_blank(point_range.end.row + 1)
-    {
-        let extra_context_line_range = point_range.end
-            ..Point::new(
-                point_range.end.row + 1,
-                snapshot.line_len(point_range.end.row + 1),
-            );
-        for chunk in snapshot.text_for_range(extra_context_line_range) {
-            prompt_excerpt.push_str(chunk);
+    pub fn append_from_buffer<T: ToOffset>(
+        &mut self,
+        snapshot: &'a BufferSnapshot,
+        range: Range<T>,
+    ) {
+        self.extend(snapshot.text_for_range(range));
+    }
+
+    pub fn add_to_string(&self, string: &mut String) {
+        for chunk in self.chunks.iter() {
+            string.push_str(chunk);
         }
     }
+}
+
+impl<'a> ReversedStringChunks<'a> {
+    pub fn len(&self) -> usize {
+        self.0.len()
+    }
 
-    write!(prompt_excerpt, "\n```").unwrap();
-    prompt_excerpt
+    pub fn prepend_from_buffer<T: ToOffset>(
+        &mut self,
+        snapshot: &'a BufferSnapshot,
+        range: Range<T>,
+    ) {
+        self.0.extend(snapshot.reversed_chunks_in_range(range));
+    }
+
+    pub fn add_to_string(&self, string: &mut String) {
+        for chunk in self.0.chunks.iter().rev() {
+            string.push_str(chunk);
+        }
+    }
 }
 
-fn excerpt_range_for_position(point: Point, snapshot: &BufferSnapshot) -> Range<usize> {
-    const CONTEXT_LINES: u32 = 32;
+/// Computes a prompt for the excerpt of the buffer around the cursor. This always includes complete
+/// lines and the result length will be `<= MAX_INPUT_EXCERPT_BYTES`.
+fn prompt_for_excerpt(
+    snapshot: &BufferSnapshot,
+    cursor_point: Point,
+    cursor_offset: usize,
+) -> Result<(String, Range<usize>)> {
+    let mut builder = ExcerptPromptBuilder::default();
+    builder.file_path = snapshot.file().map_or(Cow::Borrowed("untitled"), |file| {
+        file.path().to_string_lossy()
+    });
+
+    let cursor_row = cursor_point.row;
+    let cursor_line_start_offset = Point::new(cursor_row, 0).to_offset(snapshot);
+    let cursor_line_end_offset =
+        Point::new(cursor_row, snapshot.line_len(cursor_row)).to_offset(snapshot);
+    builder
+        .before_cursor
+        .prepend_from_buffer(snapshot, cursor_line_start_offset..cursor_offset);
+    builder
+        .after_cursor
+        .append_from_buffer(snapshot, cursor_offset..cursor_line_end_offset);
+
+    if builder.len() > BUFFER_EXCERPT_BYTE_LIMIT {
+        return Err(anyhow!("Current line too long to send to model."));
+    }
 
-    let mut context_lines_before = CONTEXT_LINES;
-    let mut context_lines_after = CONTEXT_LINES;
-    if point.row < CONTEXT_LINES {
-        context_lines_after += CONTEXT_LINES - point.row;
-    } else if point.row + CONTEXT_LINES > snapshot.max_point().row {
-        context_lines_before += (point.row + CONTEXT_LINES) - snapshot.max_point().row;
+    let last_buffer_row = snapshot.max_point().row;
+
+    // Figure out how many lines of the buffer to include in the prompt, walking outwards from the
+    // cursor. Even if a line before or after the cursor causes the byte limit to be exceeded,
+    // continues walking in the other direction.
+    let mut first_included_row = cursor_row;
+    let mut last_included_row = cursor_row;
+    let mut no_more_before = cursor_row == 0;
+    let mut no_more_after = cursor_row >= last_buffer_row;
+    let mut output_len = builder.len();
+    let mut row_delta = 1;
+    loop {
+        if !no_more_before {
+            let row = cursor_point.row - row_delta;
+            let line_len: usize = (snapshot.line_len(row) + 1).try_into().unwrap();
+            let mut new_output_len = output_len + line_len;
+            if row == 0 {
+                new_output_len += START_OF_FILE_MARKER.len() + 1;
+            }
+            if new_output_len <= BUFFER_EXCERPT_BYTE_LIMIT {
+                output_len = new_output_len;
+                first_included_row = row;
+                if row == 0 {
+                    builder.include_start_of_file_marker = true;
+                    no_more_before = true;
+                }
+            } else {
+                no_more_before = true;
+            }
+        }
+        if !no_more_after {
+            let row = cursor_point.row + row_delta;
+            let line_len: usize = (snapshot.line_len(row) + 1).try_into().unwrap();
+            let new_output_len = output_len + line_len;
+            if new_output_len <= BUFFER_EXCERPT_BYTE_LIMIT {
+                output_len = new_output_len;
+                last_included_row = row;
+                if row >= last_buffer_row {
+                    no_more_after = true;
+                }
+            } else {
+                no_more_after = true;
+            }
+        }
+        if no_more_before && no_more_after {
+            break;
+        }
+        row_delta += 1;
     }
 
-    let excerpt_start_row = point.row.saturating_sub(context_lines_before);
-    let excerpt_start = Point::new(excerpt_start_row, 0);
-    let excerpt_end_row = cmp::min(point.row + context_lines_after, snapshot.max_point().row);
-    let excerpt_end = Point::new(excerpt_end_row, snapshot.line_len(excerpt_end_row));
-    excerpt_start.to_offset(snapshot)..excerpt_end.to_offset(snapshot)
+    // Include a line of context outside the editable region, but only if it is not the first line
+    // (otherwise the first line of the file would never be uneditable).
+    let first_editable_row = if first_included_row != 0
+        && first_included_row < cursor_row
+        && !snapshot.is_line_blank(first_included_row)
+    {
+        let mut before_editable_region = ReversedStringChunks::default();
+        before_editable_region.prepend_from_buffer(
+            snapshot,
+            Point::new(first_included_row, 0)
+                ..Point::new(first_included_row, snapshot.line_len(first_included_row)),
+        );
+        builder.before_editable_region = Some(before_editable_region);
+        first_included_row + 1
+    } else {
+        first_included_row
+    };
+
+    // Include a line of context outside the editable region, but only if it is not the last line
+    // (otherwise the first line of the file would never be uneditable).
+    let last_editable_row = if last_included_row < last_buffer_row
+        && last_included_row > cursor_row
+        && !snapshot.is_line_blank(last_included_row)
+    {
+        let mut after_editable_region = StringChunks::default();
+        after_editable_region.append_from_buffer(
+            snapshot,
+            Point::new(last_included_row, 0)
+                ..Point::new(last_included_row, snapshot.line_len(last_included_row)),
+        );
+        builder.after_editable_region = Some(after_editable_region);
+        last_included_row + 1
+    } else {
+        last_included_row
+    };
+
+    let editable_range = (Point::new(first_editable_row, 0)
+        ..Point::new(last_editable_row, snapshot.line_len(last_editable_row)))
+        .to_offset(snapshot);
+
+    let before_cursor_row = editable_range.start..cursor_line_start_offset;
+    let after_cursor_row = cursor_line_end_offset..editable_range.end;
+    if !before_cursor_row.is_empty() {
+        builder
+            .before_cursor
+            .prepend_from_buffer(snapshot, before_cursor_row);
+    }
+    if !after_cursor_row.is_empty() {
+        builder
+            .after_cursor
+            .append_from_buffer(snapshot, after_cursor_row);
+    }
+
+    anyhow::Ok((builder.to_string(), editable_range))
+}
+
+fn prompt_for_events<'a>(
+    events: impl Iterator<Item = &'a Event>,
+    mut bytes_remaining: usize,
+) -> String {
+    let mut result = String::new();
+    for event in events {
+        if !result.is_empty() {
+            result.push('\n');
+            result.push('\n');
+        }
+        let event_string = event.to_prompt();
+        let len = event_string.len();
+        if len > PER_EVENT_BYTE_LIMIT {
+            continue;
+        }
+        if len > bytes_remaining {
+            break;
+        }
+        bytes_remaining -= len;
+        result.push_str(&event_string);
+    }
+    result
 }
 
 struct RegisteredBuffer {