From 02078140c010f5518bc990dede574f996db6f43b Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 11 Sep 2023 11:25:37 +0200 Subject: [PATCH 1/4] Extract code generation logic into its own module --- crates/ai/src/ai.rs | 1 + crates/ai/src/assistant.rs | 553 ++++++------------------------ crates/ai/src/codegen.rs | 468 +++++++++++++++++++++++++ crates/editor/src/editor.rs | 10 +- crates/editor/src/multi_buffer.rs | 43 ++- 5 files changed, 607 insertions(+), 468 deletions(-) create mode 100644 crates/ai/src/codegen.rs diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index 2c2d7e774e120980d212cdcbd887289ebee0e768..2e8eca80e3c2dbf1fe99e3849c6eae26917f4440 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -1,5 +1,6 @@ pub mod assistant; mod assistant_settings; +mod codegen; mod streaming_diff; use anyhow::{anyhow, Result}; diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index 9b384252fc0dfea3fe7897c1152fbca18fbcd9e0..1d56a6308cfeab744a092968f5f9bc2d5470efb6 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -1,9 +1,8 @@ use crate::{ assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel}, - stream_completion, - streaming_diff::{Hunk, StreamingDiff}, - MessageId, MessageMetadata, MessageStatus, OpenAIRequest, RequestMessage, Role, - SavedConversation, SavedConversationMetadata, SavedMessage, OPENAI_API_URL, + codegen::{self, Codegen, OpenAICompletionProvider}, + stream_completion, MessageId, MessageMetadata, MessageStatus, OpenAIRequest, RequestMessage, + Role, SavedConversation, SavedConversationMetadata, SavedMessage, OPENAI_API_URL, }; use anyhow::{anyhow, Result}; use chrono::{DateTime, Local}; @@ -13,10 +12,10 @@ use editor::{ BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint, }, scroll::autoscroll::{Autoscroll, AutoscrollStrategy}, - Anchor, Editor, MoveDown, MoveUp, MultiBufferSnapshot, ToOffset, ToPoint, + Anchor, Editor, MoveDown, MoveUp, MultiBufferSnapshot, ToOffset, }; use fs::Fs; -use futures::{channel::mpsc, SinkExt, Stream, StreamExt}; +use futures::StreamExt; use gpui::{ actions, elements::{ @@ -30,17 +29,14 @@ use gpui::{ ModelHandle, SizeConstraint, Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext, }; -use language::{ - language_settings::SoftWrap, Buffer, LanguageRegistry, Point, Rope, ToOffset as _, - TransactionId, -}; +use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _}; use search::BufferSearchBar; use settings::SettingsStore; use std::{ cell::{Cell, RefCell}, cmp, env, fmt::Write, - future, iter, + iter, ops::Range, path::{Path, PathBuf}, rc::Rc, @@ -266,10 +262,22 @@ impl AssistantPanel { } fn new_inline_assist(&mut self, editor: &ViewHandle, cx: &mut ViewContext) { + let api_key = if let Some(api_key) = self.api_key.borrow().clone() { + api_key + } else { + return; + }; + let inline_assist_id = post_inc(&mut self.next_inline_assist_id); let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); let selection = editor.read(cx).selections.newest_anchor().clone(); let range = selection.start.bias_left(&snapshot)..selection.end.bias_right(&snapshot); + let provider = Arc::new(OpenAICompletionProvider::new( + api_key, + cx.background().clone(), + )); + let codegen = + cx.add_model(|cx| Codegen::new(editor.read(cx).buffer().clone(), range, provider, cx)); let assist_kind = if editor.read(cx).selections.newest::(cx).is_empty() { InlineAssistKind::Generate } else { @@ -283,6 +291,7 @@ impl AssistantPanel { measurements.clone(), self.include_conversation_in_next_inline_assist, self.inline_prompt_history.clone(), + codegen.clone(), cx, ); cx.focus_self(); @@ -323,44 +332,53 @@ impl AssistantPanel { PendingInlineAssist { kind: assist_kind, editor: editor.downgrade(), - range, - highlighted_ranges: Default::default(), inline_assistant: Some((block_id, inline_assistant.clone())), - code_generation: Task::ready(None), - transaction_id: None, + codegen: codegen.clone(), _subscriptions: vec![ cx.subscribe(&inline_assistant, Self::handle_inline_assistant_event), cx.subscribe(editor, { let inline_assistant = inline_assistant.downgrade(); - move |this, editor, event, cx| { + move |_, editor, event, cx| { if let Some(inline_assistant) = inline_assistant.upgrade(cx) { - match event { - editor::Event::SelectionsChanged { local } => { - if *local && inline_assistant.read(cx).has_focus { - cx.focus(&editor); - } + if let editor::Event::SelectionsChanged { local } = event { + if *local && inline_assistant.read(cx).has_focus { + cx.focus(&editor); } - editor::Event::TransactionUndone { - transaction_id: tx_id, - } => { - if let Some(pending_assist) = - this.pending_inline_assists.get(&inline_assist_id) - { - if pending_assist.transaction_id == Some(*tx_id) { - // Notice we are supplying `undo: false` here. This - // is because there's no need to undo the transaction - // because the user just did so. - this.close_inline_assist( - inline_assist_id, - false, - cx, - ); - } - } + } + } + } + }), + cx.subscribe(&codegen, move |this, codegen, event, cx| match event { + codegen::Event::Undone => { + this.finish_inline_assist(inline_assist_id, false, cx) + } + codegen::Event::Finished => { + let pending_assist = if let Some(pending_assist) = + this.pending_inline_assists.get(&inline_assist_id) + { + pending_assist + } else { + return; + }; + + let error = codegen + .read(cx) + .error() + .map(|error| format!("Inline assistant error: {}", error)); + if let Some(error) = error { + if pending_assist.inline_assistant.is_none() { + if let Some(workspace) = this.workspace.upgrade(cx) { + workspace.update(cx, |workspace, cx| { + workspace.show_toast( + Toast::new(inline_assist_id, error), + cx, + ); + }) } - _ => {} } } + + this.finish_inline_assist(inline_assist_id, false, cx); } }), ], @@ -388,7 +406,7 @@ impl AssistantPanel { self.confirm_inline_assist(assist_id, prompt, *include_conversation, cx); } InlineAssistantEvent::Canceled => { - self.close_inline_assist(assist_id, true, cx); + self.finish_inline_assist(assist_id, true, cx); } InlineAssistantEvent::Dismissed => { self.hide_inline_assist(assist_id, cx); @@ -417,7 +435,7 @@ impl AssistantPanel { .get(&editor.downgrade()) .and_then(|assist_ids| assist_ids.last().copied()) { - panel.close_inline_assist(assist_id, true, cx); + panel.finish_inline_assist(assist_id, true, cx); true } else { false @@ -432,7 +450,7 @@ impl AssistantPanel { cx.propagate_action(); } - fn close_inline_assist(&mut self, assist_id: usize, undo: bool, cx: &mut ViewContext) { + fn finish_inline_assist(&mut self, assist_id: usize, undo: bool, cx: &mut ViewContext) { self.hide_inline_assist(assist_id, cx); if let Some(pending_assist) = self.pending_inline_assists.remove(&assist_id) { @@ -450,13 +468,9 @@ impl AssistantPanel { self.update_highlights_for_editor(&editor, cx); if undo { - if let Some(transaction_id) = pending_assist.transaction_id { - editor.update(cx, |editor, cx| { - editor.buffer().update(cx, |buffer, cx| { - buffer.undo_transaction(transaction_id, cx) - }); - }); - } + pending_assist + .codegen + .update(cx, |codegen, cx| codegen.undo(cx)); } } } @@ -481,12 +495,6 @@ impl AssistantPanel { include_conversation: bool, cx: &mut ViewContext, ) { - let api_key = if let Some(api_key) = self.api_key.borrow().clone() { - api_key - } else { - return; - }; - let conversation = if include_conversation { self.active_editor() .map(|editor| editor.read(cx).conversation.clone()) @@ -514,56 +522,9 @@ impl AssistantPanel { self.inline_prompt_history.pop_front(); } - let range = pending_assist.range.clone(); let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); - let selected_text = snapshot - .text_for_range(range.start..range.end) - .collect::(); - - let selection_start = range.start.to_point(&snapshot); - let selection_end = range.end.to_point(&snapshot); - - let mut base_indent: Option = None; - let mut start_row = selection_start.row; - if snapshot.is_line_blank(start_row) { - if let Some(prev_non_blank_row) = snapshot.prev_non_blank_row(start_row) { - start_row = prev_non_blank_row; - } - } - for row in start_row..=selection_end.row { - if snapshot.is_line_blank(row) { - continue; - } - - let line_indent = snapshot.indent_size_for_line(row); - if let Some(base_indent) = base_indent.as_mut() { - if line_indent.len < base_indent.len { - *base_indent = line_indent; - } - } else { - base_indent = Some(line_indent); - } - } - - let mut normalized_selected_text = selected_text.clone(); - if let Some(base_indent) = base_indent { - for row in selection_start.row..=selection_end.row { - let selection_row = row - selection_start.row; - let line_start = - normalized_selected_text.point_to_offset(Point::new(selection_row, 0)); - let indent_len = if row == selection_start.row { - base_indent.len.saturating_sub(selection_start.column) - } else { - let line_len = normalized_selected_text.line_len(selection_row); - cmp::min(line_len, base_indent.len) - }; - let indent_end = cmp::min( - line_start + indent_len as usize, - normalized_selected_text.len(), - ); - normalized_selected_text.replace(line_start..indent_end, ""); - } - } + let range = pending_assist.codegen.read(cx).range(); + let selected_text = snapshot.text_for_range(range.clone()).collect::(); let language = snapshot.language_at(range.start); let language_name = if let Some(language) = language.as_ref() { @@ -608,7 +569,7 @@ impl AssistantPanel { } else { writeln!(prompt, "```").unwrap(); } - writeln!(prompt, "{normalized_selected_text}").unwrap(); + writeln!(prompt, "{selected_text}").unwrap(); writeln!(prompt, "```").unwrap(); writeln!(prompt).unwrap(); writeln!( @@ -689,209 +650,9 @@ impl AssistantPanel { messages, stream: true, }; - let response = stream_completion(api_key, cx.background().clone(), request); - let editor = editor.downgrade(); - - pending_assist.code_generation = cx.spawn(|this, mut cx| { - async move { - let mut edit_start = range.start.to_offset(&snapshot); - - let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1); - let diff = cx.background().spawn(async move { - let chunks = strip_markdown_codeblock(response.await?.filter_map( - |message| async move { - match message { - Ok(mut message) => Some(Ok(message.choices.pop()?.delta.content?)), - Err(error) => Some(Err(error)), - } - }, - )); - futures::pin_mut!(chunks); - let mut diff = StreamingDiff::new(selected_text.to_string()); - - let mut indent_len; - let indent_text; - if let Some(base_indent) = base_indent { - indent_len = base_indent.len; - indent_text = match base_indent.kind { - language::IndentKind::Space => " ", - language::IndentKind::Tab => "\t", - }; - } else { - indent_len = 0; - indent_text = ""; - }; - - let mut first_line_len = 0; - let mut first_line_non_whitespace_char_ix = None; - let mut first_line = true; - let mut new_text = String::new(); - - while let Some(chunk) = chunks.next().await { - let chunk = chunk?; - - let mut lines = chunk.split('\n'); - if let Some(mut line) = lines.next() { - if first_line { - if first_line_non_whitespace_char_ix.is_none() { - if let Some(mut char_ix) = - line.find(|ch: char| !ch.is_whitespace()) - { - line = &line[char_ix..]; - char_ix += first_line_len; - first_line_non_whitespace_char_ix = Some(char_ix); - let first_line_indent = char_ix - .saturating_sub(selection_start.column as usize) - as usize; - new_text.push_str(&indent_text.repeat(first_line_indent)); - indent_len = indent_len.saturating_sub(char_ix as u32); - } - } - first_line_len += line.len(); - } - - if first_line_non_whitespace_char_ix.is_some() { - new_text.push_str(line); - } - } - - for line in lines { - first_line = false; - new_text.push('\n'); - if !line.is_empty() { - new_text.push_str(&indent_text.repeat(indent_len as usize)); - } - new_text.push_str(line); - } - - let hunks = diff.push_new(&new_text); - hunks_tx.send(hunks).await?; - new_text.clear(); - } - hunks_tx.send(diff.finish()).await?; - - anyhow::Ok(()) - }); - - while let Some(hunks) = hunks_rx.next().await { - let editor = if let Some(editor) = editor.upgrade(&cx) { - editor - } else { - break; - }; - - let this = if let Some(this) = this.upgrade(&cx) { - this - } else { - break; - }; - - this.update(&mut cx, |this, cx| { - let pending_assist = if let Some(pending_assist) = - this.pending_inline_assists.get_mut(&inline_assist_id) - { - pending_assist - } else { - return; - }; - - pending_assist.highlighted_ranges.clear(); - editor.update(cx, |editor, cx| { - let transaction = editor.buffer().update(cx, |buffer, cx| { - // Avoid grouping assistant edits with user edits. - buffer.finalize_last_transaction(cx); - - buffer.start_transaction(cx); - buffer.edit( - hunks.into_iter().filter_map(|hunk| match hunk { - Hunk::Insert { text } => { - let edit_start = snapshot.anchor_after(edit_start); - Some((edit_start..edit_start, text)) - } - Hunk::Remove { len } => { - let edit_end = edit_start + len; - let edit_range = snapshot.anchor_after(edit_start) - ..snapshot.anchor_before(edit_end); - edit_start = edit_end; - Some((edit_range, String::new())) - } - Hunk::Keep { len } => { - let edit_end = edit_start + len; - let edit_range = snapshot.anchor_after(edit_start) - ..snapshot.anchor_before(edit_end); - edit_start += len; - pending_assist.highlighted_ranges.push(edit_range); - None - } - }), - None, - cx, - ); - - buffer.end_transaction(cx) - }); - - if let Some(transaction) = transaction { - if let Some(first_transaction) = pending_assist.transaction_id { - // Group all assistant edits into the first transaction. - editor.buffer().update(cx, |buffer, cx| { - buffer.merge_transactions( - transaction, - first_transaction, - cx, - ) - }); - } else { - pending_assist.transaction_id = Some(transaction); - editor.buffer().update(cx, |buffer, cx| { - buffer.finalize_last_transaction(cx) - }); - } - } - }); - - this.update_highlights_for_editor(&editor, cx); - }); - } - - if let Err(error) = diff.await { - this.update(&mut cx, |this, cx| { - let pending_assist = if let Some(pending_assist) = - this.pending_inline_assists.get_mut(&inline_assist_id) - { - pending_assist - } else { - return; - }; - - if let Some((_, inline_assistant)) = - pending_assist.inline_assistant.as_ref() - { - inline_assistant.update(cx, |inline_assistant, cx| { - inline_assistant.set_error(error, cx); - }); - } else if let Some(workspace) = this.workspace.upgrade(cx) { - workspace.update(cx, |workspace, cx| { - workspace.show_toast( - Toast::new( - inline_assist_id, - format!("Inline assistant error: {}", error), - ), - cx, - ); - }) - } - })?; - } else { - let _ = this.update(&mut cx, |this, cx| { - this.close_inline_assist(inline_assist_id, false, cx) - }); - } - - anyhow::Ok(()) - } - .log_err() - }); + pending_assist + .codegen + .update(cx, |codegen, cx| codegen.start(request, cx)); } fn update_highlights_for_editor( @@ -909,8 +670,9 @@ impl AssistantPanel { for inline_assist_id in inline_assist_ids { if let Some(pending_assist) = self.pending_inline_assists.get(inline_assist_id) { - background_ranges.push(pending_assist.range.clone()); - foreground_ranges.extend(pending_assist.highlighted_ranges.iter().cloned()); + let codegen = pending_assist.codegen.read(cx); + background_ranges.push(codegen.range()); + foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned()); } } @@ -2900,11 +2662,11 @@ struct InlineAssistant { has_focus: bool, include_conversation: bool, measurements: Rc>, - error: Option, prompt_history: VecDeque, prompt_history_ix: Option, pending_prompt: String, - _subscription: Subscription, + codegen: ModelHandle, + _subscriptions: Vec, } impl Entity for InlineAssistant { @@ -2933,7 +2695,7 @@ impl View for InlineAssistant { .element() .aligned(), ) - .with_children(if let Some(error) = self.error.as_ref() { + .with_children(if let Some(error) = self.codegen.read(cx).error() { Some( Svg::new("icons/circle_x_mark_12.svg") .with_color(theme.assistant.error_icon.color) @@ -3011,6 +2773,7 @@ impl InlineAssistant { measurements: Rc>, include_conversation: bool, prompt_history: VecDeque, + codegen: ModelHandle, cx: &mut ViewContext, ) -> Self { let prompt_editor = cx.add_view(|cx| { @@ -3025,7 +2788,10 @@ impl InlineAssistant { editor.set_placeholder_text(placeholder, cx); editor }); - let subscription = cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events); + let subscriptions = vec![ + cx.observe(&codegen, Self::handle_codegen_changed), + cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events), + ]; Self { id, prompt_editor, @@ -3033,11 +2799,11 @@ impl InlineAssistant { has_focus: false, include_conversation, measurements, - error: None, prompt_history, prompt_history_ix: None, pending_prompt: String::new(), - _subscription: subscription, + codegen, + _subscriptions: subscriptions, } } @@ -3053,6 +2819,31 @@ impl InlineAssistant { } } + fn handle_codegen_changed(&mut self, _: ModelHandle, cx: &mut ViewContext) { + let is_read_only = !self.codegen.read(cx).idle(); + self.prompt_editor.update(cx, |editor, cx| { + let was_read_only = editor.read_only(); + if was_read_only != is_read_only { + if is_read_only { + editor.set_read_only(true); + editor.set_field_editor_style( + Some(Arc::new(|theme| { + theme.assistant.inline.disabled_editor.clone() + })), + cx, + ); + } else { + editor.set_read_only(false); + editor.set_field_editor_style( + Some(Arc::new(|theme| theme.assistant.inline.editor.clone())), + cx, + ); + } + } + }); + cx.notify(); + } + fn cancel(&mut self, _: &editor::Cancel, cx: &mut ViewContext) { cx.emit(InlineAssistantEvent::Canceled); } @@ -3076,7 +2867,6 @@ impl InlineAssistant { include_conversation: self.include_conversation, }); self.confirmed = true; - self.error = None; cx.notify(); } } @@ -3093,19 +2883,6 @@ impl InlineAssistant { cx.notify(); } - fn set_error(&mut self, error: anyhow::Error, cx: &mut ViewContext) { - self.error = Some(error); - self.confirmed = false; - self.prompt_editor.update(cx, |editor, cx| { - editor.set_read_only(false); - editor.set_field_editor_style( - Some(Arc::new(|theme| theme.assistant.inline.editor.clone())), - cx, - ); - }); - cx.notify(); - } - fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext) { if let Some(ix) = self.prompt_history_ix { if ix > 0 { @@ -3154,11 +2931,8 @@ struct BlockMeasurements { struct PendingInlineAssist { kind: InlineAssistKind, editor: WeakViewHandle, - range: Range, - highlighted_ranges: Vec>, inline_assistant: Option<(BlockId, ViewHandle)>, - code_generation: Task>, - transaction_id: Option, + codegen: ModelHandle, _subscriptions: Vec, } @@ -3184,65 +2958,10 @@ fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) { } } -fn strip_markdown_codeblock( - stream: impl Stream>, -) -> impl Stream> { - let mut first_line = true; - let mut buffer = String::new(); - let mut starts_with_fenced_code_block = false; - stream.filter_map(move |chunk| { - let chunk = match chunk { - Ok(chunk) => chunk, - Err(err) => return future::ready(Some(Err(err))), - }; - buffer.push_str(&chunk); - - if first_line { - if buffer == "" || buffer == "`" || buffer == "``" { - return future::ready(None); - } else if buffer.starts_with("```") { - starts_with_fenced_code_block = true; - if let Some(newline_ix) = buffer.find('\n') { - buffer.replace_range(..newline_ix + 1, ""); - first_line = false; - } else { - return future::ready(None); - } - } - } - - let text = if starts_with_fenced_code_block { - buffer - .strip_suffix("\n```\n") - .or_else(|| buffer.strip_suffix("\n```")) - .or_else(|| buffer.strip_suffix("\n``")) - .or_else(|| buffer.strip_suffix("\n`")) - .or_else(|| buffer.strip_suffix('\n')) - .unwrap_or(&buffer) - } else { - &buffer - }; - - if text.contains('\n') { - first_line = false; - } - - let remainder = buffer.split_off(text.len()); - let result = if buffer.is_empty() { - None - } else { - Some(Ok(buffer.clone())) - }; - buffer = remainder; - future::ready(result) - }) -} - #[cfg(test)] mod tests { use super::*; use crate::MessageId; - use futures::stream; use gpui::AppContext; #[gpui::test] @@ -3611,62 +3330,6 @@ mod tests { ); } - #[gpui::test] - async fn test_strip_markdown_codeblock() { - assert_eq!( - strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "Lorem ipsum dolor" - ); - assert_eq!( - strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "Lorem ipsum dolor" - ); - assert_eq!( - strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "Lorem ipsum dolor" - ); - assert_eq!( - strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "Lorem ipsum dolor" - ); - assert_eq!( - strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "```js\nLorem ipsum dolor\n```" - ); - assert_eq!( - strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, - "``\nLorem ipsum dolor\n```" - ); - - fn chunks(text: &str, size: usize) -> impl Stream> { - stream::iter( - text.chars() - .collect::>() - .chunks(size) - .map(|chunk| Ok(chunk.iter().collect::())) - .collect::>(), - ) - } - } - fn messages( conversation: &ModelHandle, cx: &AppContext, diff --git a/crates/ai/src/codegen.rs b/crates/ai/src/codegen.rs new file mode 100644 index 0000000000000000000000000000000000000000..b24c0f94356111e5d045a58394696bca902efcd8 --- /dev/null +++ b/crates/ai/src/codegen.rs @@ -0,0 +1,468 @@ +use crate::{ + stream_completion, + streaming_diff::{Hunk, StreamingDiff}, + OpenAIRequest, +}; +use anyhow::Result; +use editor::{multi_buffer, Anchor, MultiBuffer, ToOffset, ToPoint}; +use futures::{ + channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, SinkExt, Stream, StreamExt, +}; +use gpui::{executor::Background, Entity, ModelContext, ModelHandle, Task}; +use language::{IndentSize, Point, Rope, TransactionId}; +use std::{cmp, future, ops::Range, sync::Arc}; + +pub trait CompletionProvider { + fn complete( + &self, + prompt: OpenAIRequest, + ) -> BoxFuture<'static, Result>>>; +} + +pub struct OpenAICompletionProvider { + api_key: String, + executor: Arc, +} + +impl OpenAICompletionProvider { + pub fn new(api_key: String, executor: Arc) -> Self { + Self { api_key, executor } + } +} + +impl CompletionProvider for OpenAICompletionProvider { + fn complete( + &self, + prompt: OpenAIRequest, + ) -> BoxFuture<'static, Result>>> { + let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt); + async move { + let response = request.await?; + let stream = response + .filter_map(|response| async move { + match response { + Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)), + Err(error) => Some(Err(error)), + } + }) + .boxed(); + Ok(stream) + } + .boxed() + } +} + +pub enum Event { + Finished, + Undone, +} + +pub struct Codegen { + provider: Arc, + buffer: ModelHandle, + range: Range, + last_equal_ranges: Vec>, + transaction_id: Option, + error: Option, + generation: Task<()>, + idle: bool, + _subscription: gpui::Subscription, +} + +impl Entity for Codegen { + type Event = Event; +} + +impl Codegen { + pub fn new( + buffer: ModelHandle, + range: Range, + provider: Arc, + cx: &mut ModelContext, + ) -> Self { + Self { + provider, + buffer: buffer.clone(), + range, + last_equal_ranges: Default::default(), + transaction_id: Default::default(), + error: Default::default(), + idle: true, + generation: Task::ready(()), + _subscription: cx.subscribe(&buffer, Self::handle_buffer_event), + } + } + + fn handle_buffer_event( + &mut self, + _buffer: ModelHandle, + event: &multi_buffer::Event, + cx: &mut ModelContext, + ) { + if let multi_buffer::Event::TransactionUndone { transaction_id } = event { + if self.transaction_id == Some(*transaction_id) { + self.transaction_id = None; + self.generation = Task::ready(()); + cx.emit(Event::Undone); + } + } + } + + pub fn range(&self) -> Range { + self.range.clone() + } + + pub fn last_equal_ranges(&self) -> &[Range] { + &self.last_equal_ranges + } + + pub fn idle(&self) -> bool { + self.idle + } + + pub fn error(&self) -> Option<&anyhow::Error> { + self.error.as_ref() + } + + pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext) { + let range = self.range.clone(); + let snapshot = self.buffer.read(cx).snapshot(cx); + let selected_text = snapshot + .text_for_range(range.start..range.end) + .collect::(); + + let selection_start = range.start.to_point(&snapshot); + let selection_end = range.end.to_point(&snapshot); + + let mut base_indent: Option = None; + let mut start_row = selection_start.row; + if snapshot.is_line_blank(start_row) { + if let Some(prev_non_blank_row) = snapshot.prev_non_blank_row(start_row) { + start_row = prev_non_blank_row; + } + } + for row in start_row..=selection_end.row { + if snapshot.is_line_blank(row) { + continue; + } + + let line_indent = snapshot.indent_size_for_line(row); + if let Some(base_indent) = base_indent.as_mut() { + if line_indent.len < base_indent.len { + *base_indent = line_indent; + } + } else { + base_indent = Some(line_indent); + } + } + + let mut normalized_selected_text = selected_text.clone(); + if let Some(base_indent) = base_indent { + for row in selection_start.row..=selection_end.row { + let selection_row = row - selection_start.row; + let line_start = + normalized_selected_text.point_to_offset(Point::new(selection_row, 0)); + let indent_len = if row == selection_start.row { + base_indent.len.saturating_sub(selection_start.column) + } else { + let line_len = normalized_selected_text.line_len(selection_row); + cmp::min(line_len, base_indent.len) + }; + let indent_end = cmp::min( + line_start + indent_len as usize, + normalized_selected_text.len(), + ); + normalized_selected_text.replace(line_start..indent_end, ""); + } + } + + let response = self.provider.complete(prompt); + self.generation = cx.spawn_weak(|this, mut cx| { + async move { + let generate = async { + let mut edit_start = range.start.to_offset(&snapshot); + + let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1); + let diff = cx.background().spawn(async move { + let chunks = strip_markdown_codeblock(response.await?); + futures::pin_mut!(chunks); + let mut diff = StreamingDiff::new(selected_text.to_string()); + + let mut indent_len; + let indent_text; + if let Some(base_indent) = base_indent { + indent_len = base_indent.len; + indent_text = match base_indent.kind { + language::IndentKind::Space => " ", + language::IndentKind::Tab => "\t", + }; + } else { + indent_len = 0; + indent_text = ""; + }; + + let mut first_line_len = 0; + let mut first_line_non_whitespace_char_ix = None; + let mut first_line = true; + let mut new_text = String::new(); + + while let Some(chunk) = chunks.next().await { + let chunk = chunk?; + + let mut lines = chunk.split('\n'); + if let Some(mut line) = lines.next() { + if first_line { + if first_line_non_whitespace_char_ix.is_none() { + if let Some(mut char_ix) = + line.find(|ch: char| !ch.is_whitespace()) + { + line = &line[char_ix..]; + char_ix += first_line_len; + first_line_non_whitespace_char_ix = Some(char_ix); + let first_line_indent = char_ix + .saturating_sub(selection_start.column as usize) + as usize; + new_text + .push_str(&indent_text.repeat(first_line_indent)); + indent_len = indent_len.saturating_sub(char_ix as u32); + } + } + first_line_len += line.len(); + } + + if first_line_non_whitespace_char_ix.is_some() { + new_text.push_str(line); + } + } + + for line in lines { + first_line = false; + new_text.push('\n'); + if !line.is_empty() { + new_text.push_str(&indent_text.repeat(indent_len as usize)); + } + new_text.push_str(line); + } + + let hunks = diff.push_new(&new_text); + hunks_tx.send(hunks).await?; + new_text.clear(); + } + hunks_tx.send(diff.finish()).await?; + + anyhow::Ok(()) + }); + + while let Some(hunks) = hunks_rx.next().await { + let this = if let Some(this) = this.upgrade(&cx) { + this + } else { + break; + }; + + this.update(&mut cx, |this, cx| { + this.last_equal_ranges.clear(); + + let transaction = this.buffer.update(cx, |buffer, cx| { + // Avoid grouping assistant edits with user edits. + buffer.finalize_last_transaction(cx); + + buffer.start_transaction(cx); + buffer.edit( + hunks.into_iter().filter_map(|hunk| match hunk { + Hunk::Insert { text } => { + let edit_start = snapshot.anchor_after(edit_start); + Some((edit_start..edit_start, text)) + } + Hunk::Remove { len } => { + let edit_end = edit_start + len; + let edit_range = snapshot.anchor_after(edit_start) + ..snapshot.anchor_before(edit_end); + edit_start = edit_end; + Some((edit_range, String::new())) + } + Hunk::Keep { len } => { + let edit_end = edit_start + len; + let edit_range = snapshot.anchor_after(edit_start) + ..snapshot.anchor_before(edit_end); + edit_start += len; + this.last_equal_ranges.push(edit_range); + None + } + }), + None, + cx, + ); + + buffer.end_transaction(cx) + }); + + if let Some(transaction) = transaction { + if let Some(first_transaction) = this.transaction_id { + // Group all assistant edits into the first transaction. + this.buffer.update(cx, |buffer, cx| { + buffer.merge_transactions( + transaction, + first_transaction, + cx, + ) + }); + } else { + this.transaction_id = Some(transaction); + this.buffer.update(cx, |buffer, cx| { + buffer.finalize_last_transaction(cx) + }); + } + } + + cx.notify(); + }); + } + + diff.await?; + anyhow::Ok(()) + }; + + let result = generate.await; + if let Some(this) = this.upgrade(&cx) { + this.update(&mut cx, |this, cx| { + this.last_equal_ranges.clear(); + this.idle = true; + if let Err(error) = result { + this.error = Some(error); + } + cx.emit(Event::Finished); + cx.notify(); + }); + } + } + }); + self.error.take(); + self.idle = false; + cx.notify(); + } + + pub fn undo(&mut self, cx: &mut ModelContext) { + if let Some(transaction_id) = self.transaction_id { + self.buffer + .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx)); + } + } +} + +fn strip_markdown_codeblock( + stream: impl Stream>, +) -> impl Stream> { + let mut first_line = true; + let mut buffer = String::new(); + let mut starts_with_fenced_code_block = false; + stream.filter_map(move |chunk| { + let chunk = match chunk { + Ok(chunk) => chunk, + Err(err) => return future::ready(Some(Err(err))), + }; + buffer.push_str(&chunk); + + if first_line { + if buffer == "" || buffer == "`" || buffer == "``" { + return future::ready(None); + } else if buffer.starts_with("```") { + starts_with_fenced_code_block = true; + if let Some(newline_ix) = buffer.find('\n') { + buffer.replace_range(..newline_ix + 1, ""); + first_line = false; + } else { + return future::ready(None); + } + } + } + + let text = if starts_with_fenced_code_block { + buffer + .strip_suffix("\n```\n") + .or_else(|| buffer.strip_suffix("\n```")) + .or_else(|| buffer.strip_suffix("\n``")) + .or_else(|| buffer.strip_suffix("\n`")) + .or_else(|| buffer.strip_suffix('\n')) + .unwrap_or(&buffer) + } else { + &buffer + }; + + if text.contains('\n') { + first_line = false; + } + + let remainder = buffer.split_off(text.len()); + let result = if buffer.is_empty() { + None + } else { + Some(Ok(buffer.clone())) + }; + buffer = remainder; + future::ready(result) + }) +} + +#[cfg(test)] +mod tests { + use futures::stream; + + use super::*; + + #[gpui::test] + async fn test_strip_markdown_codeblock() { + assert_eq!( + strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "Lorem ipsum dolor" + ); + assert_eq!( + strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "Lorem ipsum dolor" + ); + assert_eq!( + strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "Lorem ipsum dolor" + ); + assert_eq!( + strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "Lorem ipsum dolor" + ); + assert_eq!( + strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "```js\nLorem ipsum dolor\n```" + ); + assert_eq!( + strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "``\nLorem ipsum dolor\n```" + ); + + fn chunks(text: &str, size: usize) -> impl Stream> { + stream::iter( + text.chars() + .collect::>() + .chunks(size) + .map(|chunk| Ok(chunk.iter().collect::())) + .collect::>(), + ) + } + } +} diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index bdd29b04fa20e2ecd2492d6554f5541b2bc99540..12df29df1d4ea88d584f9e25d9a0df423b151606 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -1734,6 +1734,10 @@ impl Editor { } } + pub fn read_only(&self) -> bool { + self.read_only + } + pub fn set_read_only(&mut self, read_only: bool) { self.read_only = read_only; } @@ -5103,9 +5107,6 @@ impl Editor { self.unmark_text(cx); self.refresh_copilot_suggestions(true, cx); cx.emit(Event::Edited); - cx.emit(Event::TransactionUndone { - transaction_id: tx_id, - }); } } @@ -8548,9 +8549,6 @@ pub enum Event { local: bool, autoscroll: bool, }, - TransactionUndone { - transaction_id: TransactionId, - }, Closed, } diff --git a/crates/editor/src/multi_buffer.rs b/crates/editor/src/multi_buffer.rs index 74283fd778dda2a4d1f0125338eade33d89216cf..c5d17dfd2eb8ac67f81a19f63ccbb7f32f9c1c62 100644 --- a/crates/editor/src/multi_buffer.rs +++ b/crates/editor/src/multi_buffer.rs @@ -70,6 +70,9 @@ pub enum Event { Edited { sigleton_buffer_edited: bool, }, + TransactionUndone { + transaction_id: TransactionId, + }, Reloaded, DiffBaseChanged, LanguageChanged, @@ -771,30 +774,36 @@ impl MultiBuffer { } pub fn undo(&mut self, cx: &mut ModelContext) -> Option { + let mut transaction_id = None; if let Some(buffer) = self.as_singleton() { - return buffer.update(cx, |buffer, cx| buffer.undo(cx)); - } + transaction_id = buffer.update(cx, |buffer, cx| buffer.undo(cx)); + } else { + while let Some(transaction) = self.history.pop_undo() { + let mut undone = false; + for (buffer_id, buffer_transaction_id) in &mut transaction.buffer_transactions { + if let Some(BufferState { buffer, .. }) = self.buffers.borrow().get(buffer_id) { + undone |= buffer.update(cx, |buffer, cx| { + let undo_to = *buffer_transaction_id; + if let Some(entry) = buffer.peek_undo_stack() { + *buffer_transaction_id = entry.transaction_id(); + } + buffer.undo_to_transaction(undo_to, cx) + }); + } + } - while let Some(transaction) = self.history.pop_undo() { - let mut undone = false; - for (buffer_id, buffer_transaction_id) in &mut transaction.buffer_transactions { - if let Some(BufferState { buffer, .. }) = self.buffers.borrow().get(buffer_id) { - undone |= buffer.update(cx, |buffer, cx| { - let undo_to = *buffer_transaction_id; - if let Some(entry) = buffer.peek_undo_stack() { - *buffer_transaction_id = entry.transaction_id(); - } - buffer.undo_to_transaction(undo_to, cx) - }); + if undone { + transaction_id = Some(transaction.id); + break; } } + } - if undone { - return Some(transaction.id); - } + if let Some(transaction_id) = transaction_id { + cx.emit(Event::TransactionUndone { transaction_id }); } - None + transaction_id } pub fn redo(&mut self, cx: &mut ModelContext) -> Option { From 6d9333dc3b46f668cee18f03c47e6a64b8224e8a Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 11 Sep 2023 14:35:15 +0200 Subject: [PATCH 2/4] Add a failing test for codegen autoindent --- Cargo.lock | 1 + crates/ai/Cargo.toml | 1 + crates/ai/src/ai.rs | 2 +- crates/ai/src/codegen.rs | 113 ++++++++++++++++++++++++++++++++++++++- 4 files changed, 115 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 121e9a28dd160e9eef6d4d7497c7876196a57c88..bf0ed9b163253cca33a4b857e4df502ba0de18f6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -114,6 +114,7 @@ dependencies = [ "log", "menu", "ordered-float", + "parking_lot 0.11.2", "project", "rand 0.8.5", "regex", diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index 4438f88108988e715d476a65fc2566d8a5f8e090..d96e470d5cdce93923e156ca9afb4e32a6f921e1 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -27,6 +27,7 @@ futures.workspace = true indoc.workspace = true isahc.workspace = true ordered-float.workspace = true +parking_lot.workspace = true regex.workspace = true schemars.workspace = true serde.workspace = true diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index 2e8eca80e3c2dbf1fe99e3849c6eae26917f4440..7d9b93b0a7dbdcd41954cecc40c62fee14689a13 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -27,7 +27,7 @@ use util::paths::CONVERSATIONS_DIR; const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; // Data types for chat completion requests -#[derive(Debug, Serialize)] +#[derive(Debug, Default, Serialize)] pub struct OpenAIRequest { model: String, messages: Vec, diff --git a/crates/ai/src/codegen.rs b/crates/ai/src/codegen.rs index b24c0f94356111e5d045a58394696bca902efcd8..9657d9a4926c17eadb5ae7d78a020ee6342deacf 100644 --- a/crates/ai/src/codegen.rs +++ b/crates/ai/src/codegen.rs @@ -406,9 +406,68 @@ fn strip_markdown_codeblock( #[cfg(test)] mod tests { + use super::*; use futures::stream; + use gpui::{executor::Deterministic, TestAppContext}; + use indoc::indoc; + use language::{tree_sitter_rust, Buffer, Language, LanguageConfig}; + use parking_lot::Mutex; + use rand::prelude::*; + + #[gpui::test(iterations = 10)] + async fn test_autoindent( + cx: &mut TestAppContext, + mut rng: StdRng, + deterministic: Arc, + ) { + let text = indoc! {" + fn main() { + let x = 0; + for _ in 0..10 { + x += 1; + } + } + "}; + let buffer = + cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx)); + let range = buffer.read_with(cx, |buffer, cx| { + let snapshot = buffer.snapshot(cx); + snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4)) + }); + let provider = Arc::new(TestCompletionProvider::new()); + let codegen = cx.add_model(|cx| Codegen::new(buffer.clone(), range, provider.clone(), cx)); + codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); + + let mut new_text = indoc! {" + let mut x = 0; + while x < 10 { + x += 1; + } + "}; + while !new_text.is_empty() { + let max_len = cmp::min(new_text.len(), 10); + let len = rng.gen_range(1..=max_len); + let (chunk, suffix) = new_text.split_at(len); + provider.send_completion(chunk); + new_text = suffix; + deterministic.run_until_parked(); + } + provider.finish_completion(); + deterministic.run_until_parked(); - use super::*; + assert_eq!( + buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), + indoc! {" + fn main() { + let mut x = 0; + while x < 10 { + x += 1; + } + } + "} + ); + } #[gpui::test] async fn test_strip_markdown_codeblock() { @@ -465,4 +524,56 @@ mod tests { ) } } + + struct TestCompletionProvider { + last_completion_tx: Mutex>>, + } + + impl TestCompletionProvider { + fn new() -> Self { + Self { + last_completion_tx: Mutex::new(None), + } + } + + fn send_completion(&self, completion: impl Into) { + let mut tx = self.last_completion_tx.lock(); + tx.as_mut().unwrap().try_send(completion.into()).unwrap(); + } + + fn finish_completion(&self) { + self.last_completion_tx.lock().take().unwrap(); + } + } + + impl CompletionProvider for TestCompletionProvider { + fn complete( + &self, + _prompt: OpenAIRequest, + ) -> BoxFuture<'static, Result>>> { + let (tx, rx) = mpsc::channel(1); + *self.last_completion_tx.lock() = Some(tx); + async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed() + } + } + + fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + Some(tree_sitter_rust::language()), + ) + .with_indents_query( + r#" + (call_expression) @indent + (field_expression) @indent + (_ "(" ")" @end) @indent + (_ "{" "}" @end) @indent + "#, + ) + .unwrap() + } } From b8c437529c4cfb0a30146b668896a37cf26f65a6 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 11 Sep 2023 16:33:25 +0200 Subject: [PATCH 3/4] Never use the indentation that comes from OpenAI --- crates/ai/src/assistant.rs | 51 +++--- crates/ai/src/codegen.rs | 338 +++++++++++++++++++++++++------------ 2 files changed, 257 insertions(+), 132 deletions(-) diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index 1d56a6308cfeab744a092968f5f9bc2d5470efb6..6d4fce2f6d09de085f2e78e8f10aeaad8c0bd16c 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -1,6 +1,6 @@ use crate::{ assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel}, - codegen::{self, Codegen, OpenAICompletionProvider}, + codegen::{self, Codegen, CodegenKind, OpenAICompletionProvider}, stream_completion, MessageId, MessageMetadata, MessageStatus, OpenAIRequest, RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage, OPENAI_API_URL, }; @@ -270,24 +270,28 @@ impl AssistantPanel { let inline_assist_id = post_inc(&mut self.next_inline_assist_id); let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); - let selection = editor.read(cx).selections.newest_anchor().clone(); - let range = selection.start.bias_left(&snapshot)..selection.end.bias_right(&snapshot); let provider = Arc::new(OpenAICompletionProvider::new( api_key, cx.background().clone(), )); - let codegen = - cx.add_model(|cx| Codegen::new(editor.read(cx).buffer().clone(), range, provider, cx)); - let assist_kind = if editor.read(cx).selections.newest::(cx).is_empty() { - InlineAssistKind::Generate + let selection = editor.read(cx).selections.newest_anchor().clone(); + let codegen_kind = if editor.read(cx).selections.newest::(cx).is_empty() { + CodegenKind::Generate { + position: selection.start, + } } else { - InlineAssistKind::Transform + CodegenKind::Transform { + range: selection.start..selection.end, + } }; + let codegen = cx.add_model(|cx| { + Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx) + }); + let measurements = Rc::new(Cell::new(BlockMeasurements::default())); let inline_assistant = cx.add_view(|cx| { let assistant = InlineAssistant::new( inline_assist_id, - assist_kind, measurements.clone(), self.include_conversation_in_next_inline_assist, self.inline_prompt_history.clone(), @@ -330,7 +334,6 @@ impl AssistantPanel { self.pending_inline_assists.insert( inline_assist_id, PendingInlineAssist { - kind: assist_kind, editor: editor.downgrade(), inline_assistant: Some((block_id, inline_assistant.clone())), codegen: codegen.clone(), @@ -348,6 +351,14 @@ impl AssistantPanel { } } }), + cx.observe(&codegen, { + let editor = editor.downgrade(); + move |this, _, cx| { + if let Some(editor) = editor.upgrade(cx) { + this.update_highlights_for_editor(&editor, cx); + } + } + }), cx.subscribe(&codegen, move |this, codegen, event, cx| match event { codegen::Event::Undone => { this.finish_inline_assist(inline_assist_id, false, cx) @@ -542,8 +553,8 @@ impl AssistantPanel { if let Some(language_name) = language_name { writeln!(prompt, "You're an expert {language_name} engineer.").unwrap(); } - match pending_assist.kind { - InlineAssistKind::Transform => { + match pending_assist.codegen.read(cx).kind() { + CodegenKind::Transform { .. } => { writeln!( prompt, "You're currently working inside an editor on this file:" @@ -583,7 +594,7 @@ impl AssistantPanel { ) .unwrap(); } - InlineAssistKind::Generate => { + CodegenKind::Generate { .. } => { writeln!( prompt, "You're currently working inside an editor on this file:" @@ -2649,12 +2660,6 @@ enum InlineAssistantEvent { }, } -#[derive(Copy, Clone)] -enum InlineAssistKind { - Transform, - Generate, -} - struct InlineAssistant { id: usize, prompt_editor: ViewHandle, @@ -2769,7 +2774,6 @@ impl View for InlineAssistant { impl InlineAssistant { fn new( id: usize, - kind: InlineAssistKind, measurements: Rc>, include_conversation: bool, prompt_history: VecDeque, @@ -2781,9 +2785,9 @@ impl InlineAssistant { Some(Arc::new(|theme| theme.assistant.inline.editor.clone())), cx, ); - let placeholder = match kind { - InlineAssistKind::Transform => "Enter transformation prompt…", - InlineAssistKind::Generate => "Enter generation prompt…", + let placeholder = match codegen.read(cx).kind() { + CodegenKind::Transform { .. } => "Enter transformation prompt…", + CodegenKind::Generate { .. } => "Enter generation prompt…", }; editor.set_placeholder_text(placeholder, cx); editor @@ -2929,7 +2933,6 @@ struct BlockMeasurements { } struct PendingInlineAssist { - kind: InlineAssistKind, editor: WeakViewHandle, inline_assistant: Option<(BlockId, ViewHandle)>, codegen: ModelHandle, diff --git a/crates/ai/src/codegen.rs b/crates/ai/src/codegen.rs index 9657d9a4926c17eadb5ae7d78a020ee6342deacf..bc13e294fa8ca9a6e49c811a789ce663d4e1e37e 100644 --- a/crates/ai/src/codegen.rs +++ b/crates/ai/src/codegen.rs @@ -4,12 +4,14 @@ use crate::{ OpenAIRequest, }; use anyhow::Result; -use editor::{multi_buffer, Anchor, MultiBuffer, ToOffset, ToPoint}; +use editor::{ + multi_buffer, Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint, +}; use futures::{ channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, SinkExt, Stream, StreamExt, }; use gpui::{executor::Background, Entity, ModelContext, ModelHandle, Task}; -use language::{IndentSize, Point, Rope, TransactionId}; +use language::{Rope, TransactionId}; use std::{cmp, future, ops::Range, sync::Arc}; pub trait CompletionProvider { @@ -57,10 +59,17 @@ pub enum Event { Undone, } +#[derive(Clone)] +pub enum CodegenKind { + Transform { range: Range }, + Generate { position: Anchor }, +} + pub struct Codegen { provider: Arc, buffer: ModelHandle, - range: Range, + snapshot: MultiBufferSnapshot, + kind: CodegenKind, last_equal_ranges: Vec>, transaction_id: Option, error: Option, @@ -76,14 +85,31 @@ impl Entity for Codegen { impl Codegen { pub fn new( buffer: ModelHandle, - range: Range, + mut kind: CodegenKind, provider: Arc, cx: &mut ModelContext, ) -> Self { + let snapshot = buffer.read(cx).snapshot(cx); + match &mut kind { + CodegenKind::Transform { range } => { + let mut point_range = range.to_point(&snapshot); + point_range.start.column = 0; + if point_range.end.column > 0 || point_range.start.row == point_range.end.row { + point_range.end.column = snapshot.line_len(point_range.end.row); + } + range.start = snapshot.anchor_before(point_range.start); + range.end = snapshot.anchor_after(point_range.end); + } + CodegenKind::Generate { position } => { + *position = position.bias_right(&snapshot); + } + } + Self { provider, buffer: buffer.clone(), - range, + snapshot, + kind, last_equal_ranges: Default::default(), transaction_id: Default::default(), error: Default::default(), @@ -109,7 +135,14 @@ impl Codegen { } pub fn range(&self) -> Range { - self.range.clone() + match &self.kind { + CodegenKind::Transform { range } => range.clone(), + CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position, + } + } + + pub fn kind(&self) -> &CodegenKind { + &self.kind } pub fn last_equal_ranges(&self) -> &[Range] { @@ -125,56 +158,18 @@ impl Codegen { } pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext) { - let range = self.range.clone(); - let snapshot = self.buffer.read(cx).snapshot(cx); + let range = self.range(); + let snapshot = self.snapshot.clone(); let selected_text = snapshot .text_for_range(range.start..range.end) .collect::(); let selection_start = range.start.to_point(&snapshot); - let selection_end = range.end.to_point(&snapshot); - - let mut base_indent: Option = None; - let mut start_row = selection_start.row; - if snapshot.is_line_blank(start_row) { - if let Some(prev_non_blank_row) = snapshot.prev_non_blank_row(start_row) { - start_row = prev_non_blank_row; - } - } - for row in start_row..=selection_end.row { - if snapshot.is_line_blank(row) { - continue; - } - - let line_indent = snapshot.indent_size_for_line(row); - if let Some(base_indent) = base_indent.as_mut() { - if line_indent.len < base_indent.len { - *base_indent = line_indent; - } - } else { - base_indent = Some(line_indent); - } - } - - let mut normalized_selected_text = selected_text.clone(); - if let Some(base_indent) = base_indent { - for row in selection_start.row..=selection_end.row { - let selection_row = row - selection_start.row; - let line_start = - normalized_selected_text.point_to_offset(Point::new(selection_row, 0)); - let indent_len = if row == selection_start.row { - base_indent.len.saturating_sub(selection_start.column) - } else { - let line_len = normalized_selected_text.line_len(selection_row); - cmp::min(line_len, base_indent.len) - }; - let indent_end = cmp::min( - line_start + indent_len as usize, - normalized_selected_text.len(), - ); - normalized_selected_text.replace(line_start..indent_end, ""); - } - } + let suggested_line_indent = snapshot + .suggested_indents(selection_start.row..selection_start.row + 1, cx) + .into_values() + .next() + .unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row)); let response = self.provider.complete(prompt); self.generation = cx.spawn_weak(|this, mut cx| { @@ -188,66 +183,58 @@ impl Codegen { futures::pin_mut!(chunks); let mut diff = StreamingDiff::new(selected_text.to_string()); - let mut indent_len; - let indent_text; - if let Some(base_indent) = base_indent { - indent_len = base_indent.len; - indent_text = match base_indent.kind { - language::IndentKind::Space => " ", - language::IndentKind::Tab => "\t", - }; - } else { - indent_len = 0; - indent_text = ""; - }; - - let mut first_line_len = 0; - let mut first_line_non_whitespace_char_ix = None; - let mut first_line = true; let mut new_text = String::new(); + let mut base_indent = None; + let mut line_indent = None; + let mut first_line = true; while let Some(chunk) = chunks.next().await { let chunk = chunk?; - let mut lines = chunk.split('\n'); - if let Some(mut line) = lines.next() { - if first_line { - if first_line_non_whitespace_char_ix.is_none() { - if let Some(mut char_ix) = - line.find(|ch: char| !ch.is_whitespace()) - { - line = &line[char_ix..]; - char_ix += first_line_len; - first_line_non_whitespace_char_ix = Some(char_ix); - let first_line_indent = char_ix - .saturating_sub(selection_start.column as usize) - as usize; - new_text - .push_str(&indent_text.repeat(first_line_indent)); - indent_len = indent_len.saturating_sub(char_ix as u32); + let mut lines = chunk.split('\n').peekable(); + while let Some(line) = lines.next() { + new_text.push_str(line); + if line_indent.is_none() { + if let Some(non_whitespace_ch_ix) = + new_text.find(|ch: char| !ch.is_whitespace()) + { + line_indent = Some(non_whitespace_ch_ix); + base_indent = base_indent.or(line_indent); + + let line_indent = line_indent.unwrap(); + let base_indent = base_indent.unwrap(); + let indent_delta = line_indent as i32 - base_indent as i32; + let mut corrected_indent_len = cmp::max( + 0, + suggested_line_indent.len as i32 + indent_delta, + ) + as usize; + if first_line { + corrected_indent_len = corrected_indent_len + .saturating_sub(selection_start.column as usize); } - } - first_line_len += line.len(); - } - if first_line_non_whitespace_char_ix.is_some() { - new_text.push_str(line); + let indent_char = suggested_line_indent.char(); + let mut indent_buffer = [0; 4]; + let indent_str = + indent_char.encode_utf8(&mut indent_buffer); + new_text.replace_range( + ..line_indent, + &indent_str.repeat(corrected_indent_len), + ); + } } - } - for line in lines { - first_line = false; - new_text.push('\n'); - if !line.is_empty() { - new_text.push_str(&indent_text.repeat(indent_len as usize)); + if lines.peek().is_some() { + hunks_tx.send(diff.push_new(&new_text)).await?; + hunks_tx.send(diff.push_new("\n")).await?; + new_text.clear(); + line_indent = None; + first_line = false; } - new_text.push_str(line); } - - let hunks = diff.push_new(&new_text); - hunks_tx.send(hunks).await?; - new_text.clear(); } + hunks_tx.send(diff.push_new(&new_text)).await?; hunks_tx.send(diff.finish()).await?; anyhow::Ok(()) @@ -285,7 +272,7 @@ impl Codegen { let edit_end = edit_start + len; let edit_range = snapshot.anchor_after(edit_start) ..snapshot.anchor_before(edit_end); - edit_start += len; + edit_start = edit_end; this.last_equal_ranges.push(edit_range); None } @@ -410,16 +397,20 @@ mod tests { use futures::stream; use gpui::{executor::Deterministic, TestAppContext}; use indoc::indoc; - use language::{tree_sitter_rust, Buffer, Language, LanguageConfig}; + use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point}; use parking_lot::Mutex; use rand::prelude::*; + use settings::SettingsStore; #[gpui::test(iterations = 10)] - async fn test_autoindent( + async fn test_transform_autoindent( cx: &mut TestAppContext, mut rng: StdRng, deterministic: Arc, ) { + cx.set_global(cx.read(SettingsStore::test)); + cx.update(language_settings::init); + let text = indoc! {" fn main() { let x = 0; @@ -436,15 +427,146 @@ mod tests { snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4)) }); let provider = Arc::new(TestCompletionProvider::new()); - let codegen = cx.add_model(|cx| Codegen::new(buffer.clone(), range, provider.clone(), cx)); + let codegen = cx.add_model(|cx| { + Codegen::new( + buffer.clone(), + CodegenKind::Transform { range }, + provider.clone(), + cx, + ) + }); codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); - let mut new_text = indoc! {" - let mut x = 0; - while x < 10 { - x += 1; - } + let mut new_text = concat!( + " let mut x = 0;\n", + " while x < 10 {\n", + " x += 1;\n", + " }", + ); + while !new_text.is_empty() { + let max_len = cmp::min(new_text.len(), 10); + let len = rng.gen_range(1..=max_len); + let (chunk, suffix) = new_text.split_at(len); + provider.send_completion(chunk); + new_text = suffix; + deterministic.run_until_parked(); + } + provider.finish_completion(); + deterministic.run_until_parked(); + + assert_eq!( + buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), + indoc! {" + fn main() { + let mut x = 0; + while x < 10 { + x += 1; + } + } + "} + ); + } + + #[gpui::test(iterations = 10)] + async fn test_autoindent_when_generating_past_indentation( + cx: &mut TestAppContext, + mut rng: StdRng, + deterministic: Arc, + ) { + cx.set_global(cx.read(SettingsStore::test)); + cx.update(language_settings::init); + + let text = indoc! {" + fn main() { + le + } "}; + let buffer = + cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx)); + let position = buffer.read_with(cx, |buffer, cx| { + let snapshot = buffer.snapshot(cx); + snapshot.anchor_before(Point::new(1, 6)) + }); + let provider = Arc::new(TestCompletionProvider::new()); + let codegen = cx.add_model(|cx| { + Codegen::new( + buffer.clone(), + CodegenKind::Generate { position }, + provider.clone(), + cx, + ) + }); + codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); + + let mut new_text = concat!( + "t mut x = 0;\n", + "while x < 10 {\n", + " x += 1;\n", + "}", // + ); + while !new_text.is_empty() { + let max_len = cmp::min(new_text.len(), 10); + let len = rng.gen_range(1..=max_len); + let (chunk, suffix) = new_text.split_at(len); + provider.send_completion(chunk); + new_text = suffix; + deterministic.run_until_parked(); + } + provider.finish_completion(); + deterministic.run_until_parked(); + + assert_eq!( + buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), + indoc! {" + fn main() { + let mut x = 0; + while x < 10 { + x += 1; + } + } + "} + ); + } + + #[gpui::test(iterations = 10)] + async fn test_autoindent_when_generating_before_indentation( + cx: &mut TestAppContext, + mut rng: StdRng, + deterministic: Arc, + ) { + cx.set_global(cx.read(SettingsStore::test)); + cx.update(language_settings::init); + + let text = concat!( + "fn main() {\n", + " \n", + "}\n" // + ); + let buffer = + cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx)); + let position = buffer.read_with(cx, |buffer, cx| { + let snapshot = buffer.snapshot(cx); + snapshot.anchor_before(Point::new(1, 2)) + }); + let provider = Arc::new(TestCompletionProvider::new()); + let codegen = cx.add_model(|cx| { + Codegen::new( + buffer.clone(), + CodegenKind::Generate { position }, + provider.clone(), + cx, + ) + }); + codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); + + let mut new_text = concat!( + "let mut x = 0;\n", + "while x < 10 {\n", + " x += 1;\n", + "}", // + ); while !new_text.is_empty() { let max_len = cmp::min(new_text.len(), 10); let len = rng.gen_range(1..=max_len); From 127d03516fb40a16f1cb362d9512653b57fffe67 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 13 Sep 2023 11:57:10 +0200 Subject: [PATCH 4/4] Diff lines one chunk at a time after discovering indentation --- crates/ai/src/codegen.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/crates/ai/src/codegen.rs b/crates/ai/src/codegen.rs index bc13e294fa8ca9a6e49c811a789ce663d4e1e37e..e7da46cdf95f50927f0f9350f45dfb361bdfbd2e 100644 --- a/crates/ai/src/codegen.rs +++ b/crates/ai/src/codegen.rs @@ -225,10 +225,13 @@ impl Codegen { } } - if lines.peek().is_some() { + if line_indent.is_some() { hunks_tx.send(diff.push_new(&new_text)).await?; - hunks_tx.send(diff.push_new("\n")).await?; new_text.clear(); + } + + if lines.peek().is_some() { + hunks_tx.send(diff.push_new("\n")).await?; line_indent = None; first_line = false; }