Avoid including pending or errored messages on `assist`

Antonio Scandurra created

Change summary

Cargo.lock                 |   1 
crates/ai/Cargo.toml       |   1 
crates/ai/src/assistant.rs | 148 ++++++++++++++++++++++++---------------
3 files changed, 94 insertions(+), 56 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -114,6 +114,7 @@ dependencies = [
  "serde",
  "serde_json",
  "settings",
+ "smol",
  "theme",
  "tiktoken-rs",
  "util",

crates/ai/Cargo.toml 🔗

@@ -28,6 +28,7 @@ isahc.workspace = true
 schemars.workspace = true
 serde.workspace = true
 serde_json.workspace = true
+smol.workspace = true
 tiktoken-rs = "0.4"
 
 [dev-dependencies]

crates/ai/src/assistant.rs 🔗

@@ -518,7 +518,7 @@ impl Assistant {
             MessageMetadata {
                 role: Role::User,
                 sent_at: Local::now(),
-                error: None,
+                status: MessageStatus::Done,
             },
         );
 
@@ -595,6 +595,7 @@ impl Assistant {
         cx: &mut ModelContext<Self>,
     ) -> Vec<MessageAnchor> {
         let mut user_messages = Vec::new();
+        let mut tasks = Vec::new();
         for selected_message_id in selected_messages {
             let selected_message_role =
                 if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
@@ -604,15 +605,22 @@ impl Assistant {
                 };
 
             if selected_message_role == Role::Assistant {
-                let Some(user_message) = self.insert_message_after(selected_message_id, Role::User, cx) else {
+                if let Some(user_message) = self.insert_message_after(
+                    selected_message_id,
+                    Role::User,
+                    MessageStatus::Done,
+                    cx,
+                ) {
+                    user_messages.push(user_message);
+                } else {
                     continue;
-                };
-                user_messages.push(user_message);
+                }
             } else {
                 let request = OpenAIRequest {
                     model: self.model.clone(),
                     messages: self
                         .messages(cx)
+                        .filter(|message| matches!(message.status, MessageStatus::Done))
                         .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
                         .chain(Some(RequestMessage {
                             role: Role::System,
@@ -628,10 +636,15 @@ impl Assistant {
                 let Some(api_key) = self.api_key.borrow().clone() else { continue };
                 let stream = stream_completion(api_key, cx.background().clone(), request);
                 let assistant_message = self
-                    .insert_message_after(selected_message_id, Role::Assistant, cx)
+                    .insert_message_after(
+                        selected_message_id,
+                        Role::Assistant,
+                        MessageStatus::Pending,
+                        cx,
+                    )
                     .unwrap();
 
-                let task = cx.spawn_weak({
+                tasks.push(cx.spawn_weak({
                     |this, mut cx| async move {
                         let assistant_message_id = assistant_message.id;
                         let stream_completion = async {
@@ -648,16 +661,15 @@ impl Assistant {
                                                 |message| message.id == assistant_message_id,
                                             )?;
                                             this.buffer.update(cx, |buffer, cx| {
-                                                let offset = if message_ix + 1
-                                                    == this.message_anchors.len()
-                                                {
-                                                    buffer.len()
-                                                } else {
-                                                    this.message_anchors[message_ix + 1]
-                                                        .start
-                                                        .to_offset(buffer)
-                                                        .saturating_sub(1)
-                                                };
+                                                let offset = this.message_anchors[message_ix + 1..]
+                                                    .iter()
+                                                    .find(|message| message.start.is_valid(buffer))
+                                                    .map_or(buffer.len(), |message| {
+                                                        message
+                                                            .start
+                                                            .to_offset(buffer)
+                                                            .saturating_sub(1)
+                                                    });
                                                 buffer.edit([(offset..offset, text)], None, cx);
                                             });
                                             cx.emit(AssistantEvent::StreamedCompletion);
@@ -665,6 +677,7 @@ impl Assistant {
                                             Some(())
                                         });
                                 }
+                                smol::future::yield_now().await;
                             }
 
                             this.upgrade(&cx)
@@ -682,26 +695,35 @@ impl Assistant {
                         let result = stream_completion.await;
                         if let Some(this) = this.upgrade(&cx) {
                             this.update(&mut cx, |this, cx| {
-                                if let Err(error) = result {
-                                    if let Some(metadata) =
-                                        this.messages_metadata.get_mut(&assistant_message.id)
-                                    {
-                                        metadata.error = Some(error.to_string().trim().into());
-                                        cx.notify();
+                                if let Some(metadata) =
+                                    this.messages_metadata.get_mut(&assistant_message.id)
+                                {
+                                    match result {
+                                        Ok(_) => {
+                                            metadata.status = MessageStatus::Done;
+                                        }
+                                        Err(error) => {
+                                            metadata.status = MessageStatus::Error(
+                                                error.to_string().trim().into(),
+                                            );
+                                        }
                                     }
+                                    cx.notify();
                                 }
                             });
                         }
                     }
-                });
-
-                self.pending_completions.push(PendingCompletion {
-                    id: post_inc(&mut self.completion_count),
-                    _task: task,
-                });
+                }));
             }
         }
 
+        if !tasks.is_empty() {
+            self.pending_completions.push(PendingCompletion {
+                id: post_inc(&mut self.completion_count),
+                _tasks: tasks,
+            });
+        }
+
         user_messages
     }
 
@@ -723,6 +745,7 @@ impl Assistant {
         &mut self,
         message_id: MessageId,
         role: Role,
+        status: MessageStatus,
         cx: &mut ModelContext<Self>,
     ) -> Option<MessageAnchor> {
         if let Some(prev_message_ix) = self
@@ -749,7 +772,7 @@ impl Assistant {
                 MessageMetadata {
                     role,
                     sent_at: Local::now(),
-                    error: None,
+                    status,
                 },
             );
             cx.emit(AssistantEvent::MessagesEdited);
@@ -808,7 +831,7 @@ impl Assistant {
                 MessageMetadata {
                     role,
                     sent_at: Local::now(),
-                    error: None,
+                    status: MessageStatus::Done,
                 },
             );
 
@@ -850,7 +873,7 @@ impl Assistant {
                     MessageMetadata {
                         role,
                         sent_at: Local::now(),
-                        error: None,
+                        status: MessageStatus::Done,
                     },
                 );
                 (Some(selection), Some(suffix))
@@ -970,7 +993,7 @@ impl Assistant {
                     anchor: message_anchor.start,
                     role: metadata.role,
                     sent_at: metadata.sent_at,
-                    error: metadata.error.clone(),
+                    status: metadata.status.clone(),
                 });
             }
             None
@@ -980,7 +1003,7 @@ impl Assistant {
 
 struct PendingCompletion {
     id: usize,
-    _task: Task<()>,
+    _tasks: Vec<Task<()>>,
 }
 
 enum AssistantEditorEvent {
@@ -1239,22 +1262,28 @@ impl AssistantEditor {
                                     .with_style(style.sent_at.container)
                                     .aligned(),
                                 )
-                                .with_children(message.error.as_ref().map(|error| {
-                                    Svg::new("icons/circle_x_mark_12.svg")
-                                        .with_color(style.error_icon.color)
-                                        .constrained()
-                                        .with_width(style.error_icon.width)
-                                        .contained()
-                                        .with_style(style.error_icon.container)
-                                        .with_tooltip::<ErrorTooltip>(
-                                            message_id.0,
-                                            error.to_string(),
-                                            None,
-                                            theme.tooltip.clone(),
-                                            cx,
+                                .with_children(
+                                    if let MessageStatus::Error(error) = &message.status {
+                                        Some(
+                                            Svg::new("icons/circle_x_mark_12.svg")
+                                                .with_color(style.error_icon.color)
+                                                .constrained()
+                                                .with_width(style.error_icon.width)
+                                                .contained()
+                                                .with_style(style.error_icon.container)
+                                                .with_tooltip::<ErrorTooltip>(
+                                                    message_id.0,
+                                                    error.to_string(),
+                                                    None,
+                                                    theme.tooltip.clone(),
+                                                    cx,
+                                                )
+                                                .aligned(),
                                         )
-                                        .aligned()
-                                }))
+                                    } else {
+                                        None
+                                    },
+                                )
                                 .aligned()
                                 .left()
                                 .contained()
@@ -1502,7 +1531,14 @@ struct MessageAnchor {
 struct MessageMetadata {
     role: Role,
     sent_at: DateTime<Local>,
-    error: Option<Arc<str>>,
+    status: MessageStatus,
+}
+
+#[derive(Clone, Debug)]
+enum MessageStatus {
+    Pending,
+    Done,
+    Error(Arc<str>),
 }
 
 #[derive(Clone, Debug)]
@@ -1513,7 +1549,7 @@ pub struct Message {
     anchor: language::Anchor,
     role: Role,
     sent_at: DateTime<Local>,
-    error: Option<Arc<str>>,
+    status: MessageStatus,
 }
 
 impl Message {
@@ -1632,7 +1668,7 @@ mod tests {
 
         let message_2 = assistant.update(cx, |assistant, cx| {
             assistant
-                .insert_message_after(message_1.id, Role::Assistant, cx)
+                .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
                 .unwrap()
         });
         assert_eq!(
@@ -1656,7 +1692,7 @@ mod tests {
 
         let message_3 = assistant.update(cx, |assistant, cx| {
             assistant
-                .insert_message_after(message_2.id, Role::User, cx)
+                .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
                 .unwrap()
         });
         assert_eq!(
@@ -1670,7 +1706,7 @@ mod tests {
 
         let message_4 = assistant.update(cx, |assistant, cx| {
             assistant
-                .insert_message_after(message_2.id, Role::User, cx)
+                .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
                 .unwrap()
         });
         assert_eq!(
@@ -1731,7 +1767,7 @@ mod tests {
         // Ensure we can still insert after a merged message.
         let message_5 = assistant.update(cx, |assistant, cx| {
             assistant
-                .insert_message_after(message_1.id, Role::System, cx)
+                .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
                 .unwrap()
         });
         assert_eq!(
@@ -1852,14 +1888,14 @@ mod tests {
         buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
         let message_2 = assistant
             .update(cx, |assistant, cx| {
-                assistant.insert_message_after(message_1.id, Role::User, cx)
+                assistant.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
             })
             .unwrap();
         buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
 
         let message_3 = assistant
             .update(cx, |assistant, cx| {
-                assistant.insert_message_after(message_2.id, Role::User, cx)
+                assistant.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
             })
             .unwrap();
         buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));