Move to an inline refactoring prompt

Antonio Scandurra created

Change summary

crates/ai/src/ai.rs                    |   7 
crates/ai/src/refactoring_assistant.rs | 166 +++++++--------------------
crates/ai/src/refactoring_modal.rs     | 134 ++++++++++++++++++++++
crates/ai/src/streaming_diff.rs        |   8 
styles/src/style_tree/assistant.ts     |   3 
5 files changed, 186 insertions(+), 132 deletions(-)

Detailed changes

crates/ai/src/ai.rs 🔗

@@ -1,7 +1,8 @@
 pub mod assistant;
 mod assistant_settings;
-mod diff;
-mod refactor;
+mod refactoring_assistant;
+mod refactoring_modal;
+mod streaming_diff;
 
 use anyhow::{anyhow, Result};
 pub use assistant::AssistantPanel;
@@ -195,7 +196,7 @@ struct OpenAIChoice {
 
 pub fn init(cx: &mut AppContext) {
     assistant::init(cx);
-    refactor::init(cx);
+    refactoring_modal::init(cx);
 }
 
 pub async fn stream_completion(

crates/ai/src/refactor.rs → crates/ai/src/refactoring_assistant.rs 🔗

@@ -1,25 +1,16 @@
-use crate::{diff::Diff, stream_completion, OpenAIRequest, RequestMessage, Role};
 use collections::HashMap;
 use editor::{Editor, ToOffset, ToPoint};
 use futures::{channel::mpsc, SinkExt, StreamExt};
-use gpui::{
-    actions, elements::*, platform::MouseButton, AnyViewHandle, AppContext, Entity, Task, View,
-    ViewContext, ViewHandle, WeakViewHandle,
-};
+use gpui::{AppContext, Task, ViewHandle};
 use language::{Point, Rope};
-use menu::{Cancel, Confirm};
-use std::{cmp, env, sync::Arc};
+use std::{cmp, env, fmt::Write};
 use util::TryFutureExt;
-use workspace::{Modal, Workspace};
-
-actions!(assistant, [Refactor]);
 
-pub fn init(cx: &mut AppContext) {
-    cx.set_global(RefactoringAssistant::new());
-    cx.add_action(RefactoringModal::deploy);
-    cx.add_action(RefactoringModal::confirm);
-    cx.add_action(RefactoringModal::cancel);
-}
+use crate::{
+    stream_completion,
+    streaming_diff::{Hunk, StreamingDiff},
+    OpenAIRequest, RequestMessage, Role,
+};
 
 pub struct RefactoringAssistant {
     pending_edits_by_editor: HashMap<usize, Task<Option<()>>>,
@@ -32,7 +23,30 @@ impl RefactoringAssistant {
         }
     }
 
-    fn refactor(&mut self, editor: &ViewHandle<Editor>, prompt: &str, cx: &mut AppContext) {
+    pub fn update<F, T>(cx: &mut AppContext, f: F) -> T
+    where
+        F: FnOnce(&mut Self, &mut AppContext) -> T,
+    {
+        if !cx.has_global::<Self>() {
+            cx.set_global(Self::new());
+        }
+
+        cx.update_global(f)
+    }
+
+    pub fn refactor(
+        &mut self,
+        editor: &ViewHandle<Editor>,
+        user_prompt: &str,
+        cx: &mut AppContext,
+    ) {
+        let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+            api_key
+        } else {
+            // TODO: ensure the API key is present by going through the assistant panel's flow.
+            return;
+        };
+
         let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
         let selection = editor.read(cx).selections.newest_anchor().clone();
         let selected_text = snapshot
@@ -83,18 +97,20 @@ impl RefactoringAssistant {
             .language_at(selection.start)
             .map(|language| language.name());
         let language_name = language_name.as_deref().unwrap_or("");
+
+        let mut prompt = String::new();
+        writeln!(prompt, "Given the following {language_name} snippet:").unwrap();
+        writeln!(prompt, "{normalized_selected_text}").unwrap();
+        writeln!(prompt, "{user_prompt}.").unwrap();
+        writeln!(prompt, "Never make remarks, reply only with the new code.").unwrap();
         let request = OpenAIRequest {
             model: "gpt-4".into(),
-            messages: vec![
-                RequestMessage {
+            messages: vec![RequestMessage {
                 role: Role::User,
-                content: format!(
-                    "Given the following {language_name} snippet:\n{normalized_selected_text}\n{prompt}. Never make remarks and reply only with the new code."
-                ),
+                content: prompt,
             }],
             stream: true,
         };
-        let api_key = env::var("OPENAI_API_KEY").unwrap();
         let response = stream_completion(api_key, cx.background().clone(), request);
         let editor = editor.downgrade();
         self.pending_edits_by_editor.insert(
@@ -116,7 +132,7 @@ impl RefactoringAssistant {
                     let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
                     let diff = cx.background().spawn(async move {
                         let mut messages = response.await?.ready_chunks(4);
-                        let mut diff = Diff::new(selected_text.to_string());
+                        let mut diff = StreamingDiff::new(selected_text.to_string());
 
                         let indentation_len;
                         let indentation_text;
@@ -177,18 +193,18 @@ impl RefactoringAssistant {
                                 buffer.start_transaction(cx);
                                 buffer.edit(
                                     hunks.into_iter().filter_map(|hunk| match hunk {
-                                        crate::diff::Hunk::Insert { text } => {
+                                        Hunk::Insert { text } => {
                                             let edit_start = snapshot.anchor_after(edit_start);
                                             Some((edit_start..edit_start, text))
                                         }
-                                        crate::diff::Hunk::Remove { len } => {
+                                        Hunk::Remove { len } => {
                                             let edit_end = edit_start + len;
                                             let edit_range = snapshot.anchor_after(edit_start)
                                                 ..snapshot.anchor_before(edit_end);
                                             edit_start = edit_end;
                                             Some((edit_range, String::new()))
                                         }
-                                        crate::diff::Hunk::Keep { len } => {
+                                        Hunk::Keep { len } => {
                                             let edit_end = edit_start + len;
                                             let edit_range = snapshot.anchor_after(edit_start)
                                                 ..snapshot.anchor_before(edit_end);
@@ -234,99 +250,3 @@ impl RefactoringAssistant {
         );
     }
 }
-
-enum Event {
-    Dismissed,
-}
-
-struct RefactoringModal {
-    active_editor: WeakViewHandle<Editor>,
-    prompt_editor: ViewHandle<Editor>,
-    has_focus: bool,
-}
-
-impl Entity for RefactoringModal {
-    type Event = Event;
-}
-
-impl View for RefactoringModal {
-    fn ui_name() -> &'static str {
-        "RefactoringModal"
-    }
-
-    fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
-        let theme = theme::current(cx);
-
-        ChildView::new(&self.prompt_editor, cx)
-            .constrained()
-            .with_width(theme.assistant.modal.width)
-            .contained()
-            .with_style(theme.assistant.modal.container)
-            .mouse::<Self>(0)
-            .on_click_out(MouseButton::Left, |_, _, cx| cx.emit(Event::Dismissed))
-            .on_click_out(MouseButton::Right, |_, _, cx| cx.emit(Event::Dismissed))
-            .aligned()
-            .right()
-            .into_any()
-    }
-
-    fn focus_in(&mut self, _: AnyViewHandle, cx: &mut ViewContext<Self>) {
-        self.has_focus = true;
-        cx.focus(&self.prompt_editor);
-    }
-
-    fn focus_out(&mut self, _: AnyViewHandle, _: &mut ViewContext<Self>) {
-        self.has_focus = false;
-    }
-}
-
-impl Modal for RefactoringModal {
-    fn has_focus(&self) -> bool {
-        self.has_focus
-    }
-
-    fn dismiss_on_event(event: &Self::Event) -> bool {
-        matches!(event, Self::Event::Dismissed)
-    }
-}
-
-impl RefactoringModal {
-    fn deploy(workspace: &mut Workspace, _: &Refactor, cx: &mut ViewContext<Workspace>) {
-        if let Some(active_editor) = workspace
-            .active_item(cx)
-            .and_then(|item| Some(item.act_as::<Editor>(cx)?.downgrade()))
-        {
-            workspace.toggle_modal(cx, |_, cx| {
-                let prompt_editor = cx.add_view(|cx| {
-                    let mut editor = Editor::auto_height(
-                        theme::current(cx).assistant.modal.editor_max_lines,
-                        Some(Arc::new(|theme| theme.assistant.modal.editor.clone())),
-                        cx,
-                    );
-                    editor
-                        .set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
-                    editor
-                });
-                cx.add_view(|_| RefactoringModal {
-                    active_editor,
-                    prompt_editor,
-                    has_focus: false,
-                })
-            });
-        }
-    }
-
-    fn cancel(&mut self, _: &Cancel, cx: &mut ViewContext<Self>) {
-        cx.emit(Event::Dismissed);
-    }
-
-    fn confirm(&mut self, _: &Confirm, cx: &mut ViewContext<Self>) {
-        if let Some(editor) = self.active_editor.upgrade(cx) {
-            let prompt = self.prompt_editor.read(cx).text(cx);
-            cx.update_global(|assistant: &mut RefactoringAssistant, cx| {
-                assistant.refactor(&editor, &prompt, cx);
-            });
-            cx.emit(Event::Dismissed);
-        }
-    }
-}

crates/ai/src/refactoring_modal.rs 🔗

@@ -0,0 +1,134 @@
+use crate::refactoring_assistant::RefactoringAssistant;
+use collections::HashSet;
+use editor::{
+    display_map::{BlockContext, BlockDisposition, BlockProperties, BlockStyle},
+    scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
+    Editor,
+};
+use gpui::{
+    actions, elements::*, platform::MouseButton, AnyViewHandle, AppContext, Entity, View,
+    ViewContext, ViewHandle, WeakViewHandle,
+};
+use std::sync::Arc;
+use workspace::Workspace;
+
+actions!(assistant, [Refactor]);
+
+pub fn init(cx: &mut AppContext) {
+    cx.add_action(RefactoringModal::deploy);
+    cx.add_action(RefactoringModal::confirm);
+    cx.add_action(RefactoringModal::cancel);
+}
+
+enum Event {
+    Dismissed,
+}
+
+struct RefactoringModal {
+    active_editor: WeakViewHandle<Editor>,
+    prompt_editor: ViewHandle<Editor>,
+    has_focus: bool,
+}
+
+impl Entity for RefactoringModal {
+    type Event = Event;
+}
+
+impl View for RefactoringModal {
+    fn ui_name() -> &'static str {
+        "RefactoringModal"
+    }
+
+    fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
+        ChildView::new(&self.prompt_editor, cx)
+            .mouse::<Self>(0)
+            .on_click_out(MouseButton::Left, |_, _, cx| cx.emit(Event::Dismissed))
+            .on_click_out(MouseButton::Right, |_, _, cx| cx.emit(Event::Dismissed))
+            .into_any()
+    }
+
+    fn focus_in(&mut self, _: AnyViewHandle, cx: &mut ViewContext<Self>) {
+        self.has_focus = true;
+        cx.focus(&self.prompt_editor);
+    }
+
+    fn focus_out(&mut self, _: AnyViewHandle, cx: &mut ViewContext<Self>) {
+        if !self.prompt_editor.is_focused(cx) {
+            self.has_focus = false;
+            cx.emit(Event::Dismissed);
+        }
+    }
+}
+
+impl RefactoringModal {
+    fn deploy(workspace: &mut Workspace, _: &Refactor, cx: &mut ViewContext<Workspace>) {
+        if let Some(active_editor) = workspace
+            .active_item(cx)
+            .and_then(|item| item.act_as::<Editor>(cx))
+        {
+            active_editor.update(cx, |editor, cx| {
+                let position = editor.selections.newest_anchor().head();
+                let prompt_editor = cx.add_view(|cx| {
+                    Editor::single_line(
+                        Some(Arc::new(|theme| theme.assistant.modal.editor.clone())),
+                        cx,
+                    )
+                });
+                let active_editor = cx.weak_handle();
+                let refactoring = cx.add_view(|_| RefactoringModal {
+                    active_editor,
+                    prompt_editor,
+                    has_focus: false,
+                });
+                cx.focus(&refactoring);
+
+                let block_id = editor.insert_blocks(
+                    [BlockProperties {
+                        style: BlockStyle::Flex,
+                        position,
+                        height: 2,
+                        render: Arc::new({
+                            let refactoring = refactoring.clone();
+                            move |cx: &mut BlockContext| {
+                                ChildView::new(&refactoring, cx)
+                                    .contained()
+                                    .with_padding_left(cx.gutter_width)
+                                    .aligned()
+                                    .left()
+                                    .into_any()
+                            }
+                        }),
+                        disposition: BlockDisposition::Below,
+                    }],
+                    Some(Autoscroll::Strategy(AutoscrollStrategy::Newest)),
+                    cx,
+                )[0];
+                cx.subscribe(&refactoring, move |_, refactoring, event, cx| {
+                    let Event::Dismissed = event;
+                    if let Some(active_editor) = refactoring.read(cx).active_editor.upgrade(cx) {
+                        cx.window_context().defer(move |cx| {
+                            active_editor.update(cx, |editor, cx| {
+                                editor.remove_blocks(HashSet::from_iter([block_id]), None, cx);
+                            })
+                        });
+                    }
+                })
+                .detach();
+            });
+        }
+    }
+
+    fn cancel(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
+        cx.emit(Event::Dismissed);
+    }
+
+    fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
+        if let Some(editor) = self.active_editor.upgrade(cx) {
+            let prompt = self.prompt_editor.read(cx).text(cx);
+            RefactoringAssistant::update(cx, |assistant, cx| {
+                assistant.refactor(&editor, &prompt, cx);
+            });
+            cx.emit(Event::Dismissed);
+        }
+    }
+}

crates/ai/src/diff.rs → crates/ai/src/streaming_diff.rs 🔗

@@ -71,7 +71,7 @@ pub enum Hunk {
     Keep { len: usize },
 }
 
-pub struct Diff {
+pub struct StreamingDiff {
     old: Vec<char>,
     new: Vec<char>,
     scores: Matrix,
@@ -80,10 +80,10 @@ pub struct Diff {
     equal_runs: HashMap<(usize, usize), u32>,
 }
 
-impl Diff {
+impl StreamingDiff {
     const INSERTION_SCORE: f64 = -1.;
     const DELETION_SCORE: f64 = -5.;
-    const EQUALITY_BASE: f64 = 1.618;
+    const EQUALITY_BASE: f64 = 2.;
     const MAX_EQUALITY_EXPONENT: i32 = 32;
 
     pub fn new(old: String) -> Self {
@@ -250,7 +250,7 @@ mod tests {
             .collect::<String>();
         log::info!("old text: {:?}", old);
 
-        let mut diff = Diff::new(old.clone());
+        let mut diff = StreamingDiff::new(old.clone());
         let mut hunks = Vec::new();
         let mut new_len = 0;
         let mut new = String::new();

styles/src/style_tree/assistant.ts 🔗

@@ -69,8 +69,7 @@ export default function assistant(): any {
             width: 500,
             editor_max_lines: 6,
             editor: {
-                background: background(theme.lowest),
-                text: text(theme.lowest, "mono", "on"),
+                text: text(theme.lowest, "mono", "on", { size: "sm" }),
                 placeholder_text: text(theme.lowest, "sans", "on", "disabled"),
                 selection: theme.players[0],
             }