Allow for multi-cursor `assist` and `cycle_role` actions

Antonio Scandurra , Nathan Sobo , and Kyle Caverly created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>
Co-Authored-By: Kyle Caverly <kyle@zed.dev>

Change summary

crates/ai/src/assistant.rs | 354 +++++++++++++++++++++++++++------------
1 file changed, 243 insertions(+), 111 deletions(-)

Detailed changes

crates/ai/src/assistant.rs 🔗

@@ -22,7 +22,7 @@ use gpui::{
     Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
 };
 use isahc::{http::StatusCode, Request, RequestExt};
-use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, Selection, ToOffset as _};
+use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
 use serde::Deserialize;
 use settings::SettingsStore;
 use std::{
@@ -591,106 +591,129 @@ impl Assistant {
 
     fn assist(
         &mut self,
-        selection: Selection<usize>,
+        selected_messages: HashSet<MessageId>,
         cx: &mut ModelContext<Self>,
-    ) -> Option<(MessageAnchor, MessageAnchor)> {
-        let request = OpenAIRequest {
-            model: self.model.clone(),
-            messages: self
-                .messages(cx)
-                .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
-                .collect(),
-            stream: true,
-        };
+    ) -> Vec<MessageAnchor> {
+        let mut user_messages = Vec::new();
+        for selected_message_id in selected_messages {
+            let selected_message_role =
+                if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
+                    metadata.role
+                } else {
+                    continue;
+                };
+            let Some(user_message) = self.insert_message_after(selected_message_id, Role::User, cx) else {
+                continue;
+            };
+            user_messages.push(user_message);
+            if selected_message_role == Role::User {
+                let request = OpenAIRequest {
+                    model: self.model.clone(),
+                    messages: self
+                        .messages(cx)
+                        .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
+                        .chain(Some(RequestMessage {
+                            role: Role::System,
+                            content: format!(
+                                "Direct your reply to message with id {}. Do not include a [Message X] header.",
+                                selected_message_id.0
+                            ),
+                        }))
+                        .collect(),
+                    stream: true,
+                };
+
+                let Some(api_key) = self.api_key.borrow().clone() else { continue };
+                let stream = stream_completion(api_key, cx.background().clone(), request);
+                let assistant_message = self
+                    .insert_message_after(selected_message_id, Role::Assistant, cx)
+                    .unwrap();
+
+                let task = cx.spawn_weak({
+                    |this, mut cx| async move {
+                        let assistant_message_id = assistant_message.id;
+                        let stream_completion = async {
+                            let mut messages = stream.await?;
+
+                            while let Some(message) = messages.next().await {
+                                let mut message = message?;
+                                if let Some(choice) = message.choices.pop() {
+                                    this.upgrade(&cx)
+                                        .ok_or_else(|| anyhow!("assistant was dropped"))?
+                                        .update(&mut cx, |this, cx| {
+                                            let text: Arc<str> = choice.delta.content?.into();
+                                            let message_ix = this.message_anchors.iter().position(
+                                                |message| message.id == assistant_message_id,
+                                            )?;
+                                            this.buffer.update(cx, |buffer, cx| {
+                                                let offset = if message_ix + 1
+                                                    == this.message_anchors.len()
+                                                {
+                                                    buffer.len()
+                                                } else {
+                                                    this.message_anchors[message_ix + 1]
+                                                        .start
+                                                        .to_offset(buffer)
+                                                        .saturating_sub(1)
+                                                };
+                                                buffer.edit([(offset..offset, text)], None, cx);
+                                            });
+                                            cx.emit(AssistantEvent::StreamedCompletion);
+
+                                            Some(())
+                                        });
+                                }
+                            }
 
-        let api_key = self.api_key.borrow().clone()?;
-        let stream = stream_completion(api_key, cx.background().clone(), request);
-        let assistant_message = self.insert_message_after(
-            self.message_for_offset(selection.head(), cx)?.id,
-            Role::Assistant,
-            cx,
-        )?;
-        let user_message = self.insert_message_after(assistant_message.id, Role::User, cx)?;
-
-        let task = cx.spawn_weak({
-            |this, mut cx| async move {
-                let assistant_message_id = assistant_message.id;
-                let stream_completion = async {
-                    let mut messages = stream.await?;
-
-                    while let Some(message) = messages.next().await {
-                        let mut message = message?;
-                        if let Some(choice) = message.choices.pop() {
                             this.upgrade(&cx)
                                 .ok_or_else(|| anyhow!("assistant was dropped"))?
                                 .update(&mut cx, |this, cx| {
-                                    let text: Arc<str> = choice.delta.content?.into();
-                                    let message_ix = this
-                                        .message_anchors
-                                        .iter()
-                                        .position(|message| message.id == assistant_message_id)?;
-                                    this.buffer.update(cx, |buffer, cx| {
-                                        let offset = if message_ix + 1 == this.message_anchors.len()
-                                        {
-                                            buffer.len()
-                                        } else {
-                                            this.message_anchors[message_ix + 1]
-                                                .start
-                                                .to_offset(buffer)
-                                                .saturating_sub(1)
-                                        };
-                                        buffer.edit([(offset..offset, text)], None, cx);
+                                    this.pending_completions.retain(|completion| {
+                                        completion.id != this.completion_count
                                     });
-                                    cx.emit(AssistantEvent::StreamedCompletion);
-
-                                    Some(())
+                                    this.summarize(cx);
                                 });
+
+                            anyhow::Ok(())
+                        };
+
+                        let result = stream_completion.await;
+                        if let Some(this) = this.upgrade(&cx) {
+                            this.update(&mut cx, |this, cx| {
+                                if let Err(error) = result {
+                                    if let Some(metadata) =
+                                        this.messages_metadata.get_mut(&assistant_message.id)
+                                    {
+                                        metadata.error = Some(error.to_string().trim().into());
+                                        cx.notify();
+                                    }
+                                }
+                            });
                         }
                     }
+                });
 
-                    this.upgrade(&cx)
-                        .ok_or_else(|| anyhow!("assistant was dropped"))?
-                        .update(&mut cx, |this, cx| {
-                            this.pending_completions
-                                .retain(|completion| completion.id != this.completion_count);
-                            this.summarize(cx);
-                        });
-
-                    anyhow::Ok(())
-                };
-
-                let result = stream_completion.await;
-                if let Some(this) = this.upgrade(&cx) {
-                    this.update(&mut cx, |this, cx| {
-                        if let Err(error) = result {
-                            if let Some(metadata) =
-                                this.messages_metadata.get_mut(&assistant_message.id)
-                            {
-                                metadata.error = Some(error.to_string().trim().into());
-                                cx.notify();
-                            }
-                        }
-                    });
-                }
+                self.pending_completions.push(PendingCompletion {
+                    id: post_inc(&mut self.completion_count),
+                    _task: task,
+                });
             }
-        });
+        }
 
-        self.pending_completions.push(PendingCompletion {
-            id: post_inc(&mut self.completion_count),
-            _task: task,
-        });
-        Some((assistant_message, user_message))
+        user_messages
     }
 
     fn cancel_last_assist(&mut self) -> bool {
         self.pending_completions.pop().is_some()
     }
 
-    fn cycle_message_role(&mut self, id: MessageId, cx: &mut ModelContext<Self>) {
-        if let Some(metadata) = self.messages_metadata.get_mut(&id) {
-            metadata.role.cycle();
-            cx.emit(AssistantEvent::MessagesEdited);
-            cx.notify();
+    fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
+        for id in ids {
+            if let Some(metadata) = self.messages_metadata.get_mut(&id) {
+                metadata.role.cycle();
+                cx.emit(AssistantEvent::MessagesEdited);
+                cx.notify();
+            }
         }
     }
 
@@ -884,14 +907,39 @@ impl Assistant {
         }
     }
 
-    fn message_for_offset<'a>(&'a self, offset: usize, cx: &'a AppContext) -> Option<Message> {
+    fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
+        self.messages_for_offsets([offset], cx).pop()
+    }
+
+    fn messages_for_offsets(
+        &self,
+        offsets: impl IntoIterator<Item = usize>,
+        cx: &AppContext,
+    ) -> Vec<Message> {
+        let mut result = Vec::new();
+
+        let buffer_len = self.buffer.read(cx).len();
         let mut messages = self.messages(cx).peekable();
-        while let Some(message) = messages.next() {
-            if message.range.contains(&offset) || messages.peek().is_none() {
-                return Some(message);
+        let mut offsets = offsets.into_iter().peekable();
+        while let Some(offset) = offsets.next() {
+            // Skip messages that start after the offset.
+            while messages.peek().map_or(false, |message| {
+                message.range.end < offset || (message.range.end == offset && offset < buffer_len)
+            }) {
+                messages.next();
             }
+            let Some(message) = messages.peek() else { continue };
+
+            // Skip offsets that are in the same message.
+            while offsets.peek().map_or(false, |offset| {
+                message.range.contains(offset) || message.range.end == buffer_len
+            }) {
+                offsets.next();
+            }
+
+            result.push(message.clone());
         }
-        None
+        result
     }
 
     fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
@@ -983,24 +1031,32 @@ impl AssistantEditor {
     }
 
     fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
-        let selection = self.editor.read(cx).selections.newest(cx);
-        let user_message = self.assistant.update(cx, |assistant, cx| {
-            let (_, user_message) = assistant.assist(selection, cx)?;
-            Some(user_message)
+        let cursors = self.cursors(cx);
+
+        let user_messages = self.assistant.update(cx, |assistant, cx| {
+            let selected_messages = assistant
+                .messages_for_offsets(cursors, cx)
+                .into_iter()
+                .map(|message| message.id)
+                .collect();
+            assistant.assist(selected_messages, cx)
+        });
+        let new_selections = user_messages
+            .iter()
+            .map(|message| {
+                let cursor = message
+                    .start
+                    .to_offset(self.assistant.read(cx).buffer.read(cx));
+                cursor..cursor
+            })
+            .collect::<Vec<_>>();
+        self.editor.update(cx, |editor, cx| {
+            editor.change_selections(
+                Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
+                cx,
+                |selections| selections.select_ranges(new_selections),
+            );
         });
-
-        if let Some(user_message) = user_message {
-            let cursor = user_message
-                .start
-                .to_offset(&self.assistant.read(cx).buffer.read(cx));
-            self.editor.update(cx, |editor, cx| {
-                editor.change_selections(
-                    Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
-                    cx,
-                    |selections| selections.select_ranges([cursor..cursor]),
-                );
-            });
-        }
     }
 
     fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
@@ -1013,14 +1069,25 @@ impl AssistantEditor {
     }
 
     fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
-        let cursor_offset = self.editor.read(cx).selections.newest(cx).head();
+        let cursors = self.cursors(cx);
         self.assistant.update(cx, |assistant, cx| {
-            if let Some(message) = assistant.message_for_offset(cursor_offset, cx) {
-                assistant.cycle_message_role(message.id, cx);
-            }
+            let messages = assistant
+                .messages_for_offsets(cursors, cx)
+                .into_iter()
+                .map(|message| message.id)
+                .collect();
+            assistant.cycle_message_roles(messages, cx)
         });
     }
 
+    fn cursors(&self, cx: &AppContext) -> Vec<usize> {
+        let selections = self.editor.read(cx).selections.all::<usize>(cx);
+        selections
+            .into_iter()
+            .map(|selection| selection.head())
+            .collect()
+    }
+
     fn handle_assistant_event(
         &mut self,
         _: ModelHandle<Assistant>,
@@ -1149,7 +1216,10 @@ impl AssistantEditor {
                                 let assistant = assistant.clone();
                                 move |_, _, cx| {
                                     assistant.update(cx, |assistant, cx| {
-                                        assistant.cycle_message_role(message_id, cx)
+                                        assistant.cycle_message_roles(
+                                            HashSet::from_iter(Some(message_id)),
+                                            cx,
+                                        )
                                     })
                                 }
                             });
@@ -1444,9 +1514,11 @@ pub struct Message {
 
 impl Message {
     fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
+        let mut content = format!("[Message {}]\n", self.id.0).to_string();
+        content.extend(buffer.text_for_range(self.range.clone()));
         RequestMessage {
             role: self.role,
-            content: buffer.text_for_range(self.range.clone()).collect(),
+            content,
         }
     }
 }
@@ -1761,6 +1833,66 @@ mod tests {
         );
     }
 
+    #[gpui::test]
+    fn test_messages_for_offsets(cx: &mut AppContext) {
+        let registry = Arc::new(LanguageRegistry::test());
+        let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
+        let buffer = assistant.read(cx).buffer.clone();
+
+        let message_1 = assistant.read(cx).message_anchors[0].clone();
+        assert_eq!(
+            messages(&assistant, cx),
+            vec![(message_1.id, Role::User, 0..0)]
+        );
+
+        buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
+        let message_2 = assistant
+            .update(cx, |assistant, cx| {
+                assistant.insert_message_after(message_1.id, Role::User, cx)
+            })
+            .unwrap();
+        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
+
+        let message_3 = assistant
+            .update(cx, |assistant, cx| {
+                assistant.insert_message_after(message_2.id, Role::User, cx)
+            })
+            .unwrap();
+        buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
+
+        assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
+        assert_eq!(
+            messages(&assistant, cx),
+            vec![
+                (message_1.id, Role::User, 0..4),
+                (message_2.id, Role::User, 4..8),
+                (message_3.id, Role::User, 8..11)
+            ]
+        );
+
+        assert_eq!(
+            message_ids_for_offsets(&assistant, &[0, 4, 9], cx),
+            [message_1.id, message_2.id, message_3.id]
+        );
+        assert_eq!(
+            message_ids_for_offsets(&assistant, &[0, 1, 11], cx),
+            [message_1.id, message_3.id]
+        );
+
+        fn message_ids_for_offsets(
+            assistant: &ModelHandle<Assistant>,
+            offsets: &[usize],
+            cx: &AppContext,
+        ) -> Vec<MessageId> {
+            assistant
+                .read(cx)
+                .messages_for_offsets(offsets.iter().copied(), cx)
+                .into_iter()
+                .map(|message| message.id)
+                .collect()
+        }
+    }
+
     fn messages(
         assistant: &ModelHandle<Assistant>,
         cx: &AppContext,