Extract code generation logic into its own module

Antonio Scandurra created

Change summary

crates/ai/src/ai.rs               |   1 
crates/ai/src/assistant.rs        | 553 ++++++--------------------------
crates/ai/src/codegen.rs          | 468 +++++++++++++++++++++++++++
crates/editor/src/editor.rs       |  10 
crates/editor/src/multi_buffer.rs |  43 +-
5 files changed, 607 insertions(+), 468 deletions(-)

Detailed changes

crates/ai/src/ai.rs 🔗

@@ -1,5 +1,6 @@
 pub mod assistant;
 mod assistant_settings;
+mod codegen;
 mod streaming_diff;
 
 use anyhow::{anyhow, Result};

crates/ai/src/assistant.rs 🔗

@@ -1,9 +1,8 @@
 use crate::{
     assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
-    stream_completion,
-    streaming_diff::{Hunk, StreamingDiff},
-    MessageId, MessageMetadata, MessageStatus, OpenAIRequest, RequestMessage, Role,
-    SavedConversation, SavedConversationMetadata, SavedMessage, OPENAI_API_URL,
+    codegen::{self, Codegen, OpenAICompletionProvider},
+    stream_completion, MessageId, MessageMetadata, MessageStatus, OpenAIRequest, RequestMessage,
+    Role, SavedConversation, SavedConversationMetadata, SavedMessage, OPENAI_API_URL,
 };
 use anyhow::{anyhow, Result};
 use chrono::{DateTime, Local};
@@ -13,10 +12,10 @@ use editor::{
         BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint,
     },
     scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
-    Anchor, Editor, MoveDown, MoveUp, MultiBufferSnapshot, ToOffset, ToPoint,
+    Anchor, Editor, MoveDown, MoveUp, MultiBufferSnapshot, ToOffset,
 };
 use fs::Fs;
-use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
+use futures::StreamExt;
 use gpui::{
     actions,
     elements::{
@@ -30,17 +29,14 @@ use gpui::{
     ModelHandle, SizeConstraint, Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle,
     WindowContext,
 };
-use language::{
-    language_settings::SoftWrap, Buffer, LanguageRegistry, Point, Rope, ToOffset as _,
-    TransactionId,
-};
+use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
 use search::BufferSearchBar;
 use settings::SettingsStore;
 use std::{
     cell::{Cell, RefCell},
     cmp, env,
     fmt::Write,
-    future, iter,
+    iter,
     ops::Range,
     path::{Path, PathBuf},
     rc::Rc,
@@ -266,10 +262,22 @@ impl AssistantPanel {
     }
 
     fn new_inline_assist(&mut self, editor: &ViewHandle<Editor>, cx: &mut ViewContext<Self>) {
+        let api_key = if let Some(api_key) = self.api_key.borrow().clone() {
+            api_key
+        } else {
+            return;
+        };
+
         let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
         let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
         let selection = editor.read(cx).selections.newest_anchor().clone();
         let range = selection.start.bias_left(&snapshot)..selection.end.bias_right(&snapshot);
+        let provider = Arc::new(OpenAICompletionProvider::new(
+            api_key,
+            cx.background().clone(),
+        ));
+        let codegen =
+            cx.add_model(|cx| Codegen::new(editor.read(cx).buffer().clone(), range, provider, cx));
         let assist_kind = if editor.read(cx).selections.newest::<usize>(cx).is_empty() {
             InlineAssistKind::Generate
         } else {
@@ -283,6 +291,7 @@ impl AssistantPanel {
                 measurements.clone(),
                 self.include_conversation_in_next_inline_assist,
                 self.inline_prompt_history.clone(),
+                codegen.clone(),
                 cx,
             );
             cx.focus_self();
@@ -323,44 +332,53 @@ impl AssistantPanel {
             PendingInlineAssist {
                 kind: assist_kind,
                 editor: editor.downgrade(),
-                range,
-                highlighted_ranges: Default::default(),
                 inline_assistant: Some((block_id, inline_assistant.clone())),
-                code_generation: Task::ready(None),
-                transaction_id: None,
+                codegen: codegen.clone(),
                 _subscriptions: vec![
                     cx.subscribe(&inline_assistant, Self::handle_inline_assistant_event),
                     cx.subscribe(editor, {
                         let inline_assistant = inline_assistant.downgrade();
-                        move |this, editor, event, cx| {
+                        move |_, editor, event, cx| {
                             if let Some(inline_assistant) = inline_assistant.upgrade(cx) {
-                                match event {
-                                    editor::Event::SelectionsChanged { local } => {
-                                        if *local && inline_assistant.read(cx).has_focus {
-                                            cx.focus(&editor);
-                                        }
+                                if let editor::Event::SelectionsChanged { local } = event {
+                                    if *local && inline_assistant.read(cx).has_focus {
+                                        cx.focus(&editor);
                                     }
-                                    editor::Event::TransactionUndone {
-                                        transaction_id: tx_id,
-                                    } => {
-                                        if let Some(pending_assist) =
-                                            this.pending_inline_assists.get(&inline_assist_id)
-                                        {
-                                            if pending_assist.transaction_id == Some(*tx_id) {
-                                                // Notice we are supplying `undo: false` here. This
-                                                // is because there's no need to undo the transaction
-                                                // because the user just did so.
-                                                this.close_inline_assist(
-                                                    inline_assist_id,
-                                                    false,
-                                                    cx,
-                                                );
-                                            }
-                                        }
+                                }
+                            }
+                        }
+                    }),
+                    cx.subscribe(&codegen, move |this, codegen, event, cx| match event {
+                        codegen::Event::Undone => {
+                            this.finish_inline_assist(inline_assist_id, false, cx)
+                        }
+                        codegen::Event::Finished => {
+                            let pending_assist = if let Some(pending_assist) =
+                                this.pending_inline_assists.get(&inline_assist_id)
+                            {
+                                pending_assist
+                            } else {
+                                return;
+                            };
+
+                            let error = codegen
+                                .read(cx)
+                                .error()
+                                .map(|error| format!("Inline assistant error: {}", error));
+                            if let Some(error) = error {
+                                if pending_assist.inline_assistant.is_none() {
+                                    if let Some(workspace) = this.workspace.upgrade(cx) {
+                                        workspace.update(cx, |workspace, cx| {
+                                            workspace.show_toast(
+                                                Toast::new(inline_assist_id, error),
+                                                cx,
+                                            );
+                                        })
                                     }
-                                    _ => {}
                                 }
                             }
+
+                            this.finish_inline_assist(inline_assist_id, false, cx);
                         }
                     }),
                 ],
@@ -388,7 +406,7 @@ impl AssistantPanel {
                 self.confirm_inline_assist(assist_id, prompt, *include_conversation, cx);
             }
             InlineAssistantEvent::Canceled => {
-                self.close_inline_assist(assist_id, true, cx);
+                self.finish_inline_assist(assist_id, true, cx);
             }
             InlineAssistantEvent::Dismissed => {
                 self.hide_inline_assist(assist_id, cx);
@@ -417,7 +435,7 @@ impl AssistantPanel {
                         .get(&editor.downgrade())
                         .and_then(|assist_ids| assist_ids.last().copied())
                     {
-                        panel.close_inline_assist(assist_id, true, cx);
+                        panel.finish_inline_assist(assist_id, true, cx);
                         true
                     } else {
                         false
@@ -432,7 +450,7 @@ impl AssistantPanel {
         cx.propagate_action();
     }
 
-    fn close_inline_assist(&mut self, assist_id: usize, undo: bool, cx: &mut ViewContext<Self>) {
+    fn finish_inline_assist(&mut self, assist_id: usize, undo: bool, cx: &mut ViewContext<Self>) {
         self.hide_inline_assist(assist_id, cx);
 
         if let Some(pending_assist) = self.pending_inline_assists.remove(&assist_id) {
@@ -450,13 +468,9 @@ impl AssistantPanel {
                 self.update_highlights_for_editor(&editor, cx);
 
                 if undo {
-                    if let Some(transaction_id) = pending_assist.transaction_id {
-                        editor.update(cx, |editor, cx| {
-                            editor.buffer().update(cx, |buffer, cx| {
-                                buffer.undo_transaction(transaction_id, cx)
-                            });
-                        });
-                    }
+                    pending_assist
+                        .codegen
+                        .update(cx, |codegen, cx| codegen.undo(cx));
                 }
             }
         }
@@ -481,12 +495,6 @@ impl AssistantPanel {
         include_conversation: bool,
         cx: &mut ViewContext<Self>,
     ) {
-        let api_key = if let Some(api_key) = self.api_key.borrow().clone() {
-            api_key
-        } else {
-            return;
-        };
-
         let conversation = if include_conversation {
             self.active_editor()
                 .map(|editor| editor.read(cx).conversation.clone())
@@ -514,56 +522,9 @@ impl AssistantPanel {
             self.inline_prompt_history.pop_front();
         }
 
-        let range = pending_assist.range.clone();
         let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
-        let selected_text = snapshot
-            .text_for_range(range.start..range.end)
-            .collect::<Rope>();
-
-        let selection_start = range.start.to_point(&snapshot);
-        let selection_end = range.end.to_point(&snapshot);
-
-        let mut base_indent: Option<language::IndentSize> = None;
-        let mut start_row = selection_start.row;
-        if snapshot.is_line_blank(start_row) {
-            if let Some(prev_non_blank_row) = snapshot.prev_non_blank_row(start_row) {
-                start_row = prev_non_blank_row;
-            }
-        }
-        for row in start_row..=selection_end.row {
-            if snapshot.is_line_blank(row) {
-                continue;
-            }
-
-            let line_indent = snapshot.indent_size_for_line(row);
-            if let Some(base_indent) = base_indent.as_mut() {
-                if line_indent.len < base_indent.len {
-                    *base_indent = line_indent;
-                }
-            } else {
-                base_indent = Some(line_indent);
-            }
-        }
-
-        let mut normalized_selected_text = selected_text.clone();
-        if let Some(base_indent) = base_indent {
-            for row in selection_start.row..=selection_end.row {
-                let selection_row = row - selection_start.row;
-                let line_start =
-                    normalized_selected_text.point_to_offset(Point::new(selection_row, 0));
-                let indent_len = if row == selection_start.row {
-                    base_indent.len.saturating_sub(selection_start.column)
-                } else {
-                    let line_len = normalized_selected_text.line_len(selection_row);
-                    cmp::min(line_len, base_indent.len)
-                };
-                let indent_end = cmp::min(
-                    line_start + indent_len as usize,
-                    normalized_selected_text.len(),
-                );
-                normalized_selected_text.replace(line_start..indent_end, "");
-            }
-        }
+        let range = pending_assist.codegen.read(cx).range();
+        let selected_text = snapshot.text_for_range(range.clone()).collect::<String>();
 
         let language = snapshot.language_at(range.start);
         let language_name = if let Some(language) = language.as_ref() {
@@ -608,7 +569,7 @@ impl AssistantPanel {
                 } else {
                     writeln!(prompt, "```").unwrap();
                 }
-                writeln!(prompt, "{normalized_selected_text}").unwrap();
+                writeln!(prompt, "{selected_text}").unwrap();
                 writeln!(prompt, "```").unwrap();
                 writeln!(prompt).unwrap();
                 writeln!(
@@ -689,209 +650,9 @@ impl AssistantPanel {
             messages,
             stream: true,
         };
-        let response = stream_completion(api_key, cx.background().clone(), request);
-        let editor = editor.downgrade();
-
-        pending_assist.code_generation = cx.spawn(|this, mut cx| {
-            async move {
-                let mut edit_start = range.start.to_offset(&snapshot);
-
-                let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
-                let diff = cx.background().spawn(async move {
-                    let chunks = strip_markdown_codeblock(response.await?.filter_map(
-                        |message| async move {
-                            match message {
-                                Ok(mut message) => Some(Ok(message.choices.pop()?.delta.content?)),
-                                Err(error) => Some(Err(error)),
-                            }
-                        },
-                    ));
-                    futures::pin_mut!(chunks);
-                    let mut diff = StreamingDiff::new(selected_text.to_string());
-
-                    let mut indent_len;
-                    let indent_text;
-                    if let Some(base_indent) = base_indent {
-                        indent_len = base_indent.len;
-                        indent_text = match base_indent.kind {
-                            language::IndentKind::Space => " ",
-                            language::IndentKind::Tab => "\t",
-                        };
-                    } else {
-                        indent_len = 0;
-                        indent_text = "";
-                    };
-
-                    let mut first_line_len = 0;
-                    let mut first_line_non_whitespace_char_ix = None;
-                    let mut first_line = true;
-                    let mut new_text = String::new();
-
-                    while let Some(chunk) = chunks.next().await {
-                        let chunk = chunk?;
-
-                        let mut lines = chunk.split('\n');
-                        if let Some(mut line) = lines.next() {
-                            if first_line {
-                                if first_line_non_whitespace_char_ix.is_none() {
-                                    if let Some(mut char_ix) =
-                                        line.find(|ch: char| !ch.is_whitespace())
-                                    {
-                                        line = &line[char_ix..];
-                                        char_ix += first_line_len;
-                                        first_line_non_whitespace_char_ix = Some(char_ix);
-                                        let first_line_indent = char_ix
-                                            .saturating_sub(selection_start.column as usize)
-                                            as usize;
-                                        new_text.push_str(&indent_text.repeat(first_line_indent));
-                                        indent_len = indent_len.saturating_sub(char_ix as u32);
-                                    }
-                                }
-                                first_line_len += line.len();
-                            }
-
-                            if first_line_non_whitespace_char_ix.is_some() {
-                                new_text.push_str(line);
-                            }
-                        }
-
-                        for line in lines {
-                            first_line = false;
-                            new_text.push('\n');
-                            if !line.is_empty() {
-                                new_text.push_str(&indent_text.repeat(indent_len as usize));
-                            }
-                            new_text.push_str(line);
-                        }
-
-                        let hunks = diff.push_new(&new_text);
-                        hunks_tx.send(hunks).await?;
-                        new_text.clear();
-                    }
-                    hunks_tx.send(diff.finish()).await?;
-
-                    anyhow::Ok(())
-                });
-
-                while let Some(hunks) = hunks_rx.next().await {
-                    let editor = if let Some(editor) = editor.upgrade(&cx) {
-                        editor
-                    } else {
-                        break;
-                    };
-
-                    let this = if let Some(this) = this.upgrade(&cx) {
-                        this
-                    } else {
-                        break;
-                    };
-
-                    this.update(&mut cx, |this, cx| {
-                        let pending_assist = if let Some(pending_assist) =
-                            this.pending_inline_assists.get_mut(&inline_assist_id)
-                        {
-                            pending_assist
-                        } else {
-                            return;
-                        };
-
-                        pending_assist.highlighted_ranges.clear();
-                        editor.update(cx, |editor, cx| {
-                            let transaction = editor.buffer().update(cx, |buffer, cx| {
-                                // Avoid grouping assistant edits with user edits.
-                                buffer.finalize_last_transaction(cx);
-
-                                buffer.start_transaction(cx);
-                                buffer.edit(
-                                    hunks.into_iter().filter_map(|hunk| match hunk {
-                                        Hunk::Insert { text } => {
-                                            let edit_start = snapshot.anchor_after(edit_start);
-                                            Some((edit_start..edit_start, text))
-                                        }
-                                        Hunk::Remove { len } => {
-                                            let edit_end = edit_start + len;
-                                            let edit_range = snapshot.anchor_after(edit_start)
-                                                ..snapshot.anchor_before(edit_end);
-                                            edit_start = edit_end;
-                                            Some((edit_range, String::new()))
-                                        }
-                                        Hunk::Keep { len } => {
-                                            let edit_end = edit_start + len;
-                                            let edit_range = snapshot.anchor_after(edit_start)
-                                                ..snapshot.anchor_before(edit_end);
-                                            edit_start += len;
-                                            pending_assist.highlighted_ranges.push(edit_range);
-                                            None
-                                        }
-                                    }),
-                                    None,
-                                    cx,
-                                );
-
-                                buffer.end_transaction(cx)
-                            });
-
-                            if let Some(transaction) = transaction {
-                                if let Some(first_transaction) = pending_assist.transaction_id {
-                                    // Group all assistant edits into the first transaction.
-                                    editor.buffer().update(cx, |buffer, cx| {
-                                        buffer.merge_transactions(
-                                            transaction,
-                                            first_transaction,
-                                            cx,
-                                        )
-                                    });
-                                } else {
-                                    pending_assist.transaction_id = Some(transaction);
-                                    editor.buffer().update(cx, |buffer, cx| {
-                                        buffer.finalize_last_transaction(cx)
-                                    });
-                                }
-                            }
-                        });
-
-                        this.update_highlights_for_editor(&editor, cx);
-                    });
-                }
-
-                if let Err(error) = diff.await {
-                    this.update(&mut cx, |this, cx| {
-                        let pending_assist = if let Some(pending_assist) =
-                            this.pending_inline_assists.get_mut(&inline_assist_id)
-                        {
-                            pending_assist
-                        } else {
-                            return;
-                        };
-
-                        if let Some((_, inline_assistant)) =
-                            pending_assist.inline_assistant.as_ref()
-                        {
-                            inline_assistant.update(cx, |inline_assistant, cx| {
-                                inline_assistant.set_error(error, cx);
-                            });
-                        } else if let Some(workspace) = this.workspace.upgrade(cx) {
-                            workspace.update(cx, |workspace, cx| {
-                                workspace.show_toast(
-                                    Toast::new(
-                                        inline_assist_id,
-                                        format!("Inline assistant error: {}", error),
-                                    ),
-                                    cx,
-                                );
-                            })
-                        }
-                    })?;
-                } else {
-                    let _ = this.update(&mut cx, |this, cx| {
-                        this.close_inline_assist(inline_assist_id, false, cx)
-                    });
-                }
-
-                anyhow::Ok(())
-            }
-            .log_err()
-        });
+        pending_assist
+            .codegen
+            .update(cx, |codegen, cx| codegen.start(request, cx));
     }
 
     fn update_highlights_for_editor(
@@ -909,8 +670,9 @@ impl AssistantPanel {
 
         for inline_assist_id in inline_assist_ids {
             if let Some(pending_assist) = self.pending_inline_assists.get(inline_assist_id) {
-                background_ranges.push(pending_assist.range.clone());
-                foreground_ranges.extend(pending_assist.highlighted_ranges.iter().cloned());
+                let codegen = pending_assist.codegen.read(cx);
+                background_ranges.push(codegen.range());
+                foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
             }
         }
 
@@ -2900,11 +2662,11 @@ struct InlineAssistant {
     has_focus: bool,
     include_conversation: bool,
     measurements: Rc<Cell<BlockMeasurements>>,
-    error: Option<anyhow::Error>,
     prompt_history: VecDeque<String>,
     prompt_history_ix: Option<usize>,
     pending_prompt: String,
-    _subscription: Subscription,
+    codegen: ModelHandle<Codegen>,
+    _subscriptions: Vec<Subscription>,
 }
 
 impl Entity for InlineAssistant {
@@ -2933,7 +2695,7 @@ impl View for InlineAssistant {
                             .element()
                             .aligned(),
                     )
-                    .with_children(if let Some(error) = self.error.as_ref() {
+                    .with_children(if let Some(error) = self.codegen.read(cx).error() {
                         Some(
                             Svg::new("icons/circle_x_mark_12.svg")
                                 .with_color(theme.assistant.error_icon.color)
@@ -3011,6 +2773,7 @@ impl InlineAssistant {
         measurements: Rc<Cell<BlockMeasurements>>,
         include_conversation: bool,
         prompt_history: VecDeque<String>,
+        codegen: ModelHandle<Codegen>,
         cx: &mut ViewContext<Self>,
     ) -> Self {
         let prompt_editor = cx.add_view(|cx| {
@@ -3025,7 +2788,10 @@ impl InlineAssistant {
             editor.set_placeholder_text(placeholder, cx);
             editor
         });
-        let subscription = cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events);
+        let subscriptions = vec![
+            cx.observe(&codegen, Self::handle_codegen_changed),
+            cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events),
+        ];
         Self {
             id,
             prompt_editor,
@@ -3033,11 +2799,11 @@ impl InlineAssistant {
             has_focus: false,
             include_conversation,
             measurements,
-            error: None,
             prompt_history,
             prompt_history_ix: None,
             pending_prompt: String::new(),
-            _subscription: subscription,
+            codegen,
+            _subscriptions: subscriptions,
         }
     }
 
@@ -3053,6 +2819,31 @@ impl InlineAssistant {
         }
     }
 
+    fn handle_codegen_changed(&mut self, _: ModelHandle<Codegen>, cx: &mut ViewContext<Self>) {
+        let is_read_only = !self.codegen.read(cx).idle();
+        self.prompt_editor.update(cx, |editor, cx| {
+            let was_read_only = editor.read_only();
+            if was_read_only != is_read_only {
+                if is_read_only {
+                    editor.set_read_only(true);
+                    editor.set_field_editor_style(
+                        Some(Arc::new(|theme| {
+                            theme.assistant.inline.disabled_editor.clone()
+                        })),
+                        cx,
+                    );
+                } else {
+                    editor.set_read_only(false);
+                    editor.set_field_editor_style(
+                        Some(Arc::new(|theme| theme.assistant.inline.editor.clone())),
+                        cx,
+                    );
+                }
+            }
+        });
+        cx.notify();
+    }
+
     fn cancel(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
         cx.emit(InlineAssistantEvent::Canceled);
     }
@@ -3076,7 +2867,6 @@ impl InlineAssistant {
                 include_conversation: self.include_conversation,
             });
             self.confirmed = true;
-            self.error = None;
             cx.notify();
         }
     }
@@ -3093,19 +2883,6 @@ impl InlineAssistant {
         cx.notify();
     }
 
-    fn set_error(&mut self, error: anyhow::Error, cx: &mut ViewContext<Self>) {
-        self.error = Some(error);
-        self.confirmed = false;
-        self.prompt_editor.update(cx, |editor, cx| {
-            editor.set_read_only(false);
-            editor.set_field_editor_style(
-                Some(Arc::new(|theme| theme.assistant.inline.editor.clone())),
-                cx,
-            );
-        });
-        cx.notify();
-    }
-
     fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
         if let Some(ix) = self.prompt_history_ix {
             if ix > 0 {
@@ -3154,11 +2931,8 @@ struct BlockMeasurements {
 struct PendingInlineAssist {
     kind: InlineAssistKind,
     editor: WeakViewHandle<Editor>,
-    range: Range<Anchor>,
-    highlighted_ranges: Vec<Range<Anchor>>,
     inline_assistant: Option<(BlockId, ViewHandle<InlineAssistant>)>,
-    code_generation: Task<Option<()>>,
-    transaction_id: Option<TransactionId>,
+    codegen: ModelHandle<Codegen>,
     _subscriptions: Vec<Subscription>,
 }
 
@@ -3184,65 +2958,10 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
     }
 }
 
-fn strip_markdown_codeblock(
-    stream: impl Stream<Item = Result<String>>,
-) -> impl Stream<Item = Result<String>> {
-    let mut first_line = true;
-    let mut buffer = String::new();
-    let mut starts_with_fenced_code_block = false;
-    stream.filter_map(move |chunk| {
-        let chunk = match chunk {
-            Ok(chunk) => chunk,
-            Err(err) => return future::ready(Some(Err(err))),
-        };
-        buffer.push_str(&chunk);
-
-        if first_line {
-            if buffer == "" || buffer == "`" || buffer == "``" {
-                return future::ready(None);
-            } else if buffer.starts_with("```") {
-                starts_with_fenced_code_block = true;
-                if let Some(newline_ix) = buffer.find('\n') {
-                    buffer.replace_range(..newline_ix + 1, "");
-                    first_line = false;
-                } else {
-                    return future::ready(None);
-                }
-            }
-        }
-
-        let text = if starts_with_fenced_code_block {
-            buffer
-                .strip_suffix("\n```\n")
-                .or_else(|| buffer.strip_suffix("\n```"))
-                .or_else(|| buffer.strip_suffix("\n``"))
-                .or_else(|| buffer.strip_suffix("\n`"))
-                .or_else(|| buffer.strip_suffix('\n'))
-                .unwrap_or(&buffer)
-        } else {
-            &buffer
-        };
-
-        if text.contains('\n') {
-            first_line = false;
-        }
-
-        let remainder = buffer.split_off(text.len());
-        let result = if buffer.is_empty() {
-            None
-        } else {
-            Some(Ok(buffer.clone()))
-        };
-        buffer = remainder;
-        future::ready(result)
-    })
-}
-
 #[cfg(test)]
 mod tests {
     use super::*;
     use crate::MessageId;
-    use futures::stream;
     use gpui::AppContext;
 
     #[gpui::test]
@@ -3611,62 +3330,6 @@ mod tests {
         );
     }
 
-    #[gpui::test]
-    async fn test_strip_markdown_codeblock() {
-        assert_eq!(
-            strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
-                .map(|chunk| chunk.unwrap())
-                .collect::<String>()
-                .await,
-            "Lorem ipsum dolor"
-        );
-        assert_eq!(
-            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
-                .map(|chunk| chunk.unwrap())
-                .collect::<String>()
-                .await,
-            "Lorem ipsum dolor"
-        );
-        assert_eq!(
-            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
-                .map(|chunk| chunk.unwrap())
-                .collect::<String>()
-                .await,
-            "Lorem ipsum dolor"
-        );
-        assert_eq!(
-            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
-                .map(|chunk| chunk.unwrap())
-                .collect::<String>()
-                .await,
-            "Lorem ipsum dolor"
-        );
-        assert_eq!(
-            strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
-                .map(|chunk| chunk.unwrap())
-                .collect::<String>()
-                .await,
-            "```js\nLorem ipsum dolor\n```"
-        );
-        assert_eq!(
-            strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
-                .map(|chunk| chunk.unwrap())
-                .collect::<String>()
-                .await,
-            "``\nLorem ipsum dolor\n```"
-        );
-
-        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
-            stream::iter(
-                text.chars()
-                    .collect::<Vec<_>>()
-                    .chunks(size)
-                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
-                    .collect::<Vec<_>>(),
-            )
-        }
-    }
-
     fn messages(
         conversation: &ModelHandle<Conversation>,
         cx: &AppContext,

crates/ai/src/codegen.rs 🔗

@@ -0,0 +1,468 @@
+use crate::{
+    stream_completion,
+    streaming_diff::{Hunk, StreamingDiff},
+    OpenAIRequest,
+};
+use anyhow::Result;
+use editor::{multi_buffer, Anchor, MultiBuffer, ToOffset, ToPoint};
+use futures::{
+    channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, SinkExt, Stream, StreamExt,
+};
+use gpui::{executor::Background, Entity, ModelContext, ModelHandle, Task};
+use language::{IndentSize, Point, Rope, TransactionId};
+use std::{cmp, future, ops::Range, sync::Arc};
+
+pub trait CompletionProvider {
+    fn complete(
+        &self,
+        prompt: OpenAIRequest,
+    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
+}
+
+pub struct OpenAICompletionProvider {
+    api_key: String,
+    executor: Arc<Background>,
+}
+
+impl OpenAICompletionProvider {
+    pub fn new(api_key: String, executor: Arc<Background>) -> Self {
+        Self { api_key, executor }
+    }
+}
+
+impl CompletionProvider for OpenAICompletionProvider {
+    fn complete(
+        &self,
+        prompt: OpenAIRequest,
+    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+        let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
+        async move {
+            let response = request.await?;
+            let stream = response
+                .filter_map(|response| async move {
+                    match response {
+                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
+                        Err(error) => Some(Err(error)),
+                    }
+                })
+                .boxed();
+            Ok(stream)
+        }
+        .boxed()
+    }
+}
+
+pub enum Event {
+    Finished,
+    Undone,
+}
+
+pub struct Codegen {
+    provider: Arc<dyn CompletionProvider>,
+    buffer: ModelHandle<MultiBuffer>,
+    range: Range<Anchor>,
+    last_equal_ranges: Vec<Range<Anchor>>,
+    transaction_id: Option<TransactionId>,
+    error: Option<anyhow::Error>,
+    generation: Task<()>,
+    idle: bool,
+    _subscription: gpui::Subscription,
+}
+
+impl Entity for Codegen {
+    type Event = Event;
+}
+
+impl Codegen {
+    pub fn new(
+        buffer: ModelHandle<MultiBuffer>,
+        range: Range<Anchor>,
+        provider: Arc<dyn CompletionProvider>,
+        cx: &mut ModelContext<Self>,
+    ) -> Self {
+        Self {
+            provider,
+            buffer: buffer.clone(),
+            range,
+            last_equal_ranges: Default::default(),
+            transaction_id: Default::default(),
+            error: Default::default(),
+            idle: true,
+            generation: Task::ready(()),
+            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
+        }
+    }
+
+    fn handle_buffer_event(
+        &mut self,
+        _buffer: ModelHandle<MultiBuffer>,
+        event: &multi_buffer::Event,
+        cx: &mut ModelContext<Self>,
+    ) {
+        if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
+            if self.transaction_id == Some(*transaction_id) {
+                self.transaction_id = None;
+                self.generation = Task::ready(());
+                cx.emit(Event::Undone);
+            }
+        }
+    }
+
+    pub fn range(&self) -> Range<Anchor> {
+        self.range.clone()
+    }
+
+    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
+        &self.last_equal_ranges
+    }
+
+    pub fn idle(&self) -> bool {
+        self.idle
+    }
+
+    pub fn error(&self) -> Option<&anyhow::Error> {
+        self.error.as_ref()
+    }
+
+    pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
+        let range = self.range.clone();
+        let snapshot = self.buffer.read(cx).snapshot(cx);
+        let selected_text = snapshot
+            .text_for_range(range.start..range.end)
+            .collect::<Rope>();
+
+        let selection_start = range.start.to_point(&snapshot);
+        let selection_end = range.end.to_point(&snapshot);
+
+        let mut base_indent: Option<IndentSize> = None;
+        let mut start_row = selection_start.row;
+        if snapshot.is_line_blank(start_row) {
+            if let Some(prev_non_blank_row) = snapshot.prev_non_blank_row(start_row) {
+                start_row = prev_non_blank_row;
+            }
+        }
+        for row in start_row..=selection_end.row {
+            if snapshot.is_line_blank(row) {
+                continue;
+            }
+
+            let line_indent = snapshot.indent_size_for_line(row);
+            if let Some(base_indent) = base_indent.as_mut() {
+                if line_indent.len < base_indent.len {
+                    *base_indent = line_indent;
+                }
+            } else {
+                base_indent = Some(line_indent);
+            }
+        }
+
+        let mut normalized_selected_text = selected_text.clone();
+        if let Some(base_indent) = base_indent {
+            for row in selection_start.row..=selection_end.row {
+                let selection_row = row - selection_start.row;
+                let line_start =
+                    normalized_selected_text.point_to_offset(Point::new(selection_row, 0));
+                let indent_len = if row == selection_start.row {
+                    base_indent.len.saturating_sub(selection_start.column)
+                } else {
+                    let line_len = normalized_selected_text.line_len(selection_row);
+                    cmp::min(line_len, base_indent.len)
+                };
+                let indent_end = cmp::min(
+                    line_start + indent_len as usize,
+                    normalized_selected_text.len(),
+                );
+                normalized_selected_text.replace(line_start..indent_end, "");
+            }
+        }
+
+        let response = self.provider.complete(prompt);
+        self.generation = cx.spawn_weak(|this, mut cx| {
+            async move {
+                let generate = async {
+                    let mut edit_start = range.start.to_offset(&snapshot);
+
+                    let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
+                    let diff = cx.background().spawn(async move {
+                        let chunks = strip_markdown_codeblock(response.await?);
+                        futures::pin_mut!(chunks);
+                        let mut diff = StreamingDiff::new(selected_text.to_string());
+
+                        let mut indent_len;
+                        let indent_text;
+                        if let Some(base_indent) = base_indent {
+                            indent_len = base_indent.len;
+                            indent_text = match base_indent.kind {
+                                language::IndentKind::Space => " ",
+                                language::IndentKind::Tab => "\t",
+                            };
+                        } else {
+                            indent_len = 0;
+                            indent_text = "";
+                        };
+
+                        let mut first_line_len = 0;
+                        let mut first_line_non_whitespace_char_ix = None;
+                        let mut first_line = true;
+                        let mut new_text = String::new();
+
+                        while let Some(chunk) = chunks.next().await {
+                            let chunk = chunk?;
+
+                            let mut lines = chunk.split('\n');
+                            if let Some(mut line) = lines.next() {
+                                if first_line {
+                                    if first_line_non_whitespace_char_ix.is_none() {
+                                        if let Some(mut char_ix) =
+                                            line.find(|ch: char| !ch.is_whitespace())
+                                        {
+                                            line = &line[char_ix..];
+                                            char_ix += first_line_len;
+                                            first_line_non_whitespace_char_ix = Some(char_ix);
+                                            let first_line_indent = char_ix
+                                                .saturating_sub(selection_start.column as usize)
+                                                as usize;
+                                            new_text
+                                                .push_str(&indent_text.repeat(first_line_indent));
+                                            indent_len = indent_len.saturating_sub(char_ix as u32);
+                                        }
+                                    }
+                                    first_line_len += line.len();
+                                }
+
+                                if first_line_non_whitespace_char_ix.is_some() {
+                                    new_text.push_str(line);
+                                }
+                            }
+
+                            for line in lines {
+                                first_line = false;
+                                new_text.push('\n');
+                                if !line.is_empty() {
+                                    new_text.push_str(&indent_text.repeat(indent_len as usize));
+                                }
+                                new_text.push_str(line);
+                            }
+
+                            let hunks = diff.push_new(&new_text);
+                            hunks_tx.send(hunks).await?;
+                            new_text.clear();
+                        }
+                        hunks_tx.send(diff.finish()).await?;
+
+                        anyhow::Ok(())
+                    });
+
+                    while let Some(hunks) = hunks_rx.next().await {
+                        let this = if let Some(this) = this.upgrade(&cx) {
+                            this
+                        } else {
+                            break;
+                        };
+
+                        this.update(&mut cx, |this, cx| {
+                            this.last_equal_ranges.clear();
+
+                            let transaction = this.buffer.update(cx, |buffer, cx| {
+                                // Avoid grouping assistant edits with user edits.
+                                buffer.finalize_last_transaction(cx);
+
+                                buffer.start_transaction(cx);
+                                buffer.edit(
+                                    hunks.into_iter().filter_map(|hunk| match hunk {
+                                        Hunk::Insert { text } => {
+                                            let edit_start = snapshot.anchor_after(edit_start);
+                                            Some((edit_start..edit_start, text))
+                                        }
+                                        Hunk::Remove { len } => {
+                                            let edit_end = edit_start + len;
+                                            let edit_range = snapshot.anchor_after(edit_start)
+                                                ..snapshot.anchor_before(edit_end);
+                                            edit_start = edit_end;
+                                            Some((edit_range, String::new()))
+                                        }
+                                        Hunk::Keep { len } => {
+                                            let edit_end = edit_start + len;
+                                            let edit_range = snapshot.anchor_after(edit_start)
+                                                ..snapshot.anchor_before(edit_end);
+                                            edit_start += len;
+                                            this.last_equal_ranges.push(edit_range);
+                                            None
+                                        }
+                                    }),
+                                    None,
+                                    cx,
+                                );
+
+                                buffer.end_transaction(cx)
+                            });
+
+                            if let Some(transaction) = transaction {
+                                if let Some(first_transaction) = this.transaction_id {
+                                    // Group all assistant edits into the first transaction.
+                                    this.buffer.update(cx, |buffer, cx| {
+                                        buffer.merge_transactions(
+                                            transaction,
+                                            first_transaction,
+                                            cx,
+                                        )
+                                    });
+                                } else {
+                                    this.transaction_id = Some(transaction);
+                                    this.buffer.update(cx, |buffer, cx| {
+                                        buffer.finalize_last_transaction(cx)
+                                    });
+                                }
+                            }
+
+                            cx.notify();
+                        });
+                    }
+
+                    diff.await?;
+                    anyhow::Ok(())
+                };
+
+                let result = generate.await;
+                if let Some(this) = this.upgrade(&cx) {
+                    this.update(&mut cx, |this, cx| {
+                        this.last_equal_ranges.clear();
+                        this.idle = true;
+                        if let Err(error) = result {
+                            this.error = Some(error);
+                        }
+                        cx.emit(Event::Finished);
+                        cx.notify();
+                    });
+                }
+            }
+        });
+        self.error.take();
+        self.idle = false;
+        cx.notify();
+    }
+
+    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
+        if let Some(transaction_id) = self.transaction_id {
+            self.buffer
+                .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
+        }
+    }
+}
+
+fn strip_markdown_codeblock(
+    stream: impl Stream<Item = Result<String>>,
+) -> impl Stream<Item = Result<String>> {
+    let mut first_line = true;
+    let mut buffer = String::new();
+    let mut starts_with_fenced_code_block = false;
+    stream.filter_map(move |chunk| {
+        let chunk = match chunk {
+            Ok(chunk) => chunk,
+            Err(err) => return future::ready(Some(Err(err))),
+        };
+        buffer.push_str(&chunk);
+
+        if first_line {
+            if buffer == "" || buffer == "`" || buffer == "``" {
+                return future::ready(None);
+            } else if buffer.starts_with("```") {
+                starts_with_fenced_code_block = true;
+                if let Some(newline_ix) = buffer.find('\n') {
+                    buffer.replace_range(..newline_ix + 1, "");
+                    first_line = false;
+                } else {
+                    return future::ready(None);
+                }
+            }
+        }
+
+        let text = if starts_with_fenced_code_block {
+            buffer
+                .strip_suffix("\n```\n")
+                .or_else(|| buffer.strip_suffix("\n```"))
+                .or_else(|| buffer.strip_suffix("\n``"))
+                .or_else(|| buffer.strip_suffix("\n`"))
+                .or_else(|| buffer.strip_suffix('\n'))
+                .unwrap_or(&buffer)
+        } else {
+            &buffer
+        };
+
+        if text.contains('\n') {
+            first_line = false;
+        }
+
+        let remainder = buffer.split_off(text.len());
+        let result = if buffer.is_empty() {
+            None
+        } else {
+            Some(Ok(buffer.clone()))
+        };
+        buffer = remainder;
+        future::ready(result)
+    })
+}
+
+#[cfg(test)]
+mod tests {
+    use futures::stream;
+
+    use super::*;
+
+    #[gpui::test]
+    async fn test_strip_markdown_codeblock() {
+        assert_eq!(
+            strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
+                .map(|chunk| chunk.unwrap())
+                .collect::<String>()
+                .await,
+            "Lorem ipsum dolor"
+        );
+        assert_eq!(
+            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
+                .map(|chunk| chunk.unwrap())
+                .collect::<String>()
+                .await,
+            "Lorem ipsum dolor"
+        );
+        assert_eq!(
+            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
+                .map(|chunk| chunk.unwrap())
+                .collect::<String>()
+                .await,
+            "Lorem ipsum dolor"
+        );
+        assert_eq!(
+            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
+                .map(|chunk| chunk.unwrap())
+                .collect::<String>()
+                .await,
+            "Lorem ipsum dolor"
+        );
+        assert_eq!(
+            strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
+                .map(|chunk| chunk.unwrap())
+                .collect::<String>()
+                .await,
+            "```js\nLorem ipsum dolor\n```"
+        );
+        assert_eq!(
+            strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
+                .map(|chunk| chunk.unwrap())
+                .collect::<String>()
+                .await,
+            "``\nLorem ipsum dolor\n```"
+        );
+
+        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
+            stream::iter(
+                text.chars()
+                    .collect::<Vec<_>>()
+                    .chunks(size)
+                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
+                    .collect::<Vec<_>>(),
+            )
+        }
+    }
+}

crates/editor/src/editor.rs 🔗

@@ -1734,6 +1734,10 @@ impl Editor {
         }
     }
 
+    pub fn read_only(&self) -> bool {
+        self.read_only
+    }
+
     pub fn set_read_only(&mut self, read_only: bool) {
         self.read_only = read_only;
     }
@@ -5103,9 +5107,6 @@ impl Editor {
             self.unmark_text(cx);
             self.refresh_copilot_suggestions(true, cx);
             cx.emit(Event::Edited);
-            cx.emit(Event::TransactionUndone {
-                transaction_id: tx_id,
-            });
         }
     }
 
@@ -8548,9 +8549,6 @@ pub enum Event {
         local: bool,
         autoscroll: bool,
     },
-    TransactionUndone {
-        transaction_id: TransactionId,
-    },
     Closed,
 }
 

crates/editor/src/multi_buffer.rs 🔗

@@ -70,6 +70,9 @@ pub enum Event {
     Edited {
         sigleton_buffer_edited: bool,
     },
+    TransactionUndone {
+        transaction_id: TransactionId,
+    },
     Reloaded,
     DiffBaseChanged,
     LanguageChanged,
@@ -771,30 +774,36 @@ impl MultiBuffer {
     }
 
     pub fn undo(&mut self, cx: &mut ModelContext<Self>) -> Option<TransactionId> {
+        let mut transaction_id = None;
         if let Some(buffer) = self.as_singleton() {
-            return buffer.update(cx, |buffer, cx| buffer.undo(cx));
-        }
+            transaction_id = buffer.update(cx, |buffer, cx| buffer.undo(cx));
+        } else {
+            while let Some(transaction) = self.history.pop_undo() {
+                let mut undone = false;
+                for (buffer_id, buffer_transaction_id) in &mut transaction.buffer_transactions {
+                    if let Some(BufferState { buffer, .. }) = self.buffers.borrow().get(buffer_id) {
+                        undone |= buffer.update(cx, |buffer, cx| {
+                            let undo_to = *buffer_transaction_id;
+                            if let Some(entry) = buffer.peek_undo_stack() {
+                                *buffer_transaction_id = entry.transaction_id();
+                            }
+                            buffer.undo_to_transaction(undo_to, cx)
+                        });
+                    }
+                }
 
-        while let Some(transaction) = self.history.pop_undo() {
-            let mut undone = false;
-            for (buffer_id, buffer_transaction_id) in &mut transaction.buffer_transactions {
-                if let Some(BufferState { buffer, .. }) = self.buffers.borrow().get(buffer_id) {
-                    undone |= buffer.update(cx, |buffer, cx| {
-                        let undo_to = *buffer_transaction_id;
-                        if let Some(entry) = buffer.peek_undo_stack() {
-                            *buffer_transaction_id = entry.transaction_id();
-                        }
-                        buffer.undo_to_transaction(undo_to, cx)
-                    });
+                if undone {
+                    transaction_id = Some(transaction.id);
+                    break;
                 }
             }
+        }
 
-            if undone {
-                return Some(transaction.id);
-            }
+        if let Some(transaction_id) = transaction_id {
+            cx.emit(Event::TransactionUndone { transaction_id });
         }
 
-        None
+        transaction_id
     }
 
     pub fn redo(&mut self, cx: &mut ModelContext<Self>) -> Option<TransactionId> {