Implement simpler logic for edit predictions prompt byte limits (#23983)

Michael Sloan created

Realized that the logic in #23814 was more than needed, and harder to
maintain. Something like that could make sense if using the tokenizer
and wanting to precisely hit a token limit. However in the case of edit
predictions it's more of a latency+expense vs capability tradeoff, and
so such precision is unnecessary.

Happily this change didn't require much extra work, just copy-modifying
parts of that change was sufficient.

Release Notes:

- N/A

Change summary

crates/collab/src/llm.rs |   9 
crates/zeta/src/zeta.rs  | 369 ++++++++++++++++-------------------------
2 files changed, 147 insertions(+), 231 deletions(-)

Detailed changes

crates/collab/src/llm.rs 🔗

@@ -42,6 +42,11 @@ use util::ResultExt;
 
 pub use token::*;
 
+const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
+
+/// Output token limit. A copy of this constant is also in `crates/zeta/src/zeta.rs`.
+const MAX_OUTPUT_TOKENS: u32 = 2048;
+
 pub struct LlmState {
     pub config: Config,
     pub executor: Executor,
@@ -52,8 +57,6 @@ pub struct LlmState {
         RwLock<HashMap<(LanguageModelProvider, String), (DateTime<Utc>, ActiveUserCount)>>,
 }
 
-const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
-
 impl LlmState {
     pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
         let database_url = config
@@ -488,7 +491,7 @@ async fn predict_edits(
         fireworks::CompletionRequest {
             model: model.to_string(),
             prompt: prompt.clone(),
-            max_tokens: 2048,
+            max_tokens: MAX_OUTPUT_TOKENS,
             temperature: 0.,
             prediction: Some(fireworks::Prediction::Content {
                 content: params.input_excerpt.clone(),

crates/zeta/src/zeta.rs 🔗

@@ -58,11 +58,19 @@ const ZED_PREDICT_DATA_COLLECTION_NEVER_ASK_AGAIN_KEY: &'static str =
 /// intentionally low to err on the side of underestimating limits.
 const BYTES_PER_TOKEN_GUESS: usize = 3;
 
-/// This is based on the output token limit `max_tokens: 2048` in `crates/collab/src/llm.rs`. Number
-/// of output tokens is relevant to the size of the input excerpt because the model is tasked with
-/// outputting a modified excerpt. `2/3` is chosen so that there are some output tokens remaining
-/// for the model to specify insertions.
-const BUFFER_EXCERPT_BYTE_LIMIT: usize = (2048 * 2 / 3) * BYTES_PER_TOKEN_GUESS;
+/// Output token limit, used to inform the size of the input. A copy of this constant is also in
+/// `crates/collab/src/llm.rs`.
+const MAX_OUTPUT_TOKENS: usize = 2048;
+
+/// Total bytes limit for editable region of buffer excerpt.
+///
+/// The number of output tokens is relevant to the size of the input excerpt because the model is
+/// tasked with outputting a modified excerpt. `2/3` is chosen so that there are some output tokens
+/// remaining for the model to specify insertions.
+const BUFFER_EXCERPT_BYTE_LIMIT: usize = (MAX_OUTPUT_TOKENS * 2 / 3) * BYTES_PER_TOKEN_GUESS;
+
+/// Total line limit for editable region of buffer excerpt.
+const BUFFER_EXCERPT_LINE_LIMIT: u32 = 64;
 
 /// Note that this is not the limit for the overall prompt, just for the inputs to the template
 /// instantiated in `crates/collab/src/llm.rs`.
@@ -342,12 +350,11 @@ impl Zeta {
         F: FnOnce(Arc<Client>, LlmApiToken, bool, PredictEditsParams) -> R + 'static,
         R: Future<Output = Result<PredictEditsResponse>> + Send + 'static,
     {
-        let buffer = buffer.clone();
         let snapshot = self.report_changes_for_buffer(&buffer, cx);
         let cursor_point = cursor.to_point(&snapshot);
         let cursor_offset = cursor_point.to_offset(&snapshot);
         let events = self.events.clone();
-        let path = snapshot
+        let path: Arc<Path> = snapshot
             .file()
             .map(|f| Arc::from(f.full_path(cx).as_path()))
             .unwrap_or_else(|| Arc::from(Path::new("untitled")));
@@ -356,25 +363,40 @@ impl Zeta {
         let llm_token = self.llm_token.clone();
         let is_staff = cx.is_staff();
 
+        let buffer = buffer.clone();
         cx.spawn(|_, cx| async move {
             let request_sent_at = Instant::now();
 
-            let (input_events, input_excerpt, input_outline, excerpt_range) = cx
+            let (input_events, input_excerpt, excerpt_range, input_outline) = cx
                 .background_executor()
                 .spawn({
                     let snapshot = snapshot.clone();
+                    let path = path.clone();
                     async move {
-                        let (input_excerpt, excerpt_range) =
-                            prompt_for_excerpt(&snapshot, cursor_point, cursor_offset)?;
-
-                        let chars_remaining = TOTAL_BYTE_LIMIT.saturating_sub(input_excerpt.len());
-                        let input_events = prompt_for_events(events.iter(), chars_remaining);
+                        let path = path.to_string_lossy();
+                        let (excerpt_range, excerpt_len_guess) = excerpt_range_for_position(
+                            cursor_point,
+                            BUFFER_EXCERPT_BYTE_LIMIT,
+                            BUFFER_EXCERPT_LINE_LIMIT,
+                            &path,
+                            &snapshot,
+                        )?;
+                        let input_excerpt = prompt_for_excerpt(
+                            cursor_offset,
+                            &excerpt_range,
+                            excerpt_len_guess,
+                            &path,
+                            &snapshot,
+                        );
+
+                        let bytes_remaining = TOTAL_BYTE_LIMIT.saturating_sub(input_excerpt.len());
+                        let input_events = prompt_for_events(events.iter(), bytes_remaining);
 
                         // Note that input_outline is not currently used in prompt generation and so
                         // is not counted towards TOTAL_BYTE_LIMIT.
                         let input_outline = prompt_for_outline(&snapshot);
 
-                        anyhow::Ok((input_events, input_excerpt, input_outline, excerpt_range))
+                        anyhow::Ok((input_events, input_excerpt, excerpt_range, input_outline))
                     }
                 })
                 .await?;
@@ -998,201 +1020,137 @@ fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
     input_outline
 }
 
-#[derive(Debug, Default)]
-struct ExcerptPromptBuilder<'a> {
-    file_path: Cow<'a, str>,
-    include_start_of_file_marker: bool,
-    before_editable_region: Option<ReversedStringChunks<'a>>,
-    before_cursor: ReversedStringChunks<'a>,
-    after_cursor: StringChunks<'a>,
-    after_editable_region: Option<StringChunks<'a>>,
-}
-
-impl<'a> ExcerptPromptBuilder<'a> {
-    pub fn len(&self) -> usize {
-        let mut length = 0;
-        length += "```".len();
-        length += self.file_path.len();
-        length += 1;
-        if self.include_start_of_file_marker {
-            length += START_OF_FILE_MARKER.len();
-            length += 1;
-        }
-        if let Some(before_editable_region) = &self.before_editable_region {
-            length += before_editable_region.len();
-            length += 1;
-        }
-        length += EDITABLE_REGION_START_MARKER.len();
-        length += 1;
-        length += self.before_cursor.len();
-        length += CURSOR_MARKER.len();
-        length += self.after_cursor.len();
-        length += 1;
-        length += EDITABLE_REGION_END_MARKER.len();
-        length += 1;
-        if let Some(after_editable_region) = &self.after_editable_region {
-            length += after_editable_region.len();
-            length += 1;
-        }
-        length += "```".len();
-        length
-    }
-
-    pub fn to_string(&self) -> String {
-        let length = self.len();
-        let mut result = String::with_capacity(length);
-        result.push_str("```");
-        result.push_str(&self.file_path);
-        result.push('\n');
-        if self.include_start_of_file_marker {
-            result.push_str(START_OF_FILE_MARKER);
-            result.push('\n');
-        }
-        if let Some(before_editable_region) = &self.before_editable_region {
-            before_editable_region.add_to_string(&mut result);
-            result.push('\n');
-        }
-        result.push_str(EDITABLE_REGION_START_MARKER);
-        result.push('\n');
-        self.before_cursor.add_to_string(&mut result);
-        result.push_str(CURSOR_MARKER);
-        self.after_cursor.add_to_string(&mut result);
-        result.push('\n');
-        result.push_str(EDITABLE_REGION_END_MARKER);
-        result.push('\n');
-        if let Some(after_editable_region) = &self.after_editable_region {
-            after_editable_region.add_to_string(&mut result);
-            result.push('\n');
-        }
-        result.push_str("```");
-        debug_assert!(
-            result.len() == length,
-            "Expected length: {}, Actual length: {}",
-            length,
-            result.len()
-        );
-        result
-    }
-}
-
-#[derive(Debug, Default)]
-pub struct StringChunks<'a> {
-    chunks: Vec<&'a str>,
-    length: usize,
-}
-
-#[derive(Debug, Default)]
-pub struct ReversedStringChunks<'a>(StringChunks<'a>);
-
-impl<'a> StringChunks<'a> {
-    pub fn len(&self) -> usize {
-        self.length
-    }
+fn prompt_for_excerpt(
+    offset: usize,
+    excerpt_range: &Range<usize>,
+    mut len_guess: usize,
+    path: &str,
+    snapshot: &BufferSnapshot,
+) -> String {
+    let point_range = excerpt_range.to_point(snapshot);
+
+    // Include one line of extra context before and after editable range, if those lines are non-empty.
+    let extra_context_before_range =
+        if point_range.start.row > 0 && !snapshot.is_line_blank(point_range.start.row - 1) {
+            let range =
+                (Point::new(point_range.start.row - 1, 0)..point_range.start).to_offset(snapshot);
+            len_guess += range.end - range.start;
+            Some(range)
+        } else {
+            None
+        };
+    let extra_context_after_range = if point_range.end.row < snapshot.max_point().row
+        && !snapshot.is_line_blank(point_range.end.row + 1)
+    {
+        let range = (point_range.end
+            ..Point::new(
+                point_range.end.row + 1,
+                snapshot.line_len(point_range.end.row + 1),
+            ))
+            .to_offset(snapshot);
+        len_guess += range.end - range.start;
+        Some(range)
+    } else {
+        None
+    };
 
-    pub fn extend(&mut self, new_chunks: impl Iterator<Item = &'a str>) {
-        self.chunks
-            .extend(new_chunks.inspect(|chunk| self.length += chunk.len()));
-    }
+    let mut prompt_excerpt = String::with_capacity(len_guess);
+    writeln!(prompt_excerpt, "```{}", path).unwrap();
 
-    pub fn append_from_buffer<T: ToOffset>(
-        &mut self,
-        snapshot: &'a BufferSnapshot,
-        range: Range<T>,
-    ) {
-        self.extend(snapshot.text_for_range(range));
+    if excerpt_range.start == 0 {
+        writeln!(prompt_excerpt, "{START_OF_FILE_MARKER}").unwrap();
     }
 
-    pub fn add_to_string(&self, string: &mut String) {
-        for chunk in self.chunks.iter() {
-            string.push_str(chunk);
+    if let Some(extra_context_before_range) = extra_context_before_range {
+        for chunk in snapshot.text_for_range(extra_context_before_range) {
+            prompt_excerpt.push_str(chunk);
         }
     }
-}
-
-impl<'a> ReversedStringChunks<'a> {
-    pub fn len(&self) -> usize {
-        self.0.len()
+    writeln!(prompt_excerpt, "{EDITABLE_REGION_START_MARKER}").unwrap();
+    for chunk in snapshot.text_for_range(excerpt_range.start..offset) {
+        prompt_excerpt.push_str(chunk);
     }
-
-    pub fn prepend_from_buffer<T: ToOffset>(
-        &mut self,
-        snapshot: &'a BufferSnapshot,
-        range: Range<T>,
-    ) {
-        self.0.extend(snapshot.reversed_chunks_in_range(range));
+    prompt_excerpt.push_str(CURSOR_MARKER);
+    for chunk in snapshot.text_for_range(offset..excerpt_range.end) {
+        prompt_excerpt.push_str(chunk);
     }
+    write!(prompt_excerpt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
 
-    pub fn add_to_string(&self, string: &mut String) {
-        for chunk in self.0.chunks.iter().rev() {
-            string.push_str(chunk);
+    if let Some(extra_context_after_range) = extra_context_after_range {
+        for chunk in snapshot.text_for_range(extra_context_after_range) {
+            prompt_excerpt.push_str(chunk);
         }
     }
+
+    write!(prompt_excerpt, "\n```").unwrap();
+    debug_assert!(
+        prompt_excerpt.len() <= len_guess,
+        "Excerpt length {} exceeds estimated length {}",
+        prompt_excerpt.len(),
+        len_guess
+    );
+    prompt_excerpt
 }
 
-/// Computes a prompt for the excerpt of the buffer around the cursor. This always includes complete
-/// lines and the result length will be `<= MAX_INPUT_EXCERPT_BYTES`.
-fn prompt_for_excerpt(
-    snapshot: &BufferSnapshot,
+fn excerpt_range_for_position(
     cursor_point: Point,
-    cursor_offset: usize,
-) -> Result<(String, Range<usize>)> {
-    let mut builder = ExcerptPromptBuilder::default();
-    builder.file_path = snapshot.file().map_or(Cow::Borrowed("untitled"), |file| {
-        file.path().to_string_lossy()
-    });
-
+    byte_limit: usize,
+    line_limit: u32,
+    path: &str,
+    snapshot: &BufferSnapshot,
+) -> Result<(Range<usize>, usize)> {
     let cursor_row = cursor_point.row;
-    let cursor_line_start_offset = Point::new(cursor_row, 0).to_offset(snapshot);
-    let cursor_line_end_offset =
-        Point::new(cursor_row, snapshot.line_len(cursor_row)).to_offset(snapshot);
-    builder
-        .before_cursor
-        .prepend_from_buffer(snapshot, cursor_line_start_offset..cursor_offset);
-    builder
-        .after_cursor
-        .append_from_buffer(snapshot, cursor_offset..cursor_line_end_offset);
-
-    if builder.len() > BUFFER_EXCERPT_BYTE_LIMIT {
+    let last_buffer_row = snapshot.max_point().row;
+
+    // This is an overestimate because it includes parts of prompt_for_excerpt which are
+    // conditionally skipped.
+    let mut len_guess = 0;
+    len_guess += "```".len() + path.len() + 1;
+    len_guess += START_OF_FILE_MARKER.len() + 1;
+    len_guess += EDITABLE_REGION_START_MARKER.len() + 1;
+    len_guess += CURSOR_MARKER.len();
+    len_guess += EDITABLE_REGION_END_MARKER.len() + 1;
+    len_guess += "```".len() + 1;
+
+    len_guess += usize::try_from(snapshot.line_len(cursor_row) + 1).unwrap();
+
+    if len_guess > byte_limit {
         return Err(anyhow!("Current line too long to send to model."));
     }
 
-    let last_buffer_row = snapshot.max_point().row;
-
-    // Figure out how many lines of the buffer to include in the prompt, walking outwards from the
-    // cursor. Even if a line before or after the cursor causes the byte limit to be exceeded,
-    // continues walking in the other direction.
-    let mut first_included_row = cursor_row;
-    let mut last_included_row = cursor_row;
+    let mut excerpt_start_row = cursor_row;
+    let mut excerpt_end_row = cursor_row;
     let mut no_more_before = cursor_row == 0;
     let mut no_more_after = cursor_row >= last_buffer_row;
-    let mut output_len = builder.len();
     let mut row_delta = 1;
     loop {
         if !no_more_before {
             let row = cursor_point.row - row_delta;
-            let line_len: usize = (snapshot.line_len(row) + 1).try_into().unwrap();
-            let mut new_output_len = output_len + line_len;
+            let line_len: usize = usize::try_from(snapshot.line_len(row) + 1).unwrap();
+            let mut new_len_guess = len_guess + line_len;
             if row == 0 {
-                new_output_len += START_OF_FILE_MARKER.len() + 1;
+                new_len_guess += START_OF_FILE_MARKER.len() + 1;
             }
-            if new_output_len <= BUFFER_EXCERPT_BYTE_LIMIT {
-                output_len = new_output_len;
-                first_included_row = row;
+            if new_len_guess <= byte_limit {
+                len_guess = new_len_guess;
+                excerpt_start_row = row;
                 if row == 0 {
-                    builder.include_start_of_file_marker = true;
                     no_more_before = true;
                 }
             } else {
                 no_more_before = true;
             }
         }
+        if excerpt_end_row - excerpt_start_row >= line_limit {
+            break;
+        }
         if !no_more_after {
             let row = cursor_point.row + row_delta;
-            let line_len: usize = (snapshot.line_len(row) + 1).try_into().unwrap();
-            let new_output_len = output_len + line_len;
-            if new_output_len <= BUFFER_EXCERPT_BYTE_LIMIT {
-                output_len = new_output_len;
-                last_included_row = row;
+            let line_len: usize = usize::try_from(snapshot.line_len(row) + 1).unwrap();
+            let new_len_guess = len_guess + line_len;
+            if new_len_guess <= byte_limit {
+                len_guess = new_len_guess;
+                excerpt_end_row = row;
                 if row >= last_buffer_row {
                     no_more_after = true;
                 }
@@ -1200,66 +1158,21 @@ fn prompt_for_excerpt(
                 no_more_after = true;
             }
         }
+        if excerpt_end_row - excerpt_start_row >= line_limit {
+            break;
+        }
         if no_more_before && no_more_after {
             break;
         }
         row_delta += 1;
     }
 
-    // Include a line of context outside the editable region, but only if it is not the first line
-    // (otherwise the first line of the file would never be uneditable).
-    let first_editable_row = if first_included_row != 0
-        && first_included_row < cursor_row
-        && !snapshot.is_line_blank(first_included_row)
-    {
-        let mut before_editable_region = ReversedStringChunks::default();
-        before_editable_region.prepend_from_buffer(
-            snapshot,
-            Point::new(first_included_row, 0)
-                ..Point::new(first_included_row, snapshot.line_len(first_included_row)),
-        );
-        builder.before_editable_region = Some(before_editable_region);
-        first_included_row + 1
-    } else {
-        first_included_row
-    };
-
-    // Include a line of context outside the editable region, but only if it is not the last line
-    // (otherwise the first line of the file would never be uneditable).
-    let last_editable_row = if last_included_row < last_buffer_row
-        && last_included_row > cursor_row
-        && !snapshot.is_line_blank(last_included_row)
-    {
-        let mut after_editable_region = StringChunks::default();
-        after_editable_region.append_from_buffer(
-            snapshot,
-            Point::new(last_included_row, 0)
-                ..Point::new(last_included_row, snapshot.line_len(last_included_row)),
-        );
-        builder.after_editable_region = Some(after_editable_region);
-        last_included_row + 1
-    } else {
-        last_included_row
-    };
-
-    let editable_range = (Point::new(first_editable_row, 0)
-        ..Point::new(last_editable_row, snapshot.line_len(last_editable_row)))
-        .to_offset(snapshot);
-
-    let before_cursor_row = editable_range.start..cursor_line_start_offset;
-    let after_cursor_row = cursor_line_end_offset..editable_range.end;
-    if !before_cursor_row.is_empty() {
-        builder
-            .before_cursor
-            .prepend_from_buffer(snapshot, before_cursor_row);
-    }
-    if !after_cursor_row.is_empty() {
-        builder
-            .after_cursor
-            .append_from_buffer(snapshot, after_cursor_row);
-    }
-
-    anyhow::Ok((builder.to_string(), editable_range))
+    let excerpt_start = Point::new(excerpt_start_row, 0);
+    let excerpt_end = Point::new(excerpt_end_row, snapshot.line_len(excerpt_end_row));
+    Ok((
+        excerpt_start.to_offset(snapshot)..excerpt_end.to_offset(snapshot),
+        len_guess,
+    ))
 }
 
 fn prompt_for_events<'a>(