Cancel assists on escape

Nathan Sobo created

Change summary

assets/keymaps/default.json |  3 +
crates/ai/src/assistant.rs  | 73 ++++++++++++++++++++++++++++----------
2 files changed, 55 insertions(+), 21 deletions(-)

Detailed changes

assets/keymaps/default.json 🔗

@@ -198,7 +198,8 @@
   {
     "context": "ContextEditor > Editor",
     "bindings": {
-      "cmd-enter": "assistant::Assist"
+      "cmd-enter": "assistant::Assist",
+      "escape": "assistant::CancelLastAssist"
     }
   },
   {

crates/ai/src/assistant.rs 🔗

@@ -1,23 +1,24 @@
 use crate::{stream_completion, OpenAIRequest, RequestMessage, Role};
-use editor::{Editor, ExcerptRange, MultiBuffer};
+use editor::{Editor, MultiBuffer};
 use futures::StreamExt;
 use gpui::{
-    actions, elements::*, Action, AppContext, Entity, ModelHandle, Subscription, View, ViewContext,
-    ViewHandle, WeakViewHandle, WindowContext,
+    actions, elements::*, Action, AppContext, Entity, ModelHandle, Subscription, Task, View,
+    ViewContext, ViewHandle, WeakViewHandle, WindowContext,
 };
 use language::{language_settings::SoftWrap, Anchor, Buffer};
 use std::sync::Arc;
-use util::ResultExt;
+use util::{post_inc, ResultExt, TryFutureExt};
 use workspace::{
     dock::{DockPosition, Panel},
     item::Item,
     pane, Pane, Workspace,
 };
 
-actions!(assistant, [NewContext, Assist]);
+actions!(assistant, [NewContext, Assist, CancelLastAssist]);
 
 pub fn init(cx: &mut AppContext) {
     cx.add_action(ContextEditor::assist);
+    cx.add_action(ContextEditor::cancel_last_assist);
 }
 
 pub enum AssistantPanelEvent {
@@ -203,6 +204,13 @@ impl Panel for AssistantPanel {
 struct ContextEditor {
     messages: Vec<Message>,
     editor: ViewHandle<Editor>,
+    completion_count: usize,
+    pending_completions: Vec<PendingCompletion>,
+}
+
+struct PendingCompletion {
+    id: usize,
+    task: Task<Option<()>>,
 }
 
 impl ContextEditor {
@@ -230,7 +238,12 @@ impl ContextEditor {
             editor
         });
 
-        Self { messages, editor }
+        Self {
+            messages,
+            editor,
+            completion_count: 0,
+            pending_completions: Vec::new(),
+        }
     }
 
     fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
@@ -265,22 +278,42 @@ impl ContextEditor {
                     );
                 });
             });
-            cx.spawn(|_, mut cx| async move {
-                let mut messages = stream.await?;
-
-                while let Some(message) = messages.next().await {
-                    let mut message = message?;
-                    if let Some(choice) = message.choices.pop() {
-                        content.update(&mut cx, |content, cx| {
-                            let text: Arc<str> = choice.delta.content?.into();
-                            content.edit([(content.len()..content.len(), text)], None, cx);
-                            Some(())
-                        });
+            let task = cx.spawn(|this, mut cx| {
+                async move {
+                    let mut messages = stream.await?;
+
+                    while let Some(message) = messages.next().await {
+                        let mut message = message?;
+                        if let Some(choice) = message.choices.pop() {
+                            content.update(&mut cx, |content, cx| {
+                                let text: Arc<str> = choice.delta.content?.into();
+                                content.edit([(content.len()..content.len(), text)], None, cx);
+                                Some(())
+                            });
+                        }
                     }
+
+                    this.update(&mut cx, |this, _| {
+                        this.pending_completions
+                            .retain(|completion| completion.id != this.completion_count);
+                    })
+                    .ok();
+
+                    anyhow::Ok(())
                 }
-                anyhow::Ok(())
-            })
-            .detach_and_log_err(cx);
+                .log_err()
+            });
+
+            self.pending_completions.push(PendingCompletion {
+                id: post_inc(&mut self.completion_count),
+                task,
+            });
+        }
+    }
+
+    fn cancel_last_assist(&mut self, _: &CancelLastAssist, cx: &mut ViewContext<Self>) {
+        if self.pending_completions.pop().is_none() {
+            cx.propagate_action();
         }
     }
 }