Insert reply after assistant message when hitting `cmd-enter`

Antonio Scandurra and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

crates/ai/src/assistant.rs | 104 ++++++++++++++++++++++++++++-----------
1 file changed, 73 insertions(+), 31 deletions(-)

Detailed changes

crates/ai/src/assistant.rs 🔗

@@ -459,7 +459,7 @@ impl Assistant {
             api_key,
             buffer,
         };
-        this.push_message(Role::User, cx);
+        this.insert_message_after(ExcerptId::max(), Role::User, cx);
         this.count_remaining_tokens(cx);
         this
     }
@@ -498,7 +498,7 @@ impl Assistant {
             })
             .collect::<Vec<_>>();
         let model = self.model.clone();
-        self.pending_token_count = cx.spawn(|this, mut cx| {
+        self.pending_token_count = cx.spawn_weak(|this, mut cx| {
             async move {
                 cx.background().timer(Duration::from_millis(200)).await;
                 let token_count = cx
@@ -506,11 +506,13 @@ impl Assistant {
                     .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
                     .await?;
 
-                this.update(&mut cx, |this, cx| {
-                    this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
-                    this.token_count = Some(token_count);
-                    cx.notify()
-                });
+                this.upgrade(&cx)
+                    .ok_or_else(|| anyhow!("assistant was dropped"))?
+                    .update(&mut cx, |this, cx| {
+                        this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
+                        this.token_count = Some(token_count);
+                        cx.notify()
+                    });
                 anyhow::Ok(())
             }
             .log_err()
@@ -547,9 +549,10 @@ impl Assistant {
         let api_key = self.api_key.borrow().clone();
         if let Some(api_key) = api_key {
             let stream = stream_completion(api_key, cx.background().clone(), request);
-            let (excerpt_id, content) = self.push_message(Role::Assistant, cx);
-            self.push_message(Role::User, cx);
-            let task = cx.spawn(|this, mut cx| async move {
+            let (excerpt_id, content) =
+                self.insert_message_after(ExcerptId::max(), Role::Assistant, cx);
+            self.insert_message_after(ExcerptId::max(), Role::User, cx);
+            let task = cx.spawn_weak(|this, mut cx| async move {
                 let stream_completion = async {
                     let mut messages = stream.await?;
 
@@ -564,22 +567,26 @@ impl Assistant {
                         }
                     }
 
-                    this.update(&mut cx, |this, cx| {
-                        this.pending_completions
-                            .retain(|completion| completion.id != this.completion_count);
-                        this.summarize(cx);
-                    });
+                    this.upgrade(&cx)
+                        .ok_or_else(|| anyhow!("assistant was dropped"))?
+                        .update(&mut cx, |this, cx| {
+                            this.pending_completions
+                                .retain(|completion| completion.id != this.completion_count);
+                            this.summarize(cx);
+                        });
 
                     anyhow::Ok(())
                 };
 
                 if let Err(error) = stream_completion.await {
-                    this.update(&mut cx, |this, cx| {
-                        if let Some(metadata) = this.messages_metadata.get_mut(&excerpt_id) {
-                            metadata.error = Some(error.to_string().trim().into());
-                            cx.notify();
-                        }
-                    })
+                    if let Some(this) = this.upgrade(&cx) {
+                        this.update(&mut cx, |this, cx| {
+                            if let Some(metadata) = this.messages_metadata.get_mut(&excerpt_id) {
+                                metadata.error = Some(error.to_string().trim().into());
+                                cx.notify();
+                            }
+                        });
+                    }
                 }
             });
 
@@ -632,8 +639,9 @@ impl Assistant {
         }
     }
 
-    fn push_message(
+    fn insert_message_after(
         &mut self,
+        excerpt_id: ExcerptId,
         role: Role,
         cx: &mut ModelContext<Self>,
     ) -> (ExcerptId, ModelHandle<Buffer>) {
@@ -654,9 +662,10 @@ impl Assistant {
             buffer.set_language_registry(self.languages.clone());
             buffer
         });
-        let excerpt_id = self.buffer.update(cx, |buffer, cx| {
+        let new_excerpt_id = self.buffer.update(cx, |buffer, cx| {
             buffer
-                .push_excerpts(
+                .insert_excerpts_after(
+                    excerpt_id,
                     content.clone(),
                     vec![ExcerptRange {
                         context: 0..0,
@@ -668,19 +677,27 @@ impl Assistant {
                 .unwrap()
         });
 
-        self.messages.push(Message {
-            excerpt_id,
-            content: content.clone(),
-        });
+        let ix = self
+            .messages
+            .iter()
+            .position(|message| message.excerpt_id == excerpt_id)
+            .map_or(self.messages.len(), |ix| ix + 1);
+        self.messages.insert(
+            ix,
+            Message {
+                excerpt_id: new_excerpt_id,
+                content: content.clone(),
+            },
+        );
         self.messages_metadata.insert(
-            excerpt_id,
+            new_excerpt_id,
             MessageMetadata {
                 role,
                 sent_at: Local::now(),
                 error: None,
             },
         );
-        (excerpt_id, content)
+        (new_excerpt_id, content)
     }
 
     fn summarize(&mut self, cx: &mut ModelContext<Self>) {
@@ -882,7 +899,7 @@ impl AssistantEditor {
                     if metadata.role == Role::User {
                         assistant.assist(cx);
                     } else {
-                        assistant.push_message(Role::User, cx);
+                        assistant.insert_message_after(excerpt_id, Role::User, cx);
                     }
                 }
             }
@@ -1227,3 +1244,28 @@ async fn stream_completion(
         }
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use gpui::AppContext;
+
+    #[gpui::test]
+    fn test_inserting_and_removing_messages(cx: &mut AppContext) {
+        let registry = Arc::new(LanguageRegistry::test());
+
+        cx.add_model(|cx| {
+            let mut assistant = Assistant::new(Default::default(), registry, cx);
+            let (excerpt_1, _) =
+                assistant.insert_message_after(ExcerptId::max(), Role::Assistant, cx);
+            let (excerpt_2, _) = assistant.insert_message_after(excerpt_1, Role::User, cx);
+            let (excerpt_3, _) = assistant.insert_message_after(excerpt_1, Role::User, cx);
+            assistant.remove_empty_messages(
+                HashSet::from_iter([excerpt_2, excerpt_3]),
+                Default::default(),
+                cx,
+            );
+            assistant
+        });
+    }
+}