@@ -0,0 +1,238 @@
+use crate::{
+ BYTES_PER_TOKEN_GUESS, CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER,
+ START_OF_FILE_MARKER,
+};
+use language::{BufferSnapshot, Point};
+use std::{fmt::Write, ops::Range};
+
+pub struct InputExcerpt {
+ pub editable_range: Range<Point>,
+ pub prompt: String,
+ pub speculated_output: String,
+}
+
+pub fn excerpt_for_cursor_position(
+ position: Point,
+ path: &str,
+ snapshot: &BufferSnapshot,
+ editable_region_token_limit: usize,
+ context_token_limit: usize,
+) -> InputExcerpt {
+ let mut scope_range = position..position;
+ let mut remaining_edit_tokens = editable_region_token_limit;
+
+ while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) {
+ let parent_tokens = tokens_for_bytes(parent.byte_range().len());
+ if parent_tokens <= editable_region_token_limit {
+ scope_range = Point::new(
+ parent.start_position().row as u32,
+ parent.start_position().column as u32,
+ )
+ ..Point::new(
+ parent.end_position().row as u32,
+ parent.end_position().column as u32,
+ );
+ remaining_edit_tokens = editable_region_token_limit - parent_tokens;
+ } else {
+ break;
+ }
+ }
+
+ let editable_range = expand_range(snapshot, scope_range, remaining_edit_tokens);
+ let context_range = expand_range(snapshot, editable_range.clone(), context_token_limit);
+
+ let mut prompt = String::new();
+ let mut speculated_output = String::new();
+
+ writeln!(&mut prompt, "```{path}").unwrap();
+ if context_range.start == Point::zero() {
+ writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
+ }
+
+ for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
+ prompt.push_str(chunk.text);
+ }
+
+ push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
+ push_editable_range(
+ position,
+ snapshot,
+ editable_range.clone(),
+ &mut speculated_output,
+ );
+
+ for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
+ prompt.push_str(chunk.text);
+ }
+ write!(prompt, "\n```").unwrap();
+
+ InputExcerpt {
+ editable_range,
+ prompt,
+ speculated_output,
+ }
+}
+
+fn push_editable_range(
+ cursor_position: Point,
+ snapshot: &BufferSnapshot,
+ editable_range: Range<Point>,
+ prompt: &mut String,
+) {
+ writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
+ for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
+ prompt.push_str(chunk.text);
+ }
+ prompt.push_str(CURSOR_MARKER);
+ for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
+ prompt.push_str(chunk.text);
+ }
+ write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
+}
+
+fn expand_range(
+ snapshot: &BufferSnapshot,
+ range: Range<Point>,
+ mut remaining_tokens: usize,
+) -> Range<Point> {
+ let mut expanded_range = range.clone();
+ expanded_range.start.column = 0;
+ expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
+ loop {
+ let mut expanded = false;
+
+ if remaining_tokens > 0 && expanded_range.start.row > 0 {
+ expanded_range.start.row -= 1;
+ let line_tokens =
+ tokens_for_bytes(snapshot.line_len(expanded_range.start.row) as usize);
+ remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
+ expanded = true;
+ }
+
+ if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row {
+ expanded_range.end.row += 1;
+ expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
+ let line_tokens = tokens_for_bytes(expanded_range.end.column as usize);
+ remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
+ expanded = true;
+ }
+
+ if !expanded {
+ break;
+ }
+ }
+ expanded_range
+}
+
+fn tokens_for_bytes(bytes: usize) -> usize {
+ bytes / BYTES_PER_TOKEN_GUESS
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use gpui::{App, AppContext};
+ use indoc::indoc;
+ use language::{Buffer, Language, LanguageConfig, LanguageMatcher};
+ use std::sync::Arc;
+
+ #[gpui::test]
+ fn test_excerpt_for_cursor_position(cx: &mut App) {
+ let text = indoc! {r#"
+ fn foo() {
+ let x = 42;
+ println!("Hello, world!");
+ }
+
+ fn bar() {
+ let x = 42;
+ let mut sum = 0;
+ for i in 0..x {
+ sum += i;
+ }
+ println!("Sum: {}", sum);
+ return sum;
+ }
+
+ fn generate_random_numbers() -> Vec<i32> {
+ let mut rng = rand::thread_rng();
+ let mut numbers = Vec::new();
+ for _ in 0..5 {
+ numbers.push(rng.gen_range(1..101));
+ }
+ numbers
+ }
+ "#};
+ let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
+ let snapshot = buffer.read(cx).snapshot();
+
+ // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
+ // when a larger scope doesn't fit the editable region.
+ let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
+ assert_eq!(
+ excerpt.prompt,
+ indoc! {r#"
+ ```main.rs
+ let x = 42;
+ println!("Hello, world!");
+ <|editable_region_start|>
+ }
+
+ fn bar() {
+ let x = 42;
+ let mut sum = 0;
+ for i in 0..x {
+ sum += i;
+ }
+ println!("Sum: {}", sum);
+ r<|user_cursor_is_here|>eturn sum;
+ }
+
+ fn generate_random_numbers() -> Vec<i32> {
+ <|editable_region_end|>
+ let mut rng = rand::thread_rng();
+ let mut numbers = Vec::new();
+ ```"#}
+ );
+
+ // The `bar` function won't fit within the editable region, so we resort to line-based expansion.
+ let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
+ assert_eq!(
+ excerpt.prompt,
+ indoc! {r#"
+ ```main.rs
+ fn bar() {
+ let x = 42;
+ let mut sum = 0;
+ <|editable_region_start|>
+ for i in 0..x {
+ sum += i;
+ }
+ println!("Sum: {}", sum);
+ r<|user_cursor_is_here|>eturn sum;
+ }
+
+ fn generate_random_numbers() -> Vec<i32> {
+ let mut rng = rand::thread_rng();
+ <|editable_region_end|>
+ let mut numbers = Vec::new();
+ for _ in 0..5 {
+ numbers.push(rng.gen_range(1..101));
+ ```"#}
+ );
+ }
+
+ fn rust_lang() -> Language {
+ Language::new(
+ LanguageConfig {
+ name: "Rust".into(),
+ matcher: LanguageMatcher {
+ path_suffixes: vec!["rs".to_string()],
+ ..Default::default()
+ },
+ ..Default::default()
+ },
+ Some(tree_sitter_rust::LANGUAGE.into()),
+ )
+ }
+}
@@ -1,5 +1,6 @@
mod completion_diff_element;
mod init;
+mod input_excerpt;
mod license_detection;
mod onboarding_banner;
mod onboarding_modal;
@@ -25,7 +26,7 @@ use gpui::{
use http_client::{HttpClient, Method};
use language::{
language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, EditPreview,
- OffsetRangeExt, Point, ToOffset, ToPoint,
+ OffsetRangeExt, ToOffset, ToPoint,
};
use language_models::LlmApiToken;
use postage::watch;
@@ -61,26 +62,26 @@ const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_ch
/// intentionally low to err on the side of underestimating limits.
const BYTES_PER_TOKEN_GUESS: usize = 3;
-/// Output token limit, used to inform the size of the input. A copy of this constant is also in
+/// Input token limit, used to inform the size of the input. A copy of this constant is also in
/// `crates/collab/src/llm.rs`.
-const MAX_OUTPUT_TOKENS: usize = 2048;
+const MAX_INPUT_TOKENS: usize = 2048;
+
+const MAX_CONTEXT_TOKENS: usize = 64;
+const MAX_OUTPUT_TOKENS: usize = 256;
/// Total bytes limit for editable region of buffer excerpt.
///
/// The number of output tokens is relevant to the size of the input excerpt because the model is
/// tasked with outputting a modified excerpt. `2/3` is chosen so that there are some output tokens
/// remaining for the model to specify insertions.
-const BUFFER_EXCERPT_BYTE_LIMIT: usize = (MAX_OUTPUT_TOKENS * 2 / 3) * BYTES_PER_TOKEN_GUESS;
-
-/// Total line limit for editable region of buffer excerpt.
-const BUFFER_EXCERPT_LINE_LIMIT: u32 = 64;
+const BUFFER_EXCERPT_BYTE_LIMIT: usize = (MAX_INPUT_TOKENS * 2 / 3) * BYTES_PER_TOKEN_GUESS;
/// Note that this is not the limit for the overall prompt, just for the inputs to the template
/// instantiated in `crates/collab/src/llm.rs`.
const TOTAL_BYTE_LIMIT: usize = BUFFER_EXCERPT_BYTE_LIMIT * 2;
/// Maximum number of events to include in the prompt.
-const MAX_EVENT_COUNT: usize = 16;
+const MAX_EVENT_COUNT: usize = 8;
/// Maximum number of string bytes in a single event. Arbitrarily choosing this to be 4x the size of
/// equally splitting up the the remaining bytes after the largest possible buffer excerpt.
@@ -373,8 +374,8 @@ impl Zeta {
R: Future<Output = Result<PredictEditsResponse>> + Send + 'static,
{
let snapshot = self.report_changes_for_buffer(&buffer, cx);
- let cursor_point = cursor.to_point(&snapshot);
- let cursor_offset = cursor_point.to_offset(&snapshot);
+ let cursor_position = cursor.to_point(&snapshot);
+ let cursor_offset = cursor_position.to_offset(&snapshot);
let events = self.events.clone();
let path: Arc<Path> = snapshot
.file()
@@ -389,45 +390,47 @@ impl Zeta {
cx.spawn(|_, cx| async move {
let request_sent_at = Instant::now();
- let (input_events, input_excerpt, excerpt_range, input_outline) = cx
- .background_executor()
- .spawn({
- let snapshot = snapshot.clone();
- let path = path.clone();
- async move {
- let path = path.to_string_lossy();
- let (excerpt_range, excerpt_len_guess) = excerpt_range_for_position(
- cursor_point,
- BUFFER_EXCERPT_BYTE_LIMIT,
- BUFFER_EXCERPT_LINE_LIMIT,
- &path,
- &snapshot,
- )?;
- let input_excerpt = prompt_for_excerpt(
- cursor_offset,
- &excerpt_range,
- excerpt_len_guess,
- &path,
- &snapshot,
- );
-
- let bytes_remaining = TOTAL_BYTE_LIMIT.saturating_sub(input_excerpt.len());
- let input_events = prompt_for_events(events.iter(), bytes_remaining);
-
- // Note that input_outline is not currently used in prompt generation and so
- // is not counted towards TOTAL_BYTE_LIMIT.
- let input_outline = prompt_for_outline(&snapshot);
-
- anyhow::Ok((input_events, input_excerpt, excerpt_range, input_outline))
- }
- })
- .await?;
+ let (input_events, input_excerpt, editable_range, input_outline, speculated_output) =
+ cx.background_executor()
+ .spawn({
+ let snapshot = snapshot.clone();
+ let path = path.clone();
+ async move {
+ let path = path.to_string_lossy();
+ let input_excerpt = input_excerpt::excerpt_for_cursor_position(
+ cursor_position,
+ &path,
+ &snapshot,
+ MAX_OUTPUT_TOKENS,
+ MAX_CONTEXT_TOKENS,
+ );
+
+ let bytes_remaining =
+ TOTAL_BYTE_LIMIT.saturating_sub(input_excerpt.prompt.len());
+ let input_events = prompt_for_events(events.iter(), bytes_remaining);
+
+ // Note that input_outline is not currently used in prompt generation and so
+ // is not counted towards TOTAL_BYTE_LIMIT.
+ let input_outline = prompt_for_outline(&snapshot);
+
+ let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
+ anyhow::Ok((
+ input_events,
+ input_excerpt.prompt,
+ editable_range,
+ input_outline,
+ input_excerpt.speculated_output,
+ ))
+ }
+ })
+ .await?;
log::debug!("Events:\n{}\nExcerpt:\n{}", input_events, input_excerpt);
let body = PredictEditsParams {
input_events: input_events.clone(),
input_excerpt: input_excerpt.clone(),
+ speculated_output,
outline: Some(input_outline.clone()),
data_collection_permission,
};
@@ -441,7 +444,7 @@ impl Zeta {
output_excerpt,
buffer,
&snapshot,
- excerpt_range,
+ editable_range,
cursor_offset,
path,
input_outline,
@@ -457,6 +460,8 @@ impl Zeta {
// Generates several example completions of various states to fill the Zeta completion modal
#[cfg(any(test, feature = "test-support"))]
pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
+ use language::Point;
+
let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
And maybe a short line
@@ -675,7 +680,7 @@ and then another
output_excerpt: String,
buffer: Entity<Buffer>,
snapshot: &BufferSnapshot,
- excerpt_range: Range<usize>,
+ editable_range: Range<usize>,
cursor_offset: usize,
path: Arc<Path>,
input_outline: String,
@@ -692,9 +697,9 @@ and then another
.background_executor()
.spawn({
let output_excerpt = output_excerpt.clone();
- let excerpt_range = excerpt_range.clone();
+ let editable_range = editable_range.clone();
let snapshot = snapshot.clone();
- async move { Self::parse_edits(output_excerpt, excerpt_range, &snapshot) }
+ async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
})
.await?
.into();
@@ -717,7 +722,7 @@ and then another
Ok(Some(InlineCompletion {
id: InlineCompletionId::new(),
path,
- excerpt_range,
+ excerpt_range: editable_range,
cursor_offset,
edits,
edit_preview,
@@ -734,7 +739,7 @@ and then another
fn parse_edits(
output_excerpt: Arc<str>,
- excerpt_range: Range<usize>,
+ editable_range: Range<usize>,
snapshot: &BufferSnapshot,
) -> Result<Vec<(Range<Anchor>, String)>> {
let content = output_excerpt.replace(CURSOR_MARKER, "");
@@ -778,13 +783,13 @@ and then another
let new_text = &content[..codefence_end];
let old_text = snapshot
- .text_for_range(excerpt_range.clone())
+ .text_for_range(editable_range.clone())
.collect::<String>();
Ok(Self::compute_edits(
old_text,
new_text,
- excerpt_range.start,
+ editable_range.start,
&snapshot,
))
}
@@ -1011,161 +1016,6 @@ fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
input_outline
}
-fn prompt_for_excerpt(
- offset: usize,
- excerpt_range: &Range<usize>,
- mut len_guess: usize,
- path: &str,
- snapshot: &BufferSnapshot,
-) -> String {
- let point_range = excerpt_range.to_point(snapshot);
-
- // Include one line of extra context before and after editable range, if those lines are non-empty.
- let extra_context_before_range =
- if point_range.start.row > 0 && !snapshot.is_line_blank(point_range.start.row - 1) {
- let range =
- (Point::new(point_range.start.row - 1, 0)..point_range.start).to_offset(snapshot);
- len_guess += range.end - range.start;
- Some(range)
- } else {
- None
- };
- let extra_context_after_range = if point_range.end.row < snapshot.max_point().row
- && !snapshot.is_line_blank(point_range.end.row + 1)
- {
- let range = (point_range.end
- ..Point::new(
- point_range.end.row + 1,
- snapshot.line_len(point_range.end.row + 1),
- ))
- .to_offset(snapshot);
- len_guess += range.end - range.start;
- Some(range)
- } else {
- None
- };
-
- let mut prompt_excerpt = String::with_capacity(len_guess);
- writeln!(prompt_excerpt, "```{}", path).unwrap();
-
- if excerpt_range.start == 0 {
- writeln!(prompt_excerpt, "{START_OF_FILE_MARKER}").unwrap();
- }
-
- if let Some(extra_context_before_range) = extra_context_before_range {
- for chunk in snapshot.text_for_range(extra_context_before_range) {
- prompt_excerpt.push_str(chunk);
- }
- }
- writeln!(prompt_excerpt, "{EDITABLE_REGION_START_MARKER}").unwrap();
- for chunk in snapshot.text_for_range(excerpt_range.start..offset) {
- prompt_excerpt.push_str(chunk);
- }
- prompt_excerpt.push_str(CURSOR_MARKER);
- for chunk in snapshot.text_for_range(offset..excerpt_range.end) {
- prompt_excerpt.push_str(chunk);
- }
- write!(prompt_excerpt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
-
- if let Some(extra_context_after_range) = extra_context_after_range {
- for chunk in snapshot.text_for_range(extra_context_after_range) {
- prompt_excerpt.push_str(chunk);
- }
- }
-
- write!(prompt_excerpt, "\n```").unwrap();
- debug_assert!(
- prompt_excerpt.len() <= len_guess,
- "Excerpt length {} exceeds estimated length {}",
- prompt_excerpt.len(),
- len_guess
- );
- prompt_excerpt
-}
-
-fn excerpt_range_for_position(
- cursor_point: Point,
- byte_limit: usize,
- line_limit: u32,
- path: &str,
- snapshot: &BufferSnapshot,
-) -> Result<(Range<usize>, usize)> {
- let cursor_row = cursor_point.row;
- let last_buffer_row = snapshot.max_point().row;
-
- // This is an overestimate because it includes parts of prompt_for_excerpt which are
- // conditionally skipped.
- let mut len_guess = 0;
- len_guess += "```".len() + path.len() + 1;
- len_guess += START_OF_FILE_MARKER.len() + 1;
- len_guess += EDITABLE_REGION_START_MARKER.len() + 1;
- len_guess += CURSOR_MARKER.len();
- len_guess += EDITABLE_REGION_END_MARKER.len() + 1;
- len_guess += "```".len() + 1;
-
- len_guess += usize::try_from(snapshot.line_len(cursor_row) + 1).unwrap();
-
- if len_guess > byte_limit {
- return Err(anyhow!("Current line too long to send to model."));
- }
-
- let mut excerpt_start_row = cursor_row;
- let mut excerpt_end_row = cursor_row;
- let mut no_more_before = cursor_row == 0;
- let mut no_more_after = cursor_row >= last_buffer_row;
- let mut row_delta = 1;
- loop {
- if !no_more_before {
- let row = cursor_point.row - row_delta;
- let line_len: usize = usize::try_from(snapshot.line_len(row) + 1).unwrap();
- let mut new_len_guess = len_guess + line_len;
- if row == 0 {
- new_len_guess += START_OF_FILE_MARKER.len() + 1;
- }
- if new_len_guess <= byte_limit {
- len_guess = new_len_guess;
- excerpt_start_row = row;
- if row == 0 {
- no_more_before = true;
- }
- } else {
- no_more_before = true;
- }
- }
- if excerpt_end_row - excerpt_start_row >= line_limit {
- break;
- }
- if !no_more_after {
- let row = cursor_point.row + row_delta;
- let line_len: usize = usize::try_from(snapshot.line_len(row) + 1).unwrap();
- let new_len_guess = len_guess + line_len;
- if new_len_guess <= byte_limit {
- len_guess = new_len_guess;
- excerpt_end_row = row;
- if row >= last_buffer_row {
- no_more_after = true;
- }
- } else {
- no_more_after = true;
- }
- }
- if excerpt_end_row - excerpt_start_row >= line_limit {
- break;
- }
- if no_more_before && no_more_after {
- break;
- }
- row_delta += 1;
- }
-
- let excerpt_start = Point::new(excerpt_start_row, 0);
- let excerpt_end = Point::new(excerpt_end_row, snapshot.line_len(excerpt_end_row));
- Ok((
- excerpt_start.to_offset(snapshot)..excerpt_end.to_offset(snapshot),
- len_guess,
- ))
-}
-
fn prompt_for_events<'a>(
events: impl Iterator<Item = &'a Event>,
mut bytes_remaining: usize,
@@ -1671,6 +1521,7 @@ mod tests {
use gpui::TestAppContext;
use http_client::FakeHttpClient;
use indoc::indoc;
+ use language::Point;
use language_models::RefreshLlmTokenListener;
use rpc::proto;
use settings::SettingsStore;