Rework inline assistant

Antonio Scandurra created

Change summary

Cargo.lock                             |   1 
assets/keymaps/default.json            |   3 
crates/ai/Cargo.toml                   |   1 
crates/ai/src/ai.rs                    |   3 
crates/ai/src/assistant.rs             | 574 ++++++++++++++++++++++++++-
crates/ai/src/refactoring_assistant.rs | 252 ------------
crates/ai/src/refactoring_modal.rs     | 137 ------
crates/ai/src/streaming_diff.rs        |  12 
crates/editor/src/editor.rs            |   4 
crates/editor/src/multi_buffer.rs      |  16 
crates/language/src/buffer.rs          |  16 
crates/text/src/text.rs                |   9 
crates/theme/src/theme.rs              |   5 
styles/src/style_tree/assistant.ts     |   5 
14 files changed, 600 insertions(+), 438 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -122,7 +122,6 @@ dependencies = [
  "serde",
  "serde_json",
  "settings",
- "similar",
  "smol",
  "theme",
  "tiktoken-rs 0.4.5",

assets/keymaps/default.json 🔗

@@ -528,7 +528,8 @@
     "bindings": {
       "alt-enter": "editor::OpenExcerpts",
       "cmd-f8": "editor::GoToHunk",
-      "cmd-shift-f8": "editor::GoToPrevHunk"
+      "cmd-shift-f8": "editor::GoToPrevHunk",
+      "ctrl-enter": "assistant::InlineAssist"
     }
   },
   {

crates/ai/Cargo.toml 🔗

@@ -31,7 +31,6 @@ regex.workspace = true
 schemars.workspace = true
 serde.workspace = true
 serde_json.workspace = true
-similar = "1.3"
 smol.workspace = true
 tiktoken-rs = "0.4"
 

crates/ai/src/ai.rs 🔗

@@ -1,7 +1,5 @@
 pub mod assistant;
 mod assistant_settings;
-mod refactoring_assistant;
-mod refactoring_modal;
 mod streaming_diff;
 
 use anyhow::{anyhow, Result};
@@ -196,7 +194,6 @@ struct OpenAIChoice {
 
 pub fn init(cx: &mut AppContext) {
     assistant::init(cx);
-    refactoring_modal::init(cx);
 }
 
 pub async fn stream_completion(

crates/ai/src/assistant.rs 🔗

@@ -1,18 +1,22 @@
 use crate::{
     assistant_settings::{AssistantDockPosition, AssistantSettings},
-    stream_completion, MessageId, MessageMetadata, MessageStatus, OpenAIRequest, RequestMessage,
-    Role, SavedConversation, SavedConversationMetadata, SavedMessage, OPENAI_API_URL,
+    stream_completion,
+    streaming_diff::{Hunk, StreamingDiff},
+    MessageId, MessageMetadata, MessageStatus, OpenAIRequest, RequestMessage, Role,
+    SavedConversation, SavedConversationMetadata, SavedMessage, OPENAI_API_URL,
 };
 use anyhow::{anyhow, Result};
 use chrono::{DateTime, Local};
 use collections::{HashMap, HashSet};
 use editor::{
-    display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint},
+    display_map::{
+        BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint,
+    },
     scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
-    Anchor, Editor, ToOffset,
+    Anchor, Editor, ToOffset, ToPoint,
 };
 use fs::Fs;
-use futures::StreamExt;
+use futures::{channel::mpsc, SinkExt, StreamExt};
 use gpui::{
     actions,
     elements::*,
@@ -21,7 +25,10 @@ use gpui::{
     Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
     Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
 };
-use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
+use language::{
+    language_settings::SoftWrap, Buffer, LanguageRegistry, Point, Rope, Selection, ToOffset as _,
+    TransactionId,
+};
 use search::BufferSearchBar;
 use settings::SettingsStore;
 use std::{
@@ -53,6 +60,7 @@ actions!(
         QuoteSelection,
         ToggleFocus,
         ResetKey,
+        InlineAssist
     ]
 );
 
@@ -84,6 +92,9 @@ pub fn init(cx: &mut AppContext) {
             workspace.toggle_panel_focus::<AssistantPanel>(cx);
         },
     );
+    cx.add_action(AssistantPanel::inline_assist);
+    cx.add_action(InlineAssistant::confirm);
+    cx.add_action(InlineAssistant::cancel);
 }
 
 #[derive(Debug)]
@@ -113,6 +124,9 @@ pub struct AssistantPanel {
     languages: Arc<LanguageRegistry>,
     fs: Arc<dyn Fs>,
     subscriptions: Vec<Subscription>,
+    next_inline_assist_id: usize,
+    pending_inline_assists: HashMap<usize, PendingInlineAssist>,
+    pending_inline_assist_ids_by_editor: HashMap<WeakViewHandle<Editor>, Vec<usize>>,
     _watch_saved_conversations: Task<Result<()>>,
 }
 
@@ -176,6 +190,9 @@ impl AssistantPanel {
                         width: None,
                         height: None,
                         subscriptions: Default::default(),
+                        next_inline_assist_id: 0,
+                        pending_inline_assists: Default::default(),
+                        pending_inline_assist_ids_by_editor: Default::default(),
                         _watch_saved_conversations,
                     };
 
@@ -196,6 +213,425 @@ impl AssistantPanel {
         })
     }
 
+    fn inline_assist(workspace: &mut Workspace, _: &InlineAssist, cx: &mut ViewContext<Workspace>) {
+        let assistant = if let Some(assistant) = workspace.panel::<AssistantPanel>(cx) {
+            if assistant
+                .update(cx, |assistant, cx| assistant.load_api_key(cx))
+                .is_some()
+            {
+                assistant
+            } else {
+                workspace.focus_panel::<AssistantPanel>(cx);
+                return;
+            }
+        } else {
+            return;
+        };
+
+        let active_editor = if let Some(active_editor) = workspace
+            .active_item(cx)
+            .and_then(|item| item.act_as::<Editor>(cx))
+        {
+            active_editor
+        } else {
+            return;
+        };
+
+        assistant.update(cx, |assistant, cx| {
+            assistant.new_inline_assist(&active_editor, cx)
+        });
+    }
+
+    fn new_inline_assist(&mut self, editor: &ViewHandle<Editor>, cx: &mut ViewContext<Self>) {
+        let id = post_inc(&mut self.next_inline_assist_id);
+        let (block_id, inline_assistant, selection) = editor.update(cx, |editor, cx| {
+            let selection = editor.selections.newest_anchor().clone();
+            let prompt_editor = cx.add_view(|cx| {
+                Editor::single_line(
+                    Some(Arc::new(|theme| theme.assistant.inline.editor.clone())),
+                    cx,
+                )
+            });
+            let assist_kind = if editor.selections.newest::<usize>(cx).is_empty() {
+                InlineAssistKind::Insert
+            } else {
+                InlineAssistKind::Edit
+            };
+            let assistant = cx.add_view(|_| InlineAssistant {
+                id,
+                prompt_editor,
+                confirmed: false,
+                has_focus: false,
+                assist_kind,
+            });
+            cx.focus(&assistant);
+
+            let block_id = editor.insert_blocks(
+                [BlockProperties {
+                    style: BlockStyle::Flex,
+                    position: selection.head(),
+                    height: 2,
+                    render: Arc::new({
+                        let assistant = assistant.clone();
+                        move |cx: &mut BlockContext| {
+                            ChildView::new(&assistant, cx)
+                                .contained()
+                                .with_padding_left(match assist_kind {
+                                    InlineAssistKind::Edit => cx.gutter_width,
+                                    InlineAssistKind::Insert => cx.anchor_x,
+                                })
+                                .into_any()
+                        }
+                    }),
+                    disposition: if selection.reversed {
+                        BlockDisposition::Above
+                    } else {
+                        BlockDisposition::Below
+                    },
+                }],
+                Some(Autoscroll::Strategy(AutoscrollStrategy::Newest)),
+                cx,
+            )[0];
+            editor.highlight_background::<Self>(
+                vec![selection.start..selection.end],
+                |theme| theme.assistant.inline.pending_edit_background,
+                cx,
+            );
+
+            (block_id, assistant, selection)
+        });
+
+        self.pending_inline_assists.insert(
+            id,
+            PendingInlineAssist {
+                editor: editor.downgrade(),
+                selection,
+                inline_assistant_block_id: Some(block_id),
+                code_generation: Task::ready(None),
+                transaction_id: None,
+                _subscriptions: vec![
+                    cx.subscribe(&inline_assistant, Self::handle_inline_assistant_event),
+                    cx.subscribe(editor, {
+                        let inline_assistant = inline_assistant.downgrade();
+                        move |_, editor, event, cx| {
+                            if let Some(inline_assistant) = inline_assistant.upgrade(cx) {
+                                if let editor::Event::SelectionsChanged { local } = event {
+                                    if *local && inline_assistant.read(cx).has_focus {
+                                        cx.focus(&editor);
+                                    }
+                                }
+                            }
+                        }
+                    }),
+                ],
+            },
+        );
+        self.pending_inline_assist_ids_by_editor
+            .entry(editor.downgrade())
+            .or_default()
+            .push(id);
+    }
+
+    fn handle_inline_assistant_event(
+        &mut self,
+        inline_assistant: ViewHandle<InlineAssistant>,
+        event: &InlineAssistantEvent,
+        cx: &mut ViewContext<Self>,
+    ) {
+        let assist_id = inline_assistant.read(cx).id;
+        match event {
+            InlineAssistantEvent::Confirmed { prompt } => {
+                self.generate_code(assist_id, prompt, cx);
+            }
+            InlineAssistantEvent::Canceled => {
+                self.complete_inline_assist(assist_id, true, cx);
+            }
+            InlineAssistantEvent::Dismissed => {
+                self.dismiss_inline_assist(assist_id, cx);
+            }
+        }
+    }
+
+    fn complete_inline_assist(
+        &mut self,
+        assist_id: usize,
+        cancel: bool,
+        cx: &mut ViewContext<Self>,
+    ) {
+        self.dismiss_inline_assist(assist_id, cx);
+
+        if let Some(pending_assist) = self.pending_inline_assists.remove(&assist_id) {
+            self.pending_inline_assist_ids_by_editor
+                .remove(&pending_assist.editor);
+
+            if let Some(editor) = pending_assist.editor.upgrade(cx) {
+                editor.update(cx, |editor, cx| {
+                    editor.clear_background_highlights::<Self>(cx);
+                    editor.clear_text_highlights::<Self>(cx);
+                });
+
+                if cancel {
+                    if let Some(transaction_id) = pending_assist.transaction_id {
+                        editor.update(cx, |editor, cx| {
+                            editor.buffer().update(cx, |buffer, cx| {
+                                buffer.undo_and_forget(transaction_id, cx)
+                            });
+                        });
+                    }
+                }
+            }
+        }
+    }
+
+    fn dismiss_inline_assist(&mut self, assist_id: usize, cx: &mut ViewContext<Self>) {
+        if let Some(pending_assist) = self.pending_inline_assists.get_mut(&assist_id) {
+            if let Some(editor) = pending_assist.editor.upgrade(cx) {
+                if let Some(block_id) = pending_assist.inline_assistant_block_id.take() {
+                    editor.update(cx, |editor, cx| {
+                        editor.remove_blocks(HashSet::from_iter([block_id]), None, cx);
+                    });
+                }
+            }
+        }
+    }
+
+    pub fn generate_code(
+        &mut self,
+        inline_assist_id: usize,
+        user_prompt: &str,
+        cx: &mut ViewContext<Self>,
+    ) {
+        let api_key = if let Some(api_key) = self.api_key.borrow().clone() {
+            api_key
+        } else {
+            return;
+        };
+
+        let pending_assist =
+            if let Some(pending_assist) = self.pending_inline_assists.get_mut(&inline_assist_id) {
+                pending_assist
+            } else {
+                return;
+            };
+
+        let editor = if let Some(editor) = pending_assist.editor.upgrade(cx) {
+            editor
+        } else {
+            return;
+        };
+
+        let selection = pending_assist.selection.clone();
+        let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
+        let selected_text = snapshot
+            .text_for_range(selection.start..selection.end)
+            .collect::<Rope>();
+
+        let mut normalized_selected_text = selected_text.clone();
+        let mut base_indentation: Option<language::IndentSize> = None;
+        let selection_start = selection.start.to_point(&snapshot);
+        let selection_end = selection.end.to_point(&snapshot);
+        if selection_start.row < selection_end.row {
+            for row in selection_start.row..=selection_end.row {
+                if snapshot.is_line_blank(row) {
+                    continue;
+                }
+
+                let line_indentation = snapshot.indent_size_for_line(row);
+                if let Some(base_indentation) = base_indentation.as_mut() {
+                    if line_indentation.len < base_indentation.len {
+                        *base_indentation = line_indentation;
+                    }
+                } else {
+                    base_indentation = Some(line_indentation);
+                }
+            }
+        }
+
+        if let Some(base_indentation) = base_indentation {
+            for row in selection_start.row..=selection_end.row {
+                let selection_row = row - selection_start.row;
+                let line_start =
+                    normalized_selected_text.point_to_offset(Point::new(selection_row, 0));
+                let indentation_len = if row == selection_start.row {
+                    base_indentation.len.saturating_sub(selection_start.column)
+                } else {
+                    let line_len = normalized_selected_text.line_len(selection_row);
+                    cmp::min(line_len, base_indentation.len)
+                };
+                let indentation_end = cmp::min(
+                    line_start + indentation_len as usize,
+                    normalized_selected_text.len(),
+                );
+                normalized_selected_text.replace(line_start..indentation_end, "");
+            }
+        }
+
+        let language_name = snapshot
+            .language_at(selection.start)
+            .map(|language| language.name());
+        let language_name = language_name.as_deref().unwrap_or("");
+
+        let mut prompt = String::new();
+        writeln!(prompt, "Given the following {language_name} snippet:").unwrap();
+        writeln!(prompt, "{normalized_selected_text}").unwrap();
+        writeln!(prompt, "{user_prompt}.").unwrap();
+        writeln!(prompt, "Never make remarks, reply only with the new code.").unwrap();
+        let request = OpenAIRequest {
+            model: "gpt-4".into(),
+            messages: vec![RequestMessage {
+                role: Role::User,
+                content: prompt,
+            }],
+            stream: true,
+        };
+        let response = stream_completion(api_key, cx.background().clone(), request);
+        let editor = editor.downgrade();
+
+        pending_assist.code_generation = cx.spawn(|this, mut cx| {
+            async move {
+                let _cleanup = util::defer({
+                    let mut cx = cx.clone();
+                    let this = this.clone();
+                    move || {
+                        let _ = this.update(&mut cx, |this, cx| {
+                            this.complete_inline_assist(inline_assist_id, false, cx)
+                        });
+                    }
+                });
+
+                let mut edit_start = selection.start.to_offset(&snapshot);
+
+                let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
+                let diff = cx.background().spawn(async move {
+                    let mut messages = response.await?;
+                    let mut diff = StreamingDiff::new(selected_text.to_string());
+
+                    let indentation_len;
+                    let indentation_text;
+                    if let Some(base_indentation) = base_indentation {
+                        indentation_len = base_indentation.len;
+                        indentation_text = match base_indentation.kind {
+                            language::IndentKind::Space => " ",
+                            language::IndentKind::Tab => "\t",
+                        };
+                    } else {
+                        indentation_len = 0;
+                        indentation_text = "";
+                    };
+
+                    let mut new_text = indentation_text
+                        .repeat(indentation_len.saturating_sub(selection_start.column) as usize);
+                    while let Some(message) = messages.next().await {
+                        let mut message = message?;
+                        if let Some(choice) = message.choices.pop() {
+                            if let Some(text) = choice.delta.content {
+                                let mut lines = text.split('\n');
+                                if let Some(first_line) = lines.next() {
+                                    new_text.push_str(&first_line);
+                                }
+
+                                for line in lines {
+                                    new_text.push('\n');
+                                    new_text.push_str(
+                                        &indentation_text.repeat(indentation_len as usize),
+                                    );
+                                    new_text.push_str(line);
+                                }
+                            }
+                        }
+
+                        let hunks = diff.push_new(&new_text);
+                        hunks_tx.send(hunks).await?;
+                        new_text.clear();
+                    }
+                    hunks_tx.send(diff.finish()).await?;
+
+                    anyhow::Ok(())
+                });
+
+                while let Some(hunks) = hunks_rx.next().await {
+                    let this = this
+                        .upgrade(&cx)
+                        .ok_or_else(|| anyhow!("assistant was dropped"))?;
+                    editor.update(&mut cx, |editor, cx| {
+                        let mut highlights = Vec::new();
+
+                        let transaction = editor.buffer().update(cx, |buffer, cx| {
+                            // Avoid grouping assistant edits with user edits.
+                            buffer.finalize_last_transaction(cx);
+
+                            buffer.start_transaction(cx);
+                            buffer.edit(
+                                hunks.into_iter().filter_map(|hunk| match hunk {
+                                    Hunk::Insert { text } => {
+                                        let edit_start = snapshot.anchor_after(edit_start);
+                                        Some((edit_start..edit_start, text))
+                                    }
+                                    Hunk::Remove { len } => {
+                                        let edit_end = edit_start + len;
+                                        let edit_range = snapshot.anchor_after(edit_start)
+                                            ..snapshot.anchor_before(edit_end);
+                                        edit_start = edit_end;
+                                        Some((edit_range, String::new()))
+                                    }
+                                    Hunk::Keep { len } => {
+                                        let edit_end = edit_start + len;
+                                        let edit_range = snapshot.anchor_after(edit_start)
+                                            ..snapshot.anchor_before(edit_end);
+                                        edit_start += len;
+                                        highlights.push(edit_range);
+                                        None
+                                    }
+                                }),
+                                None,
+                                cx,
+                            );
+
+                            buffer.end_transaction(cx)
+                        });
+
+                        if let Some(transaction) = transaction {
+                            this.update(cx, |this, cx| {
+                                if let Some(pending_assist) =
+                                    this.pending_inline_assists.get_mut(&inline_assist_id)
+                                {
+                                    if let Some(first_transaction) = pending_assist.transaction_id {
+                                        // Group all assistant edits into the first transaction.
+                                        editor.buffer().update(cx, |buffer, cx| {
+                                            buffer.merge_transactions(
+                                                transaction,
+                                                first_transaction,
+                                                cx,
+                                            )
+                                        });
+                                    } else {
+                                        pending_assist.transaction_id = Some(transaction);
+                                        editor.buffer().update(cx, |buffer, cx| {
+                                            buffer.finalize_last_transaction(cx)
+                                        });
+                                    }
+                                }
+                            });
+                        }
+
+                        editor.highlight_text::<Self>(
+                            highlights,
+                            gpui::fonts::HighlightStyle {
+                                fade_out: Some(0.6),
+                                ..Default::default()
+                            },
+                            cx,
+                        );
+                    })?;
+                }
+                diff.await?;
+
+                anyhow::Ok(())
+            }
+            .log_err()
+        });
+    }
+
     fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
         let editor = cx.add_view(|cx| {
             ConversationEditor::new(
@@ -565,6 +1001,32 @@ impl AssistantPanel {
             .iter()
             .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
     }
+
+    pub fn load_api_key(&mut self, cx: &mut ViewContext<Self>) -> Option<String> {
+        if self.api_key.borrow().is_none() && !self.has_read_credentials {
+            self.has_read_credentials = true;
+            let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+                Some(api_key)
+            } else if let Some((_, api_key)) = cx
+                .platform()
+                .read_credentials(OPENAI_API_URL)
+                .log_err()
+                .flatten()
+            {
+                String::from_utf8(api_key).log_err()
+            } else {
+                None
+            };
+            if let Some(api_key) = api_key {
+                *self.api_key.borrow_mut() = Some(api_key);
+            } else if self.api_key_editor.is_none() {
+                self.api_key_editor = Some(build_api_key_editor(cx));
+                cx.notify();
+            }
+        }
+
+        self.api_key.borrow().clone()
+    }
 }
 
 fn build_api_key_editor(cx: &mut ViewContext<AssistantPanel>) -> ViewHandle<Editor> {
@@ -748,27 +1210,7 @@ impl Panel for AssistantPanel {
 
     fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
         if active {
-            if self.api_key.borrow().is_none() && !self.has_read_credentials {
-                self.has_read_credentials = true;
-                let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
-                    Some(api_key)
-                } else if let Some((_, api_key)) = cx
-                    .platform()
-                    .read_credentials(OPENAI_API_URL)
-                    .log_err()
-                    .flatten()
-                {
-                    String::from_utf8(api_key).log_err()
-                } else {
-                    None
-                };
-                if let Some(api_key) = api_key {
-                    *self.api_key.borrow_mut() = Some(api_key);
-                } else if self.api_key_editor.is_none() {
-                    self.api_key_editor = Some(build_api_key_editor(cx));
-                    cx.notify();
-                }
-            }
+            self.load_api_key(cx);
 
             if self.editors.is_empty() {
                 self.new_conversation(cx);
@@ -2139,6 +2581,84 @@ impl Message {
     }
 }
 
+enum InlineAssistantEvent {
+    Confirmed { prompt: String },
+    Canceled,
+    Dismissed,
+}
+
+#[derive(Copy, Clone)]
+enum InlineAssistKind {
+    Edit,
+    Insert,
+}
+
+struct InlineAssistant {
+    id: usize,
+    prompt_editor: ViewHandle<Editor>,
+    confirmed: bool,
+    assist_kind: InlineAssistKind,
+    has_focus: bool,
+}
+
+impl Entity for InlineAssistant {
+    type Event = InlineAssistantEvent;
+}
+
+impl View for InlineAssistant {
+    fn ui_name() -> &'static str {
+        "InlineAssistant"
+    }
+
+    fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
+        let theme = theme::current(cx);
+        let prompt_editor = ChildView::new(&self.prompt_editor, cx).aligned().left();
+        match self.assist_kind {
+            InlineAssistKind::Edit => prompt_editor
+                .contained()
+                .with_style(theme.assistant.inline.container)
+                .into_any(),
+            InlineAssistKind::Insert => prompt_editor.into_any(),
+        }
+    }
+
+    fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
+        cx.focus(&self.prompt_editor);
+        self.has_focus = true;
+    }
+
+    fn focus_out(&mut self, _: gpui::AnyViewHandle, _: &mut ViewContext<Self>) {
+        self.has_focus = false;
+    }
+}
+
+impl InlineAssistant {
+    fn cancel(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
+        cx.emit(InlineAssistantEvent::Canceled);
+    }
+
+    fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
+        if self.confirmed {
+            cx.emit(InlineAssistantEvent::Dismissed);
+        } else {
+            let prompt = self.prompt_editor.read(cx).text(cx);
+            self.prompt_editor
+                .update(cx, |editor, _| editor.set_read_only(true));
+            cx.emit(InlineAssistantEvent::Confirmed { prompt });
+            self.confirmed = true;
+        }
+    }
+}
+
+struct PendingInlineAssist {
+    editor: WeakViewHandle<Editor>,
+    selection: Selection<Anchor>,
+    inline_assistant_block_id: Option<BlockId>,
+    code_generation: Task<Option<()>>,
+    transaction_id: Option<TransactionId>,
+    _subscriptions: Vec<Subscription>,
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;

crates/ai/src/refactoring_assistant.rs 🔗

@@ -1,252 +0,0 @@
-use collections::HashMap;
-use editor::{Editor, ToOffset, ToPoint};
-use futures::{channel::mpsc, SinkExt, StreamExt};
-use gpui::{AppContext, Task, ViewHandle};
-use language::{Point, Rope};
-use std::{cmp, env, fmt::Write};
-use util::TryFutureExt;
-
-use crate::{
-    stream_completion,
-    streaming_diff::{Hunk, StreamingDiff},
-    OpenAIRequest, RequestMessage, Role,
-};
-
-pub struct RefactoringAssistant {
-    pending_edits_by_editor: HashMap<usize, Task<Option<()>>>,
-}
-
-impl RefactoringAssistant {
-    fn new() -> Self {
-        Self {
-            pending_edits_by_editor: Default::default(),
-        }
-    }
-
-    pub fn update<F, T>(cx: &mut AppContext, f: F) -> T
-    where
-        F: FnOnce(&mut Self, &mut AppContext) -> T,
-    {
-        if !cx.has_global::<Self>() {
-            cx.set_global(Self::new());
-        }
-
-        cx.update_global(f)
-    }
-
-    pub fn refactor(
-        &mut self,
-        editor: &ViewHandle<Editor>,
-        user_prompt: &str,
-        cx: &mut AppContext,
-    ) {
-        let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
-            api_key
-        } else {
-            // TODO: ensure the API key is present by going through the assistant panel's flow.
-            return;
-        };
-
-        let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
-        let selection = editor.read(cx).selections.newest_anchor().clone();
-        let selected_text = snapshot
-            .text_for_range(selection.start..selection.end)
-            .collect::<Rope>();
-
-        let mut normalized_selected_text = selected_text.clone();
-        let mut base_indentation: Option<language::IndentSize> = None;
-        let selection_start = selection.start.to_point(&snapshot);
-        let selection_end = selection.end.to_point(&snapshot);
-        if selection_start.row < selection_end.row {
-            for row in selection_start.row..=selection_end.row {
-                if snapshot.is_line_blank(row) {
-                    continue;
-                }
-
-                let line_indentation = snapshot.indent_size_for_line(row);
-                if let Some(base_indentation) = base_indentation.as_mut() {
-                    if line_indentation.len < base_indentation.len {
-                        *base_indentation = line_indentation;
-                    }
-                } else {
-                    base_indentation = Some(line_indentation);
-                }
-            }
-        }
-
-        if let Some(base_indentation) = base_indentation {
-            for row in selection_start.row..=selection_end.row {
-                let selection_row = row - selection_start.row;
-                let line_start =
-                    normalized_selected_text.point_to_offset(Point::new(selection_row, 0));
-                let indentation_len = if row == selection_start.row {
-                    base_indentation.len.saturating_sub(selection_start.column)
-                } else {
-                    let line_len = normalized_selected_text.line_len(selection_row);
-                    cmp::min(line_len, base_indentation.len)
-                };
-                let indentation_end = cmp::min(
-                    line_start + indentation_len as usize,
-                    normalized_selected_text.len(),
-                );
-                normalized_selected_text.replace(line_start..indentation_end, "");
-            }
-        }
-
-        let language_name = snapshot
-            .language_at(selection.start)
-            .map(|language| language.name());
-        let language_name = language_name.as_deref().unwrap_or("");
-
-        let mut prompt = String::new();
-        writeln!(prompt, "Given the following {language_name} snippet:").unwrap();
-        writeln!(prompt, "{normalized_selected_text}").unwrap();
-        writeln!(prompt, "{user_prompt}.").unwrap();
-        writeln!(prompt, "Never make remarks, reply only with the new code.").unwrap();
-        let request = OpenAIRequest {
-            model: "gpt-4".into(),
-            messages: vec![RequestMessage {
-                role: Role::User,
-                content: prompt,
-            }],
-            stream: true,
-        };
-        let response = stream_completion(api_key, cx.background().clone(), request);
-        let editor = editor.downgrade();
-        self.pending_edits_by_editor.insert(
-            editor.id(),
-            cx.spawn(|mut cx| {
-                async move {
-                    let _clear_highlights = util::defer({
-                        let mut cx = cx.clone();
-                        let editor = editor.clone();
-                        move || {
-                            let _ = editor.update(&mut cx, |editor, cx| {
-                                editor.clear_text_highlights::<Self>(cx);
-                            });
-                        }
-                    });
-
-                    let mut edit_start = selection.start.to_offset(&snapshot);
-
-                    let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
-                    let diff = cx.background().spawn(async move {
-                        let mut messages = response.await?.ready_chunks(4);
-                        let mut diff = StreamingDiff::new(selected_text.to_string());
-
-                        let indentation_len;
-                        let indentation_text;
-                        if let Some(base_indentation) = base_indentation {
-                            indentation_len = base_indentation.len;
-                            indentation_text = match base_indentation.kind {
-                                language::IndentKind::Space => " ",
-                                language::IndentKind::Tab => "\t",
-                            };
-                        } else {
-                            indentation_len = 0;
-                            indentation_text = "";
-                        };
-
-                        let mut new_text =
-                            indentation_text.repeat(
-                                indentation_len.saturating_sub(selection_start.column) as usize,
-                            );
-                        while let Some(messages) = messages.next().await {
-                            for message in messages {
-                                let mut message = message?;
-                                if let Some(choice) = message.choices.pop() {
-                                    if let Some(text) = choice.delta.content {
-                                        let mut lines = text.split('\n');
-                                        if let Some(first_line) = lines.next() {
-                                            new_text.push_str(&first_line);
-                                        }
-
-                                        for line in lines {
-                                            new_text.push('\n');
-                                            new_text.push_str(
-                                                &indentation_text.repeat(indentation_len as usize),
-                                            );
-                                            new_text.push_str(line);
-                                        }
-                                    }
-                                }
-                            }
-
-                            let hunks = diff.push_new(&new_text);
-                            hunks_tx.send(hunks).await?;
-                            new_text.clear();
-                        }
-                        hunks_tx.send(diff.finish()).await?;
-
-                        anyhow::Ok(())
-                    });
-
-                    let mut first_transaction = None;
-                    while let Some(hunks) = hunks_rx.next().await {
-                        editor.update(&mut cx, |editor, cx| {
-                            let mut highlights = Vec::new();
-
-                            editor.buffer().update(cx, |buffer, cx| {
-                                // Avoid grouping assistant edits with user edits.
-                                buffer.finalize_last_transaction(cx);
-
-                                buffer.start_transaction(cx);
-                                buffer.edit(
-                                    hunks.into_iter().filter_map(|hunk| match hunk {
-                                        Hunk::Insert { text } => {
-                                            let edit_start = snapshot.anchor_after(edit_start);
-                                            Some((edit_start..edit_start, text))
-                                        }
-                                        Hunk::Remove { len } => {
-                                            let edit_end = edit_start + len;
-                                            let edit_range = snapshot.anchor_after(edit_start)
-                                                ..snapshot.anchor_before(edit_end);
-                                            edit_start = edit_end;
-                                            Some((edit_range, String::new()))
-                                        }
-                                        Hunk::Keep { len } => {
-                                            let edit_end = edit_start + len;
-                                            let edit_range = snapshot.anchor_after(edit_start)
-                                                ..snapshot.anchor_before(edit_end);
-                                            edit_start += len;
-                                            highlights.push(edit_range);
-                                            None
-                                        }
-                                    }),
-                                    None,
-                                    cx,
-                                );
-                                if let Some(transaction) = buffer.end_transaction(cx) {
-                                    if let Some(first_transaction) = first_transaction {
-                                        // Group all assistant edits into the first transaction.
-                                        buffer.merge_transactions(
-                                            transaction,
-                                            first_transaction,
-                                            cx,
-                                        );
-                                    } else {
-                                        first_transaction = Some(transaction);
-                                        buffer.finalize_last_transaction(cx);
-                                    }
-                                }
-                            });
-
-                            editor.highlight_text::<Self>(
-                                highlights,
-                                gpui::fonts::HighlightStyle {
-                                    fade_out: Some(0.6),
-                                    ..Default::default()
-                                },
-                                cx,
-                            );
-                        })?;
-                    }
-                    diff.await?;
-
-                    anyhow::Ok(())
-                }
-                .log_err()
-            }),
-        );
-    }
-}

crates/ai/src/refactoring_modal.rs 🔗

@@ -1,137 +0,0 @@
-use crate::refactoring_assistant::RefactoringAssistant;
-use collections::HashSet;
-use editor::{
-    display_map::{BlockContext, BlockDisposition, BlockProperties, BlockStyle},
-    scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
-    Editor,
-};
-use gpui::{
-    actions, elements::*, platform::MouseButton, AnyViewHandle, AppContext, Entity, View,
-    ViewContext, ViewHandle, WeakViewHandle,
-};
-use std::sync::Arc;
-use workspace::Workspace;
-
-actions!(assistant, [Refactor]);
-
-pub fn init(cx: &mut AppContext) {
-    cx.add_action(RefactoringModal::deploy);
-    cx.add_action(RefactoringModal::confirm);
-    cx.add_action(RefactoringModal::cancel);
-}
-
-enum Event {
-    Dismissed,
-}
-
-struct RefactoringModal {
-    active_editor: WeakViewHandle<Editor>,
-    prompt_editor: ViewHandle<Editor>,
-    has_focus: bool,
-}
-
-impl Entity for RefactoringModal {
-    type Event = Event;
-}
-
-impl View for RefactoringModal {
-    fn ui_name() -> &'static str {
-        "RefactoringModal"
-    }
-
-    fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
-        let theme = theme::current(cx);
-        ChildView::new(&self.prompt_editor, cx)
-            .aligned()
-            .left()
-            .contained()
-            .with_style(theme.assistant.modal.container)
-            .mouse::<Self>(0)
-            .on_click_out(MouseButton::Left, |_, _, cx| cx.emit(Event::Dismissed))
-            .on_click_out(MouseButton::Right, |_, _, cx| cx.emit(Event::Dismissed))
-            .into_any()
-    }
-
-    fn focus_in(&mut self, _: AnyViewHandle, cx: &mut ViewContext<Self>) {
-        self.has_focus = true;
-        cx.focus(&self.prompt_editor);
-    }
-
-    fn focus_out(&mut self, _: AnyViewHandle, cx: &mut ViewContext<Self>) {
-        if !self.prompt_editor.is_focused(cx) {
-            self.has_focus = false;
-            cx.emit(Event::Dismissed);
-        }
-    }
-}
-
-impl RefactoringModal {
-    fn deploy(workspace: &mut Workspace, _: &Refactor, cx: &mut ViewContext<Workspace>) {
-        if let Some(active_editor) = workspace
-            .active_item(cx)
-            .and_then(|item| item.act_as::<Editor>(cx))
-        {
-            active_editor.update(cx, |editor, cx| {
-                let position = editor.selections.newest_anchor().head();
-                let prompt_editor = cx.add_view(|cx| {
-                    Editor::single_line(
-                        Some(Arc::new(|theme| theme.assistant.modal.editor.clone())),
-                        cx,
-                    )
-                });
-                let active_editor = cx.weak_handle();
-                let refactoring = cx.add_view(|_| RefactoringModal {
-                    active_editor,
-                    prompt_editor,
-                    has_focus: false,
-                });
-                cx.focus(&refactoring);
-
-                let block_id = editor.insert_blocks(
-                    [BlockProperties {
-                        style: BlockStyle::Flex,
-                        position,
-                        height: 2,
-                        render: Arc::new({
-                            let refactoring = refactoring.clone();
-                            move |cx: &mut BlockContext| {
-                                ChildView::new(&refactoring, cx)
-                                    .contained()
-                                    .with_padding_left(cx.gutter_width)
-                                    .into_any()
-                            }
-                        }),
-                        disposition: BlockDisposition::Below,
-                    }],
-                    Some(Autoscroll::Strategy(AutoscrollStrategy::Newest)),
-                    cx,
-                )[0];
-                cx.subscribe(&refactoring, move |_, refactoring, event, cx| {
-                    let Event::Dismissed = event;
-                    if let Some(active_editor) = refactoring.read(cx).active_editor.upgrade(cx) {
-                        cx.window_context().defer(move |cx| {
-                            active_editor.update(cx, |editor, cx| {
-                                editor.remove_blocks(HashSet::from_iter([block_id]), None, cx);
-                            })
-                        });
-                    }
-                })
-                .detach();
-            });
-        }
-    }
-
-    fn cancel(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
-        cx.emit(Event::Dismissed);
-    }
-
-    fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
-        if let Some(editor) = self.active_editor.upgrade(cx) {
-            let prompt = self.prompt_editor.read(cx).text(cx);
-            RefactoringAssistant::update(cx, |assistant, cx| {
-                assistant.refactor(&editor, &prompt, cx);
-            });
-            cx.emit(Event::Dismissed);
-        }
-    }
-}

crates/ai/src/streaming_diff.rs 🔗

@@ -83,8 +83,8 @@ pub struct StreamingDiff {
 impl StreamingDiff {
     const INSERTION_SCORE: f64 = -1.;
     const DELETION_SCORE: f64 = -5.;
-    const EQUALITY_BASE: f64 = 1.4;
-    const MAX_EQUALITY_EXPONENT: i32 = 64;
+    const EQUALITY_BASE: f64 = 2.;
+    const MAX_EQUALITY_EXPONENT: i32 = 20;
 
     pub fn new(old: String) -> Self {
         let old = old.chars().collect::<Vec<_>>();
@@ -117,12 +117,8 @@ impl StreamingDiff {
                     equal_run += 1;
                     self.equal_runs.insert((i, j), equal_run);
 
-                    if self.old[i - 1] == ' ' {
-                        self.scores.get(i - 1, j - 1)
-                    } else {
-                        let exponent = cmp::min(equal_run as i32, Self::MAX_EQUALITY_EXPONENT);
-                        self.scores.get(i - 1, j - 1) + Self::EQUALITY_BASE.powi(exponent)
-                    }
+                    let exponent = cmp::min(equal_run as i32 / 4, Self::MAX_EQUALITY_EXPONENT);
+                    self.scores.get(i - 1, j - 1) + Self::EQUALITY_BASE.powi(exponent)
                 } else {
                     f64::NEG_INFINITY
                 };

crates/editor/src/editor.rs 🔗

@@ -8209,7 +8209,7 @@ impl View for Editor {
         "Editor"
     }
 
-    fn focus_in(&mut self, _: AnyViewHandle, cx: &mut ViewContext<Self>) {
+    fn focus_in(&mut self, focused: AnyViewHandle, cx: &mut ViewContext<Self>) {
         if cx.is_self_focused() {
             let focused_event = EditorFocused(cx.handle());
             cx.emit(Event::Focused);
@@ -8217,7 +8217,7 @@ impl View for Editor {
         }
         if let Some(rename) = self.pending_rename.as_ref() {
             cx.focus(&rename.editor);
-        } else {
+        } else if cx.is_self_focused() || !focused.is::<Editor>() {
             if !self.focused {
                 self.blink_manager.update(cx, BlinkManager::enable);
             }

crates/editor/src/multi_buffer.rs 🔗

@@ -626,7 +626,7 @@ impl MultiBuffer {
                 buffer.merge_transactions(transaction, destination)
             });
         } else {
-            if let Some(transaction) = self.history.remove_transaction(transaction) {
+            if let Some(transaction) = self.history.forget(transaction) {
                 if let Some(destination) = self.history.transaction_mut(destination) {
                     for (buffer_id, buffer_transaction_id) in transaction.buffer_transactions {
                         if let Some(destination_buffer_transaction_id) =
@@ -822,6 +822,18 @@ impl MultiBuffer {
         None
     }
 
+    pub fn undo_and_forget(&mut self, transaction_id: TransactionId, cx: &mut ModelContext<Self>) {
+        if let Some(buffer) = self.as_singleton() {
+            buffer.update(cx, |buffer, cx| buffer.undo_and_forget(transaction_id, cx));
+        } else if let Some(transaction) = self.history.forget(transaction_id) {
+            for (buffer_id, transaction_id) in transaction.buffer_transactions {
+                if let Some(BufferState { buffer, .. }) = self.buffers.borrow().get(&buffer_id) {
+                    buffer.update(cx, |buffer, cx| buffer.undo_and_forget(transaction_id, cx));
+                }
+            }
+        }
+    }
+
     pub fn stream_excerpts_with_context_lines(
         &mut self,
         excerpts: Vec<(ModelHandle<Buffer>, Vec<Range<text::Anchor>>)>,
@@ -3369,7 +3381,7 @@ impl History {
         }
     }
 
-    fn remove_transaction(&mut self, transaction_id: TransactionId) -> Option<Transaction> {
+    fn forget(&mut self, transaction_id: TransactionId) -> Option<Transaction> {
         if let Some(ix) = self
             .undo_stack
             .iter()

crates/language/src/buffer.rs 🔗

@@ -1664,6 +1664,22 @@ impl Buffer {
         }
     }
 
+    pub fn undo_and_forget(
+        &mut self,
+        transaction_id: TransactionId,
+        cx: &mut ModelContext<Self>,
+    ) -> bool {
+        let was_dirty = self.is_dirty();
+        let old_version = self.version.clone();
+        if let Some(operation) = self.text.undo_and_forget(transaction_id) {
+            self.send_operation(Operation::Buffer(operation), cx);
+            self.did_edit(&old_version, was_dirty, cx);
+            true
+        } else {
+            false
+        }
+    }
+
     pub fn undo_to_transaction(
         &mut self,
         transaction_id: TransactionId,

crates/text/src/text.rs 🔗

@@ -22,6 +22,7 @@ use postage::{oneshot, prelude::*};
 
 pub use rope::*;
 pub use selection::*;
+use util::ResultExt;
 
 use std::{
     cmp::{self, Ordering, Reverse},
@@ -1206,6 +1207,14 @@ impl Buffer {
         }
     }
 
+    pub fn undo_and_forget(&mut self, transaction_id: TransactionId) -> Option<Operation> {
+        if let Some(transaction) = self.history.forget(transaction_id) {
+            self.undo_or_redo(transaction).log_err()
+        } else {
+            None
+        }
+    }
+
     #[allow(clippy::needless_collect)]
     pub fn undo_to_transaction(&mut self, transaction_id: TransactionId) -> Vec<Operation> {
         let transactions = self

crates/theme/src/theme.rs 🔗

@@ -1124,14 +1124,15 @@ pub struct AssistantStyle {
     pub api_key_editor: FieldEditor,
     pub api_key_prompt: ContainedText,
     pub saved_conversation: SavedConversation,
-    pub modal: ModalAssistantStyle,
+    pub inline: InlineAssistantStyle,
 }
 
 #[derive(Clone, Deserialize, Default, JsonSchema)]
-pub struct ModalAssistantStyle {
+pub struct InlineAssistantStyle {
     #[serde(flatten)]
     pub container: ContainerStyle,
     pub editor: FieldEditor,
+    pub pending_edit_background: Color,
 }
 
 #[derive(Clone, Deserialize, Default, JsonSchema)]

styles/src/style_tree/assistant.ts 🔗

@@ -59,7 +59,7 @@ export default function assistant(): any {
             background: background(theme.highest),
             padding: { left: 12 },
         },
-        modal: {
+        inline: {
             border: border(theme.lowest, "on", {
                 top: true,
                 bottom: true,
@@ -69,7 +69,8 @@ export default function assistant(): any {
                 text: text(theme.lowest, "mono", "on", { size: "sm" }),
                 placeholder_text: text(theme.lowest, "sans", "on", "disabled"),
                 selection: theme.players[0],
-            }
+            },
+            pending_edit_background: background(theme.highest, "positive"),
         },
         message_header: {
             margin: { bottom: 4, top: 4 },