From ade3e45a36f829821db5d933de105f2c9d0e0485 Mon Sep 17 00:00:00 2001 From: Michael Sloan Date: Wed, 29 Jan 2025 14:56:29 -0700 Subject: [PATCH] Add character limits to edit prediction prompt generation (#23814) Limits the size of the buffer excerpt and the size of change history. Release Notes: - N/A --------- Co-authored-by: Richard Co-authored-by: Joao --- crates/zeta/src/zeta.rs | 383 ++++++++++++++++++++++++++++++++-------- 1 file changed, 308 insertions(+), 75 deletions(-) diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 2f297d97e9a2d405bc5e40431c2f6cc576a91686..0eefc78f6600e920a3c73fa360cbaad7f528e854 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -41,6 +41,30 @@ const EDITABLE_REGION_START_MARKER: &'static str = "<|editable_region_start|>"; const EDITABLE_REGION_END_MARKER: &'static str = "<|editable_region_end|>"; const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1); +// TODO(mgsloan): more systematic way to choose or tune these fairly arbitrary constants? + +/// Typical number of string bytes per token for the purposes of limiting model input. This is +/// intentionally low to err on the side of underestimating limits. +const BYTES_PER_TOKEN_GUESS: usize = 3; + +/// This is based on the output token limit `max_tokens: 2048` in `crates/collab/src/llm.rs`. Number +/// of output tokens is relevant to the size of the input excerpt because the model is tasked with +/// outputting a modified excerpt. `2/3` is chosen so that there are some output tokens remaining +/// for the model to specify insertions. +const BUFFER_EXCERPT_BYTE_LIMIT: usize = (2048 * 2 / 3) * BYTES_PER_TOKEN_GUESS; + +/// Note that this is not the limit for the overall prompt, just for the inputs to the template +/// instantiated in `crates/collab/src/llm.rs`. +const TOTAL_BYTE_LIMIT: usize = BUFFER_EXCERPT_BYTE_LIMIT * 2; + +/// Maximum number of events to include in the prompt. +const MAX_EVENT_COUNT: usize = 16; + +/// Maximum number of string bytes in a single event. Arbitrarily choosing this to be 4x the size of +/// equally splitting up the the remaining bytes after the largest possible buffer excerpt. +const PER_EVENT_BYTE_LIMIT: usize = + (TOTAL_BYTE_LIMIT - BUFFER_EXCERPT_BYTE_LIMIT) / MAX_EVENT_COUNT * 4; + actions!(edit_prediction, [ClearHistory]); #[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)] @@ -223,8 +247,6 @@ impl Zeta { } fn push_event(&mut self, event: Event) { - const MAX_EVENT_COUNT: usize = 16; - if let Some(Event::BufferChange { new_snapshot: last_new_snapshot, timestamp: last_timestamp, @@ -294,7 +316,7 @@ impl Zeta { pub fn request_completion_impl( &mut self, buffer: &Entity, - position: language::Anchor, + cursor: language::Anchor, cx: &mut Context, perform_predict_edits: F, ) -> Task>> @@ -303,9 +325,8 @@ impl Zeta { R: Future> + Send + 'static, { let snapshot = self.report_changes_for_buffer(buffer, cx); - let point = position.to_point(&snapshot); - let offset = point.to_offset(&snapshot); - let excerpt_range = excerpt_range_for_position(point, &snapshot); + let cursor_point = cursor.to_point(&snapshot); + let cursor_offset = cursor_point.to_offset(&snapshot); let events = self.events.clone(); let path = snapshot .file() @@ -319,28 +340,25 @@ impl Zeta { cx.spawn(|_, cx| async move { let request_sent_at = Instant::now(); - let (input_events, input_excerpt, input_outline) = cx + let (input_events, input_excerpt, input_outline, excerpt_range) = cx .background_executor() .spawn({ let snapshot = snapshot.clone(); - let excerpt_range = excerpt_range.clone(); async move { - let mut input_events = String::new(); - for event in events { - if !input_events.is_empty() { - input_events.push('\n'); - input_events.push('\n'); - } - input_events.push_str(&event.to_prompt()); - } + let (input_excerpt, excerpt_range) = + prompt_for_excerpt(&snapshot, cursor_point, cursor_offset)?; + + let chars_remaining = TOTAL_BYTE_LIMIT.saturating_sub(input_excerpt.len()); + let input_events = prompt_for_events(events.iter(), chars_remaining); - let input_excerpt = prompt_for_excerpt(&snapshot, &excerpt_range, offset); + // Note that input_outline is not currently used in prompt generation and so + // is not counted towards TOTAL_BYTE_LIMIT. let input_outline = prompt_for_outline(&snapshot); - (input_events, input_excerpt, input_outline) + anyhow::Ok((input_events, input_excerpt, input_outline, excerpt_range)) } }) - .await; + .await?; log::debug!("Events:\n{}\nExcerpt:\n{}", input_events, input_excerpt); @@ -359,7 +377,7 @@ impl Zeta { output_excerpt, &snapshot, excerpt_range, - offset, + cursor_offset, path, input_outline, input_events, @@ -814,77 +832,292 @@ fn prompt_for_outline(snapshot: &BufferSnapshot) -> String { input_outline } -fn prompt_for_excerpt( - snapshot: &BufferSnapshot, - excerpt_range: &Range, - offset: usize, -) -> String { - let mut prompt_excerpt = String::new(); - writeln!( - prompt_excerpt, - "```{}", - snapshot - .file() - .map_or(Cow::Borrowed("untitled"), |file| file - .path() - .to_string_lossy()) - ) - .unwrap(); +#[derive(Debug, Default)] +struct ExcerptPromptBuilder<'a> { + file_path: Cow<'a, str>, + include_start_of_file_marker: bool, + before_editable_region: Option>, + before_cursor: ReversedStringChunks<'a>, + after_cursor: StringChunks<'a>, + after_editable_region: Option>, +} - if excerpt_range.start == 0 { - writeln!(prompt_excerpt, "{START_OF_FILE_MARKER}").unwrap(); +impl<'a> ExcerptPromptBuilder<'a> { + pub fn len(&self) -> usize { + let mut length = 0; + length += "```".len(); + length += self.file_path.len(); + length += 1; + if self.include_start_of_file_marker { + length += START_OF_FILE_MARKER.len(); + length += 1; + } + if let Some(before_editable_region) = &self.before_editable_region { + length += before_editable_region.len(); + length += 1; + } + length += EDITABLE_REGION_START_MARKER.len(); + length += 1; + length += self.before_cursor.len(); + length += CURSOR_MARKER.len(); + length += self.after_cursor.len(); + length += 1; + length += EDITABLE_REGION_END_MARKER.len(); + length += 1; + if let Some(after_editable_region) = &self.after_editable_region { + length += after_editable_region.len(); + length += 1; + } + length += "```".len(); + length } - let point_range = excerpt_range.to_point(snapshot); - if point_range.start.row > 0 && !snapshot.is_line_blank(point_range.start.row - 1) { - let extra_context_line_range = Point::new(point_range.start.row - 1, 0)..point_range.start; - for chunk in snapshot.text_for_range(extra_context_line_range) { - prompt_excerpt.push_str(chunk); + pub fn to_string(&self) -> String { + let length = self.len(); + let mut result = String::with_capacity(length); + result.push_str("```"); + result.push_str(&self.file_path); + result.push('\n'); + if self.include_start_of_file_marker { + result.push_str(START_OF_FILE_MARKER); + result.push('\n'); + } + if let Some(before_editable_region) = &self.before_editable_region { + before_editable_region.add_to_string(&mut result); + result.push('\n'); } + result.push_str(EDITABLE_REGION_START_MARKER); + result.push('\n'); + self.before_cursor.add_to_string(&mut result); + result.push_str(CURSOR_MARKER); + self.after_cursor.add_to_string(&mut result); + result.push('\n'); + result.push_str(EDITABLE_REGION_END_MARKER); + result.push('\n'); + if let Some(after_editable_region) = &self.after_editable_region { + after_editable_region.add_to_string(&mut result); + result.push('\n'); + } + result.push_str("```"); + debug_assert!( + result.len() == length, + "Expected length: {}, Actual length: {}", + length, + result.len() + ); + result } - writeln!(prompt_excerpt, "{EDITABLE_REGION_START_MARKER}").unwrap(); - for chunk in snapshot.text_for_range(excerpt_range.start..offset) { - prompt_excerpt.push_str(chunk); +} + +#[derive(Debug, Default)] +pub struct StringChunks<'a> { + chunks: Vec<&'a str>, + length: usize, +} + +#[derive(Debug, Default)] +pub struct ReversedStringChunks<'a>(StringChunks<'a>); + +impl<'a> StringChunks<'a> { + pub fn len(&self) -> usize { + self.length } - prompt_excerpt.push_str(CURSOR_MARKER); - for chunk in snapshot.text_for_range(offset..excerpt_range.end) { - prompt_excerpt.push_str(chunk); + + pub fn extend(&mut self, new_chunks: impl Iterator) { + self.chunks + .extend(new_chunks.inspect(|chunk| self.length += chunk.len())); } - write!(prompt_excerpt, "\n{EDITABLE_REGION_END_MARKER}").unwrap(); - if point_range.end.row < snapshot.max_point().row - && !snapshot.is_line_blank(point_range.end.row + 1) - { - let extra_context_line_range = point_range.end - ..Point::new( - point_range.end.row + 1, - snapshot.line_len(point_range.end.row + 1), - ); - for chunk in snapshot.text_for_range(extra_context_line_range) { - prompt_excerpt.push_str(chunk); + pub fn append_from_buffer( + &mut self, + snapshot: &'a BufferSnapshot, + range: Range, + ) { + self.extend(snapshot.text_for_range(range)); + } + + pub fn add_to_string(&self, string: &mut String) { + for chunk in self.chunks.iter() { + string.push_str(chunk); } } +} + +impl<'a> ReversedStringChunks<'a> { + pub fn len(&self) -> usize { + self.0.len() + } - write!(prompt_excerpt, "\n```").unwrap(); - prompt_excerpt + pub fn prepend_from_buffer( + &mut self, + snapshot: &'a BufferSnapshot, + range: Range, + ) { + self.0.extend(snapshot.reversed_chunks_in_range(range)); + } + + pub fn add_to_string(&self, string: &mut String) { + for chunk in self.0.chunks.iter().rev() { + string.push_str(chunk); + } + } } -fn excerpt_range_for_position(point: Point, snapshot: &BufferSnapshot) -> Range { - const CONTEXT_LINES: u32 = 32; +/// Computes a prompt for the excerpt of the buffer around the cursor. This always includes complete +/// lines and the result length will be `<= MAX_INPUT_EXCERPT_BYTES`. +fn prompt_for_excerpt( + snapshot: &BufferSnapshot, + cursor_point: Point, + cursor_offset: usize, +) -> Result<(String, Range)> { + let mut builder = ExcerptPromptBuilder::default(); + builder.file_path = snapshot.file().map_or(Cow::Borrowed("untitled"), |file| { + file.path().to_string_lossy() + }); + + let cursor_row = cursor_point.row; + let cursor_line_start_offset = Point::new(cursor_row, 0).to_offset(snapshot); + let cursor_line_end_offset = + Point::new(cursor_row, snapshot.line_len(cursor_row)).to_offset(snapshot); + builder + .before_cursor + .prepend_from_buffer(snapshot, cursor_line_start_offset..cursor_offset); + builder + .after_cursor + .append_from_buffer(snapshot, cursor_offset..cursor_line_end_offset); + + if builder.len() > BUFFER_EXCERPT_BYTE_LIMIT { + return Err(anyhow!("Current line too long to send to model.")); + } - let mut context_lines_before = CONTEXT_LINES; - let mut context_lines_after = CONTEXT_LINES; - if point.row < CONTEXT_LINES { - context_lines_after += CONTEXT_LINES - point.row; - } else if point.row + CONTEXT_LINES > snapshot.max_point().row { - context_lines_before += (point.row + CONTEXT_LINES) - snapshot.max_point().row; + let last_buffer_row = snapshot.max_point().row; + + // Figure out how many lines of the buffer to include in the prompt, walking outwards from the + // cursor. Even if a line before or after the cursor causes the byte limit to be exceeded, + // continues walking in the other direction. + let mut first_included_row = cursor_row; + let mut last_included_row = cursor_row; + let mut no_more_before = cursor_row == 0; + let mut no_more_after = cursor_row >= last_buffer_row; + let mut output_len = builder.len(); + let mut row_delta = 1; + loop { + if !no_more_before { + let row = cursor_point.row - row_delta; + let line_len: usize = (snapshot.line_len(row) + 1).try_into().unwrap(); + let mut new_output_len = output_len + line_len; + if row == 0 { + new_output_len += START_OF_FILE_MARKER.len() + 1; + } + if new_output_len <= BUFFER_EXCERPT_BYTE_LIMIT { + output_len = new_output_len; + first_included_row = row; + if row == 0 { + builder.include_start_of_file_marker = true; + no_more_before = true; + } + } else { + no_more_before = true; + } + } + if !no_more_after { + let row = cursor_point.row + row_delta; + let line_len: usize = (snapshot.line_len(row) + 1).try_into().unwrap(); + let new_output_len = output_len + line_len; + if new_output_len <= BUFFER_EXCERPT_BYTE_LIMIT { + output_len = new_output_len; + last_included_row = row; + if row >= last_buffer_row { + no_more_after = true; + } + } else { + no_more_after = true; + } + } + if no_more_before && no_more_after { + break; + } + row_delta += 1; } - let excerpt_start_row = point.row.saturating_sub(context_lines_before); - let excerpt_start = Point::new(excerpt_start_row, 0); - let excerpt_end_row = cmp::min(point.row + context_lines_after, snapshot.max_point().row); - let excerpt_end = Point::new(excerpt_end_row, snapshot.line_len(excerpt_end_row)); - excerpt_start.to_offset(snapshot)..excerpt_end.to_offset(snapshot) + // Include a line of context outside the editable region, but only if it is not the first line + // (otherwise the first line of the file would never be uneditable). + let first_editable_row = if first_included_row != 0 + && first_included_row < cursor_row + && !snapshot.is_line_blank(first_included_row) + { + let mut before_editable_region = ReversedStringChunks::default(); + before_editable_region.prepend_from_buffer( + snapshot, + Point::new(first_included_row, 0) + ..Point::new(first_included_row, snapshot.line_len(first_included_row)), + ); + builder.before_editable_region = Some(before_editable_region); + first_included_row + 1 + } else { + first_included_row + }; + + // Include a line of context outside the editable region, but only if it is not the last line + // (otherwise the first line of the file would never be uneditable). + let last_editable_row = if last_included_row < last_buffer_row + && last_included_row > cursor_row + && !snapshot.is_line_blank(last_included_row) + { + let mut after_editable_region = StringChunks::default(); + after_editable_region.append_from_buffer( + snapshot, + Point::new(last_included_row, 0) + ..Point::new(last_included_row, snapshot.line_len(last_included_row)), + ); + builder.after_editable_region = Some(after_editable_region); + last_included_row + 1 + } else { + last_included_row + }; + + let editable_range = (Point::new(first_editable_row, 0) + ..Point::new(last_editable_row, snapshot.line_len(last_editable_row))) + .to_offset(snapshot); + + let before_cursor_row = editable_range.start..cursor_line_start_offset; + let after_cursor_row = cursor_line_end_offset..editable_range.end; + if !before_cursor_row.is_empty() { + builder + .before_cursor + .prepend_from_buffer(snapshot, before_cursor_row); + } + if !after_cursor_row.is_empty() { + builder + .after_cursor + .append_from_buffer(snapshot, after_cursor_row); + } + + anyhow::Ok((builder.to_string(), editable_range)) +} + +fn prompt_for_events<'a>( + events: impl Iterator, + mut bytes_remaining: usize, +) -> String { + let mut result = String::new(); + for event in events { + if !result.is_empty() { + result.push('\n'); + result.push('\n'); + } + let event_string = event.to_prompt(); + let len = event_string.len(); + if len > PER_EVENT_BYTE_LIMIT { + continue; + } + if len > bytes_remaining { + break; + } + bytes_remaining -= len; + result.push_str(&event_string); + } + result } struct RegisteredBuffer {