diff --git a/crates/rpc/src/llm.rs b/crates/rpc/src/llm.rs index 93ac5bdee8493eb0725456cd990e37451e85e3fd..92cd6dccca22799b212b0a8a882f998654d53344 100644 --- a/crates/rpc/src/llm.rs +++ b/crates/rpc/src/llm.rs @@ -39,6 +39,7 @@ pub struct PredictEditsParams { pub outline: Option, pub input_events: String, pub input_excerpt: String, + pub speculated_output: String, /// Whether the user provided consent for sampling this interaction. #[serde(default)] pub data_collection_permission: bool, diff --git a/crates/zeta/src/input_excerpt.rs b/crates/zeta/src/input_excerpt.rs new file mode 100644 index 0000000000000000000000000000000000000000..103f03750a54f207a3e14137d6baa473616f8407 --- /dev/null +++ b/crates/zeta/src/input_excerpt.rs @@ -0,0 +1,238 @@ +use crate::{ + BYTES_PER_TOKEN_GUESS, CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER, + START_OF_FILE_MARKER, +}; +use language::{BufferSnapshot, Point}; +use std::{fmt::Write, ops::Range}; + +pub struct InputExcerpt { + pub editable_range: Range, + pub prompt: String, + pub speculated_output: String, +} + +pub fn excerpt_for_cursor_position( + position: Point, + path: &str, + snapshot: &BufferSnapshot, + editable_region_token_limit: usize, + context_token_limit: usize, +) -> InputExcerpt { + let mut scope_range = position..position; + let mut remaining_edit_tokens = editable_region_token_limit; + + while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) { + let parent_tokens = tokens_for_bytes(parent.byte_range().len()); + if parent_tokens <= editable_region_token_limit { + scope_range = Point::new( + parent.start_position().row as u32, + parent.start_position().column as u32, + ) + ..Point::new( + parent.end_position().row as u32, + parent.end_position().column as u32, + ); + remaining_edit_tokens = editable_region_token_limit - parent_tokens; + } else { + break; + } + } + + let editable_range = expand_range(snapshot, scope_range, remaining_edit_tokens); + let context_range = expand_range(snapshot, editable_range.clone(), context_token_limit); + + let mut prompt = String::new(); + let mut speculated_output = String::new(); + + writeln!(&mut prompt, "```{path}").unwrap(); + if context_range.start == Point::zero() { + writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap(); + } + + for chunk in snapshot.chunks(context_range.start..editable_range.start, false) { + prompt.push_str(chunk.text); + } + + push_editable_range(position, snapshot, editable_range.clone(), &mut prompt); + push_editable_range( + position, + snapshot, + editable_range.clone(), + &mut speculated_output, + ); + + for chunk in snapshot.chunks(editable_range.end..context_range.end, false) { + prompt.push_str(chunk.text); + } + write!(prompt, "\n```").unwrap(); + + InputExcerpt { + editable_range, + prompt, + speculated_output, + } +} + +fn push_editable_range( + cursor_position: Point, + snapshot: &BufferSnapshot, + editable_range: Range, + prompt: &mut String, +) { + writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap(); + for chunk in snapshot.chunks(editable_range.start..cursor_position, false) { + prompt.push_str(chunk.text); + } + prompt.push_str(CURSOR_MARKER); + for chunk in snapshot.chunks(cursor_position..editable_range.end, false) { + prompt.push_str(chunk.text); + } + write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap(); +} + +fn expand_range( + snapshot: &BufferSnapshot, + range: Range, + mut remaining_tokens: usize, +) -> Range { + let mut expanded_range = range.clone(); + expanded_range.start.column = 0; + expanded_range.end.column = snapshot.line_len(expanded_range.end.row); + loop { + let mut expanded = false; + + if remaining_tokens > 0 && expanded_range.start.row > 0 { + expanded_range.start.row -= 1; + let line_tokens = + tokens_for_bytes(snapshot.line_len(expanded_range.start.row) as usize); + remaining_tokens = remaining_tokens.saturating_sub(line_tokens); + expanded = true; + } + + if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row { + expanded_range.end.row += 1; + expanded_range.end.column = snapshot.line_len(expanded_range.end.row); + let line_tokens = tokens_for_bytes(expanded_range.end.column as usize); + remaining_tokens = remaining_tokens.saturating_sub(line_tokens); + expanded = true; + } + + if !expanded { + break; + } + } + expanded_range +} + +fn tokens_for_bytes(bytes: usize) -> usize { + bytes / BYTES_PER_TOKEN_GUESS +} + +#[cfg(test)] +mod tests { + use super::*; + use gpui::{App, AppContext}; + use indoc::indoc; + use language::{Buffer, Language, LanguageConfig, LanguageMatcher}; + use std::sync::Arc; + + #[gpui::test] + fn test_excerpt_for_cursor_position(cx: &mut App) { + let text = indoc! {r#" + fn foo() { + let x = 42; + println!("Hello, world!"); + } + + fn bar() { + let x = 42; + let mut sum = 0; + for i in 0..x { + sum += i; + } + println!("Sum: {}", sum); + return sum; + } + + fn generate_random_numbers() -> Vec { + let mut rng = rand::thread_rng(); + let mut numbers = Vec::new(); + for _ in 0..5 { + numbers.push(rng.gen_range(1..101)); + } + numbers + } + "#}; + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let snapshot = buffer.read(cx).snapshot(); + + // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion + // when a larger scope doesn't fit the editable region. + let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32); + assert_eq!( + excerpt.prompt, + indoc! {r#" + ```main.rs + let x = 42; + println!("Hello, world!"); + <|editable_region_start|> + } + + fn bar() { + let x = 42; + let mut sum = 0; + for i in 0..x { + sum += i; + } + println!("Sum: {}", sum); + r<|user_cursor_is_here|>eturn sum; + } + + fn generate_random_numbers() -> Vec { + <|editable_region_end|> + let mut rng = rand::thread_rng(); + let mut numbers = Vec::new(); + ```"#} + ); + + // The `bar` function won't fit within the editable region, so we resort to line-based expansion. + let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32); + assert_eq!( + excerpt.prompt, + indoc! {r#" + ```main.rs + fn bar() { + let x = 42; + let mut sum = 0; + <|editable_region_start|> + for i in 0..x { + sum += i; + } + println!("Sum: {}", sum); + r<|user_cursor_is_here|>eturn sum; + } + + fn generate_random_numbers() -> Vec { + let mut rng = rand::thread_rng(); + <|editable_region_end|> + let mut numbers = Vec::new(); + for _ in 0..5 { + numbers.push(rng.gen_range(1..101)); + ```"#} + ); + } + + fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + } +} diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 1611614ca3aa70aad67de813447ef255f968995e..a3c178d591c5b2e5244c5c6caf427abd47471af7 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -1,5 +1,6 @@ mod completion_diff_element; mod init; +mod input_excerpt; mod license_detection; mod onboarding_banner; mod onboarding_modal; @@ -25,7 +26,7 @@ use gpui::{ use http_client::{HttpClient, Method}; use language::{ language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, EditPreview, - OffsetRangeExt, Point, ToOffset, ToPoint, + OffsetRangeExt, ToOffset, ToPoint, }; use language_models::LlmApiToken; use postage::watch; @@ -61,26 +62,26 @@ const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_ch /// intentionally low to err on the side of underestimating limits. const BYTES_PER_TOKEN_GUESS: usize = 3; -/// Output token limit, used to inform the size of the input. A copy of this constant is also in +/// Input token limit, used to inform the size of the input. A copy of this constant is also in /// `crates/collab/src/llm.rs`. -const MAX_OUTPUT_TOKENS: usize = 2048; +const MAX_INPUT_TOKENS: usize = 2048; + +const MAX_CONTEXT_TOKENS: usize = 64; +const MAX_OUTPUT_TOKENS: usize = 256; /// Total bytes limit for editable region of buffer excerpt. /// /// The number of output tokens is relevant to the size of the input excerpt because the model is /// tasked with outputting a modified excerpt. `2/3` is chosen so that there are some output tokens /// remaining for the model to specify insertions. -const BUFFER_EXCERPT_BYTE_LIMIT: usize = (MAX_OUTPUT_TOKENS * 2 / 3) * BYTES_PER_TOKEN_GUESS; - -/// Total line limit for editable region of buffer excerpt. -const BUFFER_EXCERPT_LINE_LIMIT: u32 = 64; +const BUFFER_EXCERPT_BYTE_LIMIT: usize = (MAX_INPUT_TOKENS * 2 / 3) * BYTES_PER_TOKEN_GUESS; /// Note that this is not the limit for the overall prompt, just for the inputs to the template /// instantiated in `crates/collab/src/llm.rs`. const TOTAL_BYTE_LIMIT: usize = BUFFER_EXCERPT_BYTE_LIMIT * 2; /// Maximum number of events to include in the prompt. -const MAX_EVENT_COUNT: usize = 16; +const MAX_EVENT_COUNT: usize = 8; /// Maximum number of string bytes in a single event. Arbitrarily choosing this to be 4x the size of /// equally splitting up the the remaining bytes after the largest possible buffer excerpt. @@ -373,8 +374,8 @@ impl Zeta { R: Future> + Send + 'static, { let snapshot = self.report_changes_for_buffer(&buffer, cx); - let cursor_point = cursor.to_point(&snapshot); - let cursor_offset = cursor_point.to_offset(&snapshot); + let cursor_position = cursor.to_point(&snapshot); + let cursor_offset = cursor_position.to_offset(&snapshot); let events = self.events.clone(); let path: Arc = snapshot .file() @@ -389,45 +390,47 @@ impl Zeta { cx.spawn(|_, cx| async move { let request_sent_at = Instant::now(); - let (input_events, input_excerpt, excerpt_range, input_outline) = cx - .background_executor() - .spawn({ - let snapshot = snapshot.clone(); - let path = path.clone(); - async move { - let path = path.to_string_lossy(); - let (excerpt_range, excerpt_len_guess) = excerpt_range_for_position( - cursor_point, - BUFFER_EXCERPT_BYTE_LIMIT, - BUFFER_EXCERPT_LINE_LIMIT, - &path, - &snapshot, - )?; - let input_excerpt = prompt_for_excerpt( - cursor_offset, - &excerpt_range, - excerpt_len_guess, - &path, - &snapshot, - ); - - let bytes_remaining = TOTAL_BYTE_LIMIT.saturating_sub(input_excerpt.len()); - let input_events = prompt_for_events(events.iter(), bytes_remaining); - - // Note that input_outline is not currently used in prompt generation and so - // is not counted towards TOTAL_BYTE_LIMIT. - let input_outline = prompt_for_outline(&snapshot); - - anyhow::Ok((input_events, input_excerpt, excerpt_range, input_outline)) - } - }) - .await?; + let (input_events, input_excerpt, editable_range, input_outline, speculated_output) = + cx.background_executor() + .spawn({ + let snapshot = snapshot.clone(); + let path = path.clone(); + async move { + let path = path.to_string_lossy(); + let input_excerpt = input_excerpt::excerpt_for_cursor_position( + cursor_position, + &path, + &snapshot, + MAX_OUTPUT_TOKENS, + MAX_CONTEXT_TOKENS, + ); + + let bytes_remaining = + TOTAL_BYTE_LIMIT.saturating_sub(input_excerpt.prompt.len()); + let input_events = prompt_for_events(events.iter(), bytes_remaining); + + // Note that input_outline is not currently used in prompt generation and so + // is not counted towards TOTAL_BYTE_LIMIT. + let input_outline = prompt_for_outline(&snapshot); + + let editable_range = input_excerpt.editable_range.to_offset(&snapshot); + anyhow::Ok(( + input_events, + input_excerpt.prompt, + editable_range, + input_outline, + input_excerpt.speculated_output, + )) + } + }) + .await?; log::debug!("Events:\n{}\nExcerpt:\n{}", input_events, input_excerpt); let body = PredictEditsParams { input_events: input_events.clone(), input_excerpt: input_excerpt.clone(), + speculated_output, outline: Some(input_outline.clone()), data_collection_permission, }; @@ -441,7 +444,7 @@ impl Zeta { output_excerpt, buffer, &snapshot, - excerpt_range, + editable_range, cursor_offset, path, input_outline, @@ -457,6 +460,8 @@ impl Zeta { // Generates several example completions of various states to fill the Zeta completion modal #[cfg(any(test, feature = "test-support"))] pub fn fill_with_fake_completions(&mut self, cx: &mut Context) -> Task<()> { + use language::Point; + let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line And maybe a short line @@ -675,7 +680,7 @@ and then another output_excerpt: String, buffer: Entity, snapshot: &BufferSnapshot, - excerpt_range: Range, + editable_range: Range, cursor_offset: usize, path: Arc, input_outline: String, @@ -692,9 +697,9 @@ and then another .background_executor() .spawn({ let output_excerpt = output_excerpt.clone(); - let excerpt_range = excerpt_range.clone(); + let editable_range = editable_range.clone(); let snapshot = snapshot.clone(); - async move { Self::parse_edits(output_excerpt, excerpt_range, &snapshot) } + async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) } }) .await? .into(); @@ -717,7 +722,7 @@ and then another Ok(Some(InlineCompletion { id: InlineCompletionId::new(), path, - excerpt_range, + excerpt_range: editable_range, cursor_offset, edits, edit_preview, @@ -734,7 +739,7 @@ and then another fn parse_edits( output_excerpt: Arc, - excerpt_range: Range, + editable_range: Range, snapshot: &BufferSnapshot, ) -> Result, String)>> { let content = output_excerpt.replace(CURSOR_MARKER, ""); @@ -778,13 +783,13 @@ and then another let new_text = &content[..codefence_end]; let old_text = snapshot - .text_for_range(excerpt_range.clone()) + .text_for_range(editable_range.clone()) .collect::(); Ok(Self::compute_edits( old_text, new_text, - excerpt_range.start, + editable_range.start, &snapshot, )) } @@ -1011,161 +1016,6 @@ fn prompt_for_outline(snapshot: &BufferSnapshot) -> String { input_outline } -fn prompt_for_excerpt( - offset: usize, - excerpt_range: &Range, - mut len_guess: usize, - path: &str, - snapshot: &BufferSnapshot, -) -> String { - let point_range = excerpt_range.to_point(snapshot); - - // Include one line of extra context before and after editable range, if those lines are non-empty. - let extra_context_before_range = - if point_range.start.row > 0 && !snapshot.is_line_blank(point_range.start.row - 1) { - let range = - (Point::new(point_range.start.row - 1, 0)..point_range.start).to_offset(snapshot); - len_guess += range.end - range.start; - Some(range) - } else { - None - }; - let extra_context_after_range = if point_range.end.row < snapshot.max_point().row - && !snapshot.is_line_blank(point_range.end.row + 1) - { - let range = (point_range.end - ..Point::new( - point_range.end.row + 1, - snapshot.line_len(point_range.end.row + 1), - )) - .to_offset(snapshot); - len_guess += range.end - range.start; - Some(range) - } else { - None - }; - - let mut prompt_excerpt = String::with_capacity(len_guess); - writeln!(prompt_excerpt, "```{}", path).unwrap(); - - if excerpt_range.start == 0 { - writeln!(prompt_excerpt, "{START_OF_FILE_MARKER}").unwrap(); - } - - if let Some(extra_context_before_range) = extra_context_before_range { - for chunk in snapshot.text_for_range(extra_context_before_range) { - prompt_excerpt.push_str(chunk); - } - } - writeln!(prompt_excerpt, "{EDITABLE_REGION_START_MARKER}").unwrap(); - for chunk in snapshot.text_for_range(excerpt_range.start..offset) { - prompt_excerpt.push_str(chunk); - } - prompt_excerpt.push_str(CURSOR_MARKER); - for chunk in snapshot.text_for_range(offset..excerpt_range.end) { - prompt_excerpt.push_str(chunk); - } - write!(prompt_excerpt, "\n{EDITABLE_REGION_END_MARKER}").unwrap(); - - if let Some(extra_context_after_range) = extra_context_after_range { - for chunk in snapshot.text_for_range(extra_context_after_range) { - prompt_excerpt.push_str(chunk); - } - } - - write!(prompt_excerpt, "\n```").unwrap(); - debug_assert!( - prompt_excerpt.len() <= len_guess, - "Excerpt length {} exceeds estimated length {}", - prompt_excerpt.len(), - len_guess - ); - prompt_excerpt -} - -fn excerpt_range_for_position( - cursor_point: Point, - byte_limit: usize, - line_limit: u32, - path: &str, - snapshot: &BufferSnapshot, -) -> Result<(Range, usize)> { - let cursor_row = cursor_point.row; - let last_buffer_row = snapshot.max_point().row; - - // This is an overestimate because it includes parts of prompt_for_excerpt which are - // conditionally skipped. - let mut len_guess = 0; - len_guess += "```".len() + path.len() + 1; - len_guess += START_OF_FILE_MARKER.len() + 1; - len_guess += EDITABLE_REGION_START_MARKER.len() + 1; - len_guess += CURSOR_MARKER.len(); - len_guess += EDITABLE_REGION_END_MARKER.len() + 1; - len_guess += "```".len() + 1; - - len_guess += usize::try_from(snapshot.line_len(cursor_row) + 1).unwrap(); - - if len_guess > byte_limit { - return Err(anyhow!("Current line too long to send to model.")); - } - - let mut excerpt_start_row = cursor_row; - let mut excerpt_end_row = cursor_row; - let mut no_more_before = cursor_row == 0; - let mut no_more_after = cursor_row >= last_buffer_row; - let mut row_delta = 1; - loop { - if !no_more_before { - let row = cursor_point.row - row_delta; - let line_len: usize = usize::try_from(snapshot.line_len(row) + 1).unwrap(); - let mut new_len_guess = len_guess + line_len; - if row == 0 { - new_len_guess += START_OF_FILE_MARKER.len() + 1; - } - if new_len_guess <= byte_limit { - len_guess = new_len_guess; - excerpt_start_row = row; - if row == 0 { - no_more_before = true; - } - } else { - no_more_before = true; - } - } - if excerpt_end_row - excerpt_start_row >= line_limit { - break; - } - if !no_more_after { - let row = cursor_point.row + row_delta; - let line_len: usize = usize::try_from(snapshot.line_len(row) + 1).unwrap(); - let new_len_guess = len_guess + line_len; - if new_len_guess <= byte_limit { - len_guess = new_len_guess; - excerpt_end_row = row; - if row >= last_buffer_row { - no_more_after = true; - } - } else { - no_more_after = true; - } - } - if excerpt_end_row - excerpt_start_row >= line_limit { - break; - } - if no_more_before && no_more_after { - break; - } - row_delta += 1; - } - - let excerpt_start = Point::new(excerpt_start_row, 0); - let excerpt_end = Point::new(excerpt_end_row, snapshot.line_len(excerpt_end_row)); - Ok(( - excerpt_start.to_offset(snapshot)..excerpt_end.to_offset(snapshot), - len_guess, - )) -} - fn prompt_for_events<'a>( events: impl Iterator, mut bytes_remaining: usize, @@ -1671,6 +1521,7 @@ mod tests { use gpui::TestAppContext; use http_client::FakeHttpClient; use indoc::indoc; + use language::Point; use language_models::RefreshLlmTokenListener; use rpc::proto; use settings::SettingsStore;