Remove the ability to reply to specific message in assistant

Antonio Scandurra created

Change summary

crates/ai/src/assistant.rs | 237 +++++++++++++++++----------------------
1 file changed, 105 insertions(+), 132 deletions(-)

Detailed changes

crates/ai/src/assistant.rs 🔗

@@ -1767,15 +1767,20 @@ impl Conversation {
         cx: &mut ModelContext<Self>,
     ) -> Vec<MessageAnchor> {
         let mut user_messages = Vec::new();
-        let mut tasks = Vec::new();
 
-        let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
-            message
-                .start
-                .is_valid(self.buffer.read(cx))
-                .then_some(message.id)
-        });
+        let last_message_id = if let Some(last_message_id) =
+            self.message_anchors.iter().rev().find_map(|message| {
+                message
+                    .start
+                    .is_valid(self.buffer.read(cx))
+                    .then_some(message.id)
+            }) {
+            last_message_id
+        } else {
+            return Default::default();
+        };
 
+        let mut should_assist = false;
         for selected_message_id in selected_messages {
             let selected_message_role =
                 if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
@@ -1792,144 +1797,111 @@ impl Conversation {
                     cx,
                 ) {
                     user_messages.push(user_message);
-                } else {
-                    continue;
                 }
             } else {
-                let request = OpenAIRequest {
-                    model: self.model.full_name().to_string(),
-                    messages: self
-                        .messages(cx)
-                        .filter(|message| matches!(message.status, MessageStatus::Done))
-                        .flat_map(|message| {
-                            let mut system_message = None;
-                            if message.id == selected_message_id {
-                                system_message = Some(RequestMessage {
-                                    role: Role::System,
-                                    content: concat!(
-                                        "Treat the following messages as additional knowledge you have learned about, ",
-                                        "but act as if they were not part of this conversation. That is, treat them ",
-                                        "as if the user didn't see them and couldn't possibly inquire about them."
-                                    ).into()
-                                });
-                            }
+                should_assist = true;
+            }
+        }
 
-                            Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message)
-                        })
-                        .chain(Some(RequestMessage {
-                            role: Role::System,
-                            content: format!(
-                                "Direct your reply to message with id {}. Do not include a [Message X] header.",
-                                selected_message_id.0
-                            ),
-                        }))
-                        .collect(),
-                    stream: true,
-                };
+        if should_assist {
+            let Some(api_key) = self.api_key.borrow().clone() else {
+                return Default::default();
+            };
 
-                let Some(api_key) = self.api_key.borrow().clone() else {
-                    continue;
-                };
-                let stream = stream_completion(api_key, cx.background().clone(), request);
-                let assistant_message = self
-                    .insert_message_after(
-                        selected_message_id,
-                        Role::Assistant,
-                        MessageStatus::Pending,
-                        cx,
-                    )
-                    .unwrap();
-
-                // Queue up the user's next reply
-                if Some(selected_message_id) == last_message_id {
-                    let user_message = self
-                        .insert_message_after(
-                            assistant_message.id,
-                            Role::User,
-                            MessageStatus::Done,
-                            cx,
-                        )
-                        .unwrap();
-                    user_messages.push(user_message);
-                }
+            let request = OpenAIRequest {
+                model: self.model.full_name().to_string(),
+                messages: self
+                    .messages(cx)
+                    .filter(|message| matches!(message.status, MessageStatus::Done))
+                    .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
+                    .collect(),
+                stream: true,
+            };
 
-                tasks.push(cx.spawn_weak({
-                    |this, mut cx| async move {
-                        let assistant_message_id = assistant_message.id;
-                        let stream_completion = async {
-                            let mut messages = stream.await?;
-
-                            while let Some(message) = messages.next().await {
-                                let mut message = message?;
-                                if let Some(choice) = message.choices.pop() {
-                                    this.upgrade(&cx)
-                                        .ok_or_else(|| anyhow!("conversation was dropped"))?
-                                        .update(&mut cx, |this, cx| {
-                                            let text: Arc<str> = choice.delta.content?.into();
-                                            let message_ix = this.message_anchors.iter().position(
-                                                |message| message.id == assistant_message_id,
-                                            )?;
-                                            this.buffer.update(cx, |buffer, cx| {
-                                                let offset = this.message_anchors[message_ix + 1..]
-                                                    .iter()
-                                                    .find(|message| message.start.is_valid(buffer))
-                                                    .map_or(buffer.len(), |message| {
-                                                        message
-                                                            .start
-                                                            .to_offset(buffer)
-                                                            .saturating_sub(1)
-                                                    });
-                                                buffer.edit([(offset..offset, text)], None, cx);
-                                            });
-                                            cx.emit(ConversationEvent::StreamedCompletion);
-
-                                            Some(())
+            let stream = stream_completion(api_key, cx.background().clone(), request);
+            let assistant_message = self
+                .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
+                .unwrap();
+
+            // Queue up the user's next reply.
+            let user_message = self
+                .insert_message_after(assistant_message.id, Role::User, MessageStatus::Done, cx)
+                .unwrap();
+            user_messages.push(user_message);
+
+            let task = cx.spawn_weak({
+                |this, mut cx| async move {
+                    let assistant_message_id = assistant_message.id;
+                    let stream_completion = async {
+                        let mut messages = stream.await?;
+
+                        while let Some(message) = messages.next().await {
+                            let mut message = message?;
+                            if let Some(choice) = message.choices.pop() {
+                                this.upgrade(&cx)
+                                    .ok_or_else(|| anyhow!("conversation was dropped"))?
+                                    .update(&mut cx, |this, cx| {
+                                        let text: Arc<str> = choice.delta.content?.into();
+                                        let message_ix =
+                                            this.message_anchors.iter().position(|message| {
+                                                message.id == assistant_message_id
+                                            })?;
+                                        this.buffer.update(cx, |buffer, cx| {
+                                            let offset = this.message_anchors[message_ix + 1..]
+                                                .iter()
+                                                .find(|message| message.start.is_valid(buffer))
+                                                .map_or(buffer.len(), |message| {
+                                                    message
+                                                        .start
+                                                        .to_offset(buffer)
+                                                        .saturating_sub(1)
+                                                });
+                                            buffer.edit([(offset..offset, text)], None, cx);
                                         });
-                                }
-                                smol::future::yield_now().await;
-                            }
+                                        cx.emit(ConversationEvent::StreamedCompletion);
 
-                            this.upgrade(&cx)
-                                .ok_or_else(|| anyhow!("conversation was dropped"))?
-                                .update(&mut cx, |this, cx| {
-                                    this.pending_completions.retain(|completion| {
-                                        completion.id != this.completion_count
+                                        Some(())
                                     });
-                                    this.summarize(cx);
-                                });
+                            }
+                            smol::future::yield_now().await;
+                        }
 
-                            anyhow::Ok(())
-                        };
+                        this.upgrade(&cx)
+                            .ok_or_else(|| anyhow!("conversation was dropped"))?
+                            .update(&mut cx, |this, cx| {
+                                this.pending_completions
+                                    .retain(|completion| completion.id != this.completion_count);
+                                this.summarize(cx);
+                            });
 
-                        let result = stream_completion.await;
-                        if let Some(this) = this.upgrade(&cx) {
-                            this.update(&mut cx, |this, cx| {
-                                if let Some(metadata) =
-                                    this.messages_metadata.get_mut(&assistant_message.id)
-                                {
-                                    match result {
-                                        Ok(_) => {
-                                            metadata.status = MessageStatus::Done;
-                                        }
-                                        Err(error) => {
-                                            metadata.status = MessageStatus::Error(
-                                                error.to_string().trim().into(),
-                                            );
-                                        }
+                        anyhow::Ok(())
+                    };
+
+                    let result = stream_completion.await;
+                    if let Some(this) = this.upgrade(&cx) {
+                        this.update(&mut cx, |this, cx| {
+                            if let Some(metadata) =
+                                this.messages_metadata.get_mut(&assistant_message.id)
+                            {
+                                match result {
+                                    Ok(_) => {
+                                        metadata.status = MessageStatus::Done;
+                                    }
+                                    Err(error) => {
+                                        metadata.status =
+                                            MessageStatus::Error(error.to_string().trim().into());
                                     }
-                                    cx.notify();
                                 }
-                            });
-                        }
+                                cx.notify();
+                            }
+                        });
                     }
-                }));
-            }
-        }
+                }
+            });
 
-        if !tasks.is_empty() {
             self.pending_completions.push(PendingCompletion {
                 id: post_inc(&mut self.completion_count),
-                _tasks: tasks,
+                _task: task,
             });
         }
 
@@ -2296,7 +2268,7 @@ impl Conversation {
 
 struct PendingCompletion {
     id: usize,
-    _tasks: Vec<Task<()>>,
+    _task: Task<()>,
 }
 
 enum ConversationEditorEvent {
@@ -2844,8 +2816,9 @@ pub struct Message {
 
 impl Message {
     fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
-        let mut content = format!("[Message {}]\n", self.id.0).to_string();
-        content.extend(buffer.text_for_range(self.offset_range.clone()));
+        let content = buffer
+            .text_for_range(self.offset_range.clone())
+            .collect::<String>();
         RequestMessage {
             role: self.role,
             content: content.trim_end().into(),