diff --git a/assets/prompts/content_prompt.hbs b/assets/prompts/content_prompt.hbs index cd429e13048eb483354c1c03e9af0bc3559c3b2c..107b6be0425d549a53279f1821926d37d526802b 100644 --- a/assets/prompts/content_prompt.hbs +++ b/assets/prompts/content_prompt.hbs @@ -1,14 +1,12 @@ +Here's a text file that I'm going to ask you to make an edit to. + {{#if language_name}} -Here's a file of {{language_name}} that I'm going to ask you to make an edit to. -{{else}} -Here's a file of text that I'm going to ask you to make an edit to. +The file is in {{language_name}}. {{/if}} -{{#if is_insert}} -The point you'll need to insert at is marked with . -{{else}} -The section you'll need to rewrite is marked with tags. -{{/if}} +You need to rewrite a portion of it. + +The section you'll need to edit is marked with tags. {{{document_content}}} @@ -18,44 +16,37 @@ The section you'll need to rewrite is marked with The context around the relevant section has been truncated (possibly in the middle of a line) for brevity. {{/if}} -{{#if is_insert}} -You can't replace {{content_type}}, your answer will be inserted in place of the `` tags. Don't include the insert_here tags in your output. - -Generate {{content_type}} based on the following prompt: +Rewrite the section of {{content_type}} in tags based on the following prompt: {{{user_prompt}}} -Match the indentation in the original file in the inserted {{content_type}}, don't include any indentation on blank lines. - -Immediately start with the following format with no remarks: +Here's the section to edit based on that prompt again for reference: -``` -{{INSERTED_CODE}} -``` -{{else}} -Edit the section of {{content_type}} in tags based on the following prompt: + +{{{rewrite_section}}} + - -{{{user_prompt}}} - +You'll rewrite this entire section, but you will only make changes within certain subsections. -{{#if rewrite_section}} -And here's the section to rewrite based on that prompt again for reference: +{{#if has_insertion}} +Insert text anywhere you see it marked with with tags. Do not include tags in your output. +{{/if}} +{{#if has_replacement}} +Edit edit text that you see surrounded with tags. Do not include tags in your output. +{{/if}} -{{{rewrite_section}}} +{{{rewrite_section_with_selections}}} -{{/if}} -Only make changes that are necessary to fulfill the prompt, leave everything else as-is. All surrounding {{content_type}} will be preserved. +Only make changes that are necessary to fulfill the prompt, leave everything else as-is. All surrounding {{content_type}} will be preserved. Do not output the tags or anything outside of them. -Start at the indentation level in the original file in the rewritten {{content_type}}. Don't stop until you've rewritten the entire section, even if you have no more changes to make, always write out the whole section with no unnecessary elisions. +Start at the indentation level in the original file in the rewritten {{content_type}}. Don't stop until you've rewritten the entire section, even if you have no more changes to make. Always write out the whole section with no unnecessary elisions. Immediately start with the following format with no remarks: ``` -{{REWRITTEN_CODE}} +\{{REWRITTEN_CODE}} ``` -{{/if}} diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 77c9a381c64cd751a959cd05f3cb5e0a571620eb..5cf3c39b98af46a3995ad59fd3915bd3a309263d 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -30,6 +30,7 @@ use language_model::{ }; pub(crate) use model_selector::*; pub use prompts::PromptBuilder; +use prompts::PromptOverrideContext; use semantic_index::{CloudEmbeddingProvider, SemanticIndex}; use serde::{Deserialize, Serialize}; use settings::{update_settings_file, Settings, SettingsStore}; @@ -168,7 +169,12 @@ impl Assistant { } } -pub fn init(fs: Arc, client: Arc, cx: &mut AppContext) -> Arc { +pub fn init( + fs: Arc, + client: Arc, + dev_mode: bool, + cx: &mut AppContext, +) -> Arc { cx.set_global(Assistant::default()); AssistantSettings::register(cx); @@ -203,10 +209,14 @@ pub fn init(fs: Arc, client: Arc, cx: &mut AppContext) -> Arc>::new(); - let mut newest_selection = None; - for mut selection in editor.read(cx).selections.all::(cx) { - if selection.end > selection.start { - selection.start.column = 0; - // If the selection ends at the start of the line, we don't want to include it. - if selection.end.column == 0 { - selection.end.row -= 1; - } - selection.end.column = snapshot.line_len(MultiBufferRow(selection.end.row)); - } + struct CodegenRange { + transform_range: Range, + selection_ranges: Vec>, + focus_assist: bool, + } - if let Some(prev_selection) = selections.last_mut() { - if selection.start <= prev_selection.end { - prev_selection.end = selection.end; + let newest_selection = editor.read(cx).selections.newest::(cx); + let mut codegen_ranges: Vec = Vec::new(); + for selection in editor.read(cx).selections.all::(cx) { + let selection_is_newest = selection.id == newest_selection.id; + let mut transform_range = selection.start..selection.end; + + // Expand the transform range to start/end of lines. + // If a non-empty selection ends at the start of the last line, clip at the end of the penultimate line. + transform_range.start.column = 0; + if transform_range.end.column == 0 && transform_range.end > transform_range.start { + transform_range.end.row -= 1; + } + transform_range.end.column = snapshot.line_len(MultiBufferRow(transform_range.end.row)); + let selection_range = selection.start..selection.end.min(transform_range.end); + + // If we intersect the previous transform range, + if let Some(CodegenRange { + transform_range: prev_transform_range, + selection_ranges, + focus_assist, + }) = codegen_ranges.last_mut() + { + if transform_range.start <= prev_transform_range.end { + prev_transform_range.end = transform_range.end; + selection_ranges.push(selection_range); + *focus_assist |= selection_is_newest; continue; } } - let latest_selection = newest_selection.get_or_insert_with(|| selection.clone()); - if selection.id > latest_selection.id { - *latest_selection = selection.clone(); - } - selections.push(selection); - } - let newest_selection = newest_selection.unwrap(); - - let mut codegen_ranges = Vec::new(); - for (excerpt_id, buffer, buffer_range) in - snapshot.excerpts_in_ranges(selections.iter().map(|selection| { - snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end) - })) - { - let start = Anchor { - buffer_id: Some(buffer.remote_id()), - excerpt_id, - text_anchor: buffer.anchor_before(buffer_range.start), - }; - let end = Anchor { - buffer_id: Some(buffer.remote_id()), - excerpt_id, - text_anchor: buffer.anchor_after(buffer_range.end), - }; - codegen_ranges.push(start..end); + codegen_ranges.push(CodegenRange { + transform_range, + selection_ranges: vec![selection_range], + focus_assist: selection_is_newest, + }) } let assist_group_id = self.next_assist_group_id.post_inc(); let prompt_buffer = cx.new_model(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx)); let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx)); - let mut assists = Vec::new(); let mut assist_to_focus = None; - for range in codegen_ranges { - let assist_id = self.next_assist_id.post_inc(); + + for CodegenRange { + transform_range, + selection_ranges, + focus_assist, + } in codegen_ranges + { + let transform_range = snapshot.anchor_before(transform_range.start) + ..snapshot.anchor_after(transform_range.end); + let selection_ranges = selection_ranges + .iter() + .map(|range| snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end)) + .collect::>(); + let codegen = cx.new_model(|cx| { Codegen::new( editor.read(cx).buffer().clone(), - range.clone(), + transform_range.clone(), + selection_ranges, None, self.telemetry.clone(), self.prompt_builder.clone(), @@ -206,6 +215,7 @@ impl InlineAssistant { ) }); + let assist_id = self.next_assist_id.post_inc(); let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default())); let prompt_editor = cx.new_view(|cx| { PromptEditor::new( @@ -222,23 +232,16 @@ impl InlineAssistant { ) }); - if assist_to_focus.is_none() { - let focus_assist = if newest_selection.reversed { - range.start.to_point(&snapshot) == newest_selection.start - } else { - range.end.to_point(&snapshot) == newest_selection.end - }; - if focus_assist { - assist_to_focus = Some(assist_id); - } + if focus_assist { + assist_to_focus = Some(assist_id); } let [prompt_block_id, end_block_id] = - self.insert_assist_blocks(editor, &range, &prompt_editor, cx); + self.insert_assist_blocks(editor, &transform_range, &prompt_editor, cx); assists.push(( assist_id, - range, + transform_range, prompt_editor, prompt_block_id, end_block_id, @@ -305,6 +308,7 @@ impl InlineAssistant { Codegen::new( editor.read(cx).buffer().clone(), range.clone(), + vec![range.clone()], initial_transaction_id, self.telemetry.clone(), self.prompt_builder.clone(), @@ -888,12 +892,7 @@ impl InlineAssistant { assist .codegen .update(cx, |codegen, cx| { - codegen.start( - assist.range.clone(), - user_prompt, - assistant_panel_context, - cx, - ) + codegen.start(user_prompt, assistant_panel_context, cx) }) .log_err(); @@ -2084,12 +2083,9 @@ impl InlineAssist { return future::ready(Err(anyhow!("no user prompt"))).boxed(); }; let assistant_panel_context = self.assistant_panel_context(cx); - self.codegen.read(cx).count_tokens( - self.range.clone(), - user_prompt, - assistant_panel_context, - cx, - ) + self.codegen + .read(cx) + .count_tokens(user_prompt, assistant_panel_context, cx) } } @@ -2110,6 +2106,8 @@ pub struct Codegen { buffer: Model, old_buffer: Model, snapshot: MultiBufferSnapshot, + transform_range: Range, + selected_ranges: Vec>, edit_position: Option, last_equal_ranges: Vec>, initial_transaction_id: Option, @@ -2119,7 +2117,7 @@ pub struct Codegen { diff: Diff, telemetry: Option>, _subscription: gpui::Subscription, - builder: Arc, + prompt_builder: Arc, } enum CodegenStatus { @@ -2146,7 +2144,8 @@ impl EventEmitter for Codegen {} impl Codegen { pub fn new( buffer: Model, - range: Range, + transform_range: Range, + selected_ranges: Vec>, initial_transaction_id: Option, telemetry: Option>, builder: Arc, @@ -2156,7 +2155,7 @@ impl Codegen { let (old_buffer, _, _) = buffer .read(cx) - .range_to_buffer_ranges(range.clone(), cx) + .range_to_buffer_ranges(transform_range.clone(), cx) .pop() .unwrap(); let old_buffer = cx.new_model(|cx| { @@ -2187,7 +2186,9 @@ impl Codegen { telemetry, _subscription: cx.subscribe(&buffer, Self::handle_buffer_event), initial_transaction_id, - builder, + prompt_builder: builder, + transform_range, + selected_ranges, } } @@ -2212,13 +2213,12 @@ impl Codegen { pub fn count_tokens( &self, - edit_range: Range, user_prompt: String, assistant_panel_context: Option, cx: &AppContext, ) -> BoxFuture<'static, Result> { if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() { - let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx); + let request = self.build_request(user_prompt, assistant_panel_context, cx); match request { Ok(request) => model.count_tokens(request, cx), Err(error) => futures::future::ready(Err(error)).boxed(), @@ -2230,7 +2230,6 @@ impl Codegen { pub fn start( &mut self, - edit_range: Range, user_prompt: String, assistant_panel_context: Option, cx: &mut ModelContext, @@ -2245,24 +2244,20 @@ impl Codegen { }); } - self.edit_position = Some(edit_range.start.bias_right(&self.snapshot)); + self.edit_position = Some(self.transform_range.start.bias_right(&self.snapshot)); let telemetry_id = model.telemetry_id(); - let chunks: LocalBoxFuture>>> = if user_prompt - .trim() - .to_lowercase() - == "delete" - { - async { Ok(stream::empty().boxed()) }.boxed_local() - } else { - let request = - self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx)?; + let chunks: LocalBoxFuture>>> = + if user_prompt.trim().to_lowercase() == "delete" { + async { Ok(stream::empty().boxed()) }.boxed_local() + } else { + let request = self.build_request(user_prompt, assistant_panel_context, cx)?; - let chunks = - cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await }); - async move { Ok(chunks.await?.boxed()) }.boxed_local() - }; - self.handle_stream(telemetry_id, edit_range, chunks, cx); + let chunks = + cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await }); + async move { Ok(chunks.await?.boxed()) }.boxed_local() + }; + self.handle_stream(telemetry_id, self.transform_range.clone(), chunks, cx); Ok(()) } @@ -2270,11 +2265,10 @@ impl Codegen { &self, user_prompt: String, assistant_panel_context: Option, - edit_range: Range, cx: &AppContext, ) -> Result { let buffer = self.buffer.read(cx).snapshot(cx); - let language = buffer.language_at(edit_range.start); + let language = buffer.language_at(self.transform_range.start); let language_name = if let Some(language) = language.as_ref() { if Arc::ptr_eq(language, &language::PLAIN_TEXT) { None @@ -2299,8 +2293,8 @@ impl Codegen { }; let language_name = language_name.as_deref(); - let start = buffer.point_to_buffer_offset(edit_range.start); - let end = buffer.point_to_buffer_offset(edit_range.end); + let start = buffer.point_to_buffer_offset(self.transform_range.start); + let end = buffer.point_to_buffer_offset(self.transform_range.end); let (buffer, range) = if let Some((start, end)) = start.zip(end) { let (start_buffer, start_buffer_offset) = start; let (end_buffer, end_buffer_offset) = end; @@ -2312,9 +2306,20 @@ impl Codegen { } else { return Err(anyhow::anyhow!("invalid transformation range")); }; + + let selected_ranges = self + .selected_ranges + .iter() + .map(|range| { + let start = range.start.text_anchor.to_offset(&buffer); + let end = range.end.text_anchor.to_offset(&buffer); + start..end + }) + .collect::>(); + let prompt = self - .builder - .generate_content_prompt(user_prompt, language_name, buffer, range) + .prompt_builder + .generate_content_prompt(user_prompt, language_name, buffer, range, selected_ranges) .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?; let mut messages = Vec::new(); @@ -2386,84 +2391,19 @@ impl Codegen { let mut diff = StreamingDiff::new(selected_text.to_string()); let mut line_diff = LineDiff::default(); - let mut new_text = String::new(); - let mut base_indent = None; - let mut line_indent = None; - let mut first_line = true; - while let Some(chunk) = chunks.next().await { if response_latency.is_none() { response_latency = Some(request_start.elapsed()); } let chunk = chunk?; - - let mut lines = chunk.split('\n').peekable(); - while let Some(line) = lines.next() { - new_text.push_str(line); - if line_indent.is_none() { - if let Some(non_whitespace_ch_ix) = - new_text.find(|ch: char| !ch.is_whitespace()) - { - line_indent = Some(non_whitespace_ch_ix); - base_indent = base_indent.or(line_indent); - - let line_indent = line_indent.unwrap(); - let base_indent = base_indent.unwrap(); - let indent_delta = - line_indent as i32 - base_indent as i32; - let mut corrected_indent_len = cmp::max( - 0, - suggested_line_indent.len as i32 + indent_delta, - ) - as usize; - if first_line { - corrected_indent_len = corrected_indent_len - .saturating_sub( - selection_start.column as usize, - ); - } - - let indent_char = suggested_line_indent.char(); - let mut indent_buffer = [0; 4]; - let indent_str = - indent_char.encode_utf8(&mut indent_buffer); - new_text.replace_range( - ..line_indent, - &indent_str.repeat(corrected_indent_len), - ); - } - } - - if line_indent.is_some() { - let char_ops = diff.push_new(&new_text); - line_diff - .push_char_operations(&char_ops, &selected_text); - diff_tx - .send((char_ops, line_diff.line_operations())) - .await?; - new_text.clear(); - } - - if lines.peek().is_some() { - let char_ops = diff.push_new("\n"); - line_diff - .push_char_operations(&char_ops, &selected_text); - diff_tx - .send((char_ops, line_diff.line_operations())) - .await?; - if line_indent.is_none() { - // Don't write out the leading indentation in empty lines on the next line - // This is the case where the above if statement didn't clear the buffer - new_text.clear(); - } - line_indent = None; - first_line = false; - } - } + let char_ops = diff.push_new(&chunk); + line_diff.push_char_operations(&char_ops, &selected_text); + diff_tx + .send((char_ops, line_diff.line_operations())) + .await?; } - let mut char_ops = diff.push_new(&new_text); - char_ops.extend(diff.finish()); + let char_ops = diff.finish(); line_diff.push_char_operations(&char_ops, &selected_text); line_diff.finish(&selected_text); diff_tx @@ -2824,311 +2764,13 @@ fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) { mod tests { use super::*; use futures::stream::{self}; - use gpui::{Context, TestAppContext}; - use indoc::indoc; - use language::{ - language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher, - Point, - }; - use language_model::LanguageModelRegistry; - use rand::prelude::*; use serde::Serialize; - use settings::SettingsStore; - use std::{future, sync::Arc}; #[derive(Serialize)] pub struct DummyCompletionRequest { pub name: String, } - #[gpui::test(iterations = 10)] - async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) { - cx.set_global(cx.update(SettingsStore::test)); - cx.update(language_model::LanguageModelRegistry::test); - cx.update(language_settings::init); - - let text = indoc! {" - fn main() { - let x = 0; - for _ in 0..10 { - x += 1; - } - } - "}; - let buffer = - cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); - let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx)); - let range = buffer.read_with(cx, |buffer, cx| { - let snapshot = buffer.snapshot(cx); - snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5)) - }); - let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let codegen = cx.new_model(|cx| { - Codegen::new( - buffer.clone(), - range.clone(), - None, - None, - prompt_builder, - cx, - ) - }); - - let (chunks_tx, chunks_rx) = mpsc::unbounded(); - codegen.update(cx, |codegen, cx| { - codegen.handle_stream( - String::new(), - range, - future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())), - cx, - ) - }); - - let mut new_text = concat!( - " let mut x = 0;\n", - " while x < 10 {\n", - " x += 1;\n", - " }", - ); - while !new_text.is_empty() { - let max_len = cmp::min(new_text.len(), 10); - let len = rng.gen_range(1..=max_len); - let (chunk, suffix) = new_text.split_at(len); - chunks_tx.unbounded_send(chunk.to_string()).unwrap(); - new_text = suffix; - cx.background_executor.run_until_parked(); - } - drop(chunks_tx); - cx.background_executor.run_until_parked(); - - assert_eq!( - buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), - indoc! {" - fn main() { - let mut x = 0; - while x < 10 { - x += 1; - } - } - "} - ); - } - - #[gpui::test(iterations = 10)] - async fn test_autoindent_when_generating_past_indentation( - cx: &mut TestAppContext, - mut rng: StdRng, - ) { - cx.set_global(cx.update(SettingsStore::test)); - cx.update(language_settings::init); - - let text = indoc! {" - fn main() { - le - } - "}; - let buffer = - cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); - let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx)); - let range = buffer.read_with(cx, |buffer, cx| { - let snapshot = buffer.snapshot(cx); - snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6)) - }); - let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let codegen = cx.new_model(|cx| { - Codegen::new( - buffer.clone(), - range.clone(), - None, - None, - prompt_builder, - cx, - ) - }); - - let (chunks_tx, chunks_rx) = mpsc::unbounded(); - codegen.update(cx, |codegen, cx| { - codegen.handle_stream( - String::new(), - range.clone(), - future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())), - cx, - ) - }); - - cx.background_executor.run_until_parked(); - - let mut new_text = concat!( - "t mut x = 0;\n", - "while x < 10 {\n", - " x += 1;\n", - "}", // - ); - while !new_text.is_empty() { - let max_len = cmp::min(new_text.len(), 10); - let len = rng.gen_range(1..=max_len); - let (chunk, suffix) = new_text.split_at(len); - chunks_tx.unbounded_send(chunk.to_string()).unwrap(); - new_text = suffix; - cx.background_executor.run_until_parked(); - } - drop(chunks_tx); - cx.background_executor.run_until_parked(); - - assert_eq!( - buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), - indoc! {" - fn main() { - let mut x = 0; - while x < 10 { - x += 1; - } - } - "} - ); - } - - #[gpui::test(iterations = 10)] - async fn test_autoindent_when_generating_before_indentation( - cx: &mut TestAppContext, - mut rng: StdRng, - ) { - cx.update(LanguageModelRegistry::test); - cx.set_global(cx.update(SettingsStore::test)); - cx.update(language_settings::init); - - let text = concat!( - "fn main() {\n", - " \n", - "}\n" // - ); - let buffer = - cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); - let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx)); - let range = buffer.read_with(cx, |buffer, cx| { - let snapshot = buffer.snapshot(cx); - snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2)) - }); - let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let codegen = cx.new_model(|cx| { - Codegen::new( - buffer.clone(), - range.clone(), - None, - None, - prompt_builder, - cx, - ) - }); - - let (chunks_tx, chunks_rx) = mpsc::unbounded(); - codegen.update(cx, |codegen, cx| { - codegen.handle_stream( - String::new(), - range.clone(), - future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())), - cx, - ) - }); - - cx.background_executor.run_until_parked(); - - let mut new_text = concat!( - "let mut x = 0;\n", - "while x < 10 {\n", - " x += 1;\n", - "}", // - ); - while !new_text.is_empty() { - let max_len = cmp::min(new_text.len(), 10); - let len = rng.gen_range(1..=max_len); - let (chunk, suffix) = new_text.split_at(len); - chunks_tx.unbounded_send(chunk.to_string()).unwrap(); - new_text = suffix; - cx.background_executor.run_until_parked(); - } - drop(chunks_tx); - cx.background_executor.run_until_parked(); - - assert_eq!( - buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), - indoc! {" - fn main() { - let mut x = 0; - while x < 10 { - x += 1; - } - } - "} - ); - } - - #[gpui::test(iterations = 10)] - async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) { - cx.update(LanguageModelRegistry::test); - cx.set_global(cx.update(SettingsStore::test)); - cx.update(language_settings::init); - - let text = indoc! {" - func main() { - \tx := 0 - \tfor i := 0; i < 10; i++ { - \t\tx++ - \t} - } - "}; - let buffer = cx.new_model(|cx| Buffer::local(text, cx)); - let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx)); - let range = buffer.read_with(cx, |buffer, cx| { - let snapshot = buffer.snapshot(cx); - snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2)) - }); - let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let codegen = cx.new_model(|cx| { - Codegen::new( - buffer.clone(), - range.clone(), - None, - None, - prompt_builder, - cx, - ) - }); - - let (chunks_tx, chunks_rx) = mpsc::unbounded(); - codegen.update(cx, |codegen, cx| { - codegen.handle_stream( - String::new(), - range.clone(), - future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())), - cx, - ) - }); - - let new_text = concat!( - "func main() {\n", - "\tx := 0\n", - "\tfor x < 10 {\n", - "\t\tx++\n", - "\t}", // - ); - chunks_tx.unbounded_send(new_text.to_string()).unwrap(); - drop(chunks_tx); - cx.background_executor.run_until_parked(); - - assert_eq!( - buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), - indoc! {" - func main() { - \tx := 0 - \tfor x < 10 { - \t\tx++ - \t} - } - "} - ); - } - #[gpui::test] async fn test_strip_invalid_spans_from_codeblock() { assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await; @@ -3168,27 +2810,4 @@ mod tests { ) } } - - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::language()), - ) - .with_indents_query( - r#" - (call_expression) @indent - (field_expression) @indent - (_ "(" ")" @end) @indent - (_ "{" "}" @end) @indent - "#, - ) - .unwrap() - } } diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index c66c929fd182ce050f03092b97473b6fe9049d08..4bbab19a9ac767a85ed959fa2ab177f0abfc0d88 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -12,11 +12,13 @@ use util::ResultExt; pub struct ContentPromptContext { pub content_type: String, pub language_name: Option, - pub is_insert: bool, pub is_truncated: bool, pub document_content: String, pub user_prompt: String, - pub rewrite_section: Option, + pub rewrite_section: String, + pub rewrite_section_with_selections: String, + pub has_insertion: bool, + pub has_replacement: bool, } #[derive(Serialize)] @@ -33,41 +35,54 @@ pub struct PromptBuilder { handlebars: Arc>>, } +pub struct PromptOverrideContext<'a> { + pub dev_mode: bool, + pub fs: Arc, + pub cx: &'a mut gpui::AppContext, +} + impl PromptBuilder { - pub fn new( - fs_and_cx: Option<(Arc, &gpui::AppContext)>, - ) -> Result> { + pub fn new(override_cx: Option) -> Result> { let mut handlebars = Handlebars::new(); Self::register_templates(&mut handlebars)?; let handlebars = Arc::new(Mutex::new(handlebars)); - if let Some((fs, cx)) = fs_and_cx { - Self::watch_fs_for_template_overrides(fs, cx, handlebars.clone()); + if let Some(override_cx) = override_cx { + Self::watch_fs_for_template_overrides(override_cx, handlebars.clone()); } Ok(Self { handlebars }) } fn watch_fs_for_template_overrides( - fs: Arc, - cx: &gpui::AppContext, + PromptOverrideContext { dev_mode, fs, cx }: PromptOverrideContext, handlebars: Arc>>, ) { - let templates_dir = paths::prompt_overrides_dir(); - cx.background_executor() .spawn(async move { + let templates_dir = if dev_mode { + std::env::current_dir() + .ok() + .and_then(|pwd| { + let pwd_assets_prompts = pwd.join("assets").join("prompts"); + pwd_assets_prompts.exists().then_some(pwd_assets_prompts) + }) + .unwrap_or_else(|| paths::prompt_overrides_dir().clone()) + } else { + paths::prompt_overrides_dir().clone() + }; + // Create the prompt templates directory if it doesn't exist - if !fs.is_dir(templates_dir).await { - if let Err(e) = fs.create_dir(templates_dir).await { + if !fs.is_dir(&templates_dir).await { + if let Err(e) = fs.create_dir(&templates_dir).await { log::error!("Failed to create prompt templates directory: {}", e); return; } } // Initial scan of the prompts directory - if let Ok(mut entries) = fs.read_dir(templates_dir).await { + if let Ok(mut entries) = fs.read_dir(&templates_dir).await { while let Some(Ok(file_path)) = entries.next().await { if file_path.to_string_lossy().ends_with(".hbs") { if let Ok(content) = fs.load(&file_path).await { @@ -95,7 +110,7 @@ impl PromptBuilder { } // Watch for changes - let (mut changes, watcher) = fs.watch(templates_dir, Duration::from_secs(1)).await; + let (mut changes, watcher) = fs.watch(&templates_dir, Duration::from_secs(1)).await; while let Some(changed_paths) = changes.next().await { for changed_path in changed_paths { if changed_path.extension().map_or(false, |ext| ext == "hbs") { @@ -147,7 +162,8 @@ impl PromptBuilder { user_prompt: String, language_name: Option<&str>, buffer: BufferSnapshot, - range: Range, + transform_range: Range, + selected_ranges: Vec>, ) -> Result { let content_type = match language_name { None | Some("Markdown" | "Plain Text") => "text", @@ -155,21 +171,20 @@ impl PromptBuilder { }; const MAX_CTX: usize = 50000; - let is_insert = range.is_empty(); let mut is_truncated = false; - let before_range = 0..range.start; + let before_range = 0..transform_range.start; let truncated_before = if before_range.len() > MAX_CTX { is_truncated = true; - range.start - MAX_CTX..range.start + transform_range.start - MAX_CTX..transform_range.start } else { before_range }; - let after_range = range.end..buffer.len(); + let after_range = transform_range.end..buffer.len(); let truncated_after = if after_range.len() > MAX_CTX { is_truncated = true; - range.end..range.end + MAX_CTX + transform_range.end..transform_range.end + MAX_CTX } else { after_range }; @@ -178,37 +193,61 @@ impl PromptBuilder { for chunk in buffer.text_for_range(truncated_before) { document_content.push_str(chunk); } - if is_insert { - document_content.push_str(""); - } else { - document_content.push_str("\n"); - for chunk in buffer.text_for_range(range.clone()) { - document_content.push_str(chunk); - } - document_content.push_str("\n"); + document_content.push_str("\n"); + for chunk in buffer.text_for_range(transform_range.clone()) { + document_content.push_str(chunk); } + document_content.push_str("\n"); + for chunk in buffer.text_for_range(truncated_after) { document_content.push_str(chunk); } - let rewrite_section = if !is_insert { - let mut section = String::new(); - for chunk in buffer.text_for_range(range.clone()) { - section.push_str(chunk); + let mut rewrite_section = String::new(); + for chunk in buffer.text_for_range(transform_range.clone()) { + rewrite_section.push_str(chunk); + } + + let rewrite_section_with_selections = { + let mut section_with_selections = String::new(); + let mut last_end = 0; + for selected_range in &selected_ranges { + if selected_range.start > last_end { + section_with_selections.push_str( + &rewrite_section[last_end..selected_range.start - transform_range.start], + ); + } + if selected_range.start == selected_range.end { + section_with_selections.push_str(""); + } else { + section_with_selections.push_str(""); + section_with_selections.push_str( + &rewrite_section[selected_range.start - transform_range.start + ..selected_range.end - transform_range.start], + ); + section_with_selections.push_str(""); + } + last_end = selected_range.end - transform_range.start; } - Some(section) - } else { - None + if last_end < rewrite_section.len() { + section_with_selections.push_str(&rewrite_section[last_end..]); + } + section_with_selections }; + let has_insertion = selected_ranges.iter().any(|range| range.start == range.end); + let has_replacement = selected_ranges.iter().any(|range| range.start != range.end); + let context = ContentPromptContext { content_type: content_type.to_string(), language_name: language_name.map(|s| s.to_string()), - is_insert, is_truncated, document_content, user_prompt, rewrite_section, + rewrite_section_with_selections, + has_insertion, + has_replacement, }; self.handlebars.lock().render("content_prompt", &context) diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 2eb5b3fc05842e8f307f3280dd4df910e4155433..173387b8406433723fcd9812dd52c5445a79b9f7 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -187,7 +187,12 @@ fn init_common(app_state: Arc, cx: &mut AppContext) -> Arc