Get back to a compiling state with `Buffer` backing the assistant

Antonio Scandurra created

Change summary

crates/ai/src/assistant.rs | 328 ++++++++++++++++++++++++++-------------
1 file changed, 214 insertions(+), 114 deletions(-)

Detailed changes

crates/ai/src/assistant.rs 🔗

@@ -11,7 +11,7 @@ use editor::{
         autoscroll::{Autoscroll, AutoscrollStrategy},
         ScrollAnchor,
     },
-    Anchor, DisplayPoint, Editor, ExcerptId,
+    Anchor, DisplayPoint, Editor, ToOffset as _,
 };
 use fs::Fs;
 use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
@@ -25,10 +25,13 @@ use gpui::{
     Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
 };
 use isahc::{http::StatusCode, Request, RequestExt};
-use language::{language_settings::SoftWrap, Buffer, LanguageRegistry};
+use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
 use serde::Deserialize;
 use settings::SettingsStore;
-use std::{borrow::Cow, cell::RefCell, cmp, fmt::Write, io, rc::Rc, sync::Arc, time::Duration};
+use std::{
+    borrow::Cow, cell::RefCell, cmp, fmt::Write, io, iter, ops::Range, rc::Rc, sync::Arc,
+    time::Duration,
+};
 use util::{post_inc, truncate_and_trailoff, ResultExt, TryFutureExt};
 use workspace::{
     dock::{DockPosition, Panel},
@@ -507,16 +510,16 @@ impl Assistant {
 
     fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
         let messages = self
-            .messages
-            .iter()
+            .open_ai_request_messages(cx)
+            .into_iter()
             .filter_map(|message| {
                 Some(tiktoken_rs::ChatCompletionRequestMessage {
-                    role: match self.messages_metadata.get(&message.excerpt_id)?.role {
+                    role: match message.role {
                         Role::User => "user".into(),
                         Role::Assistant => "assistant".into(),
                         Role::System => "system".into(),
                     },
-                    content: message.content.read(cx).text(),
+                    content: message.content,
                     name: None,
                 })
             })
@@ -554,45 +557,47 @@ impl Assistant {
     }
 
     fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<(Message, Message)> {
-        let messages = self
-            .messages
-            .iter()
-            .filter_map(|message| {
-                Some(RequestMessage {
-                    role: self.messages_metadata.get(&message.excerpt_id)?.role,
-                    content: message.content.read(cx).text(),
-                })
-            })
-            .collect();
         let request = OpenAIRequest {
             model: self.model.clone(),
-            messages,
+            messages: self.open_ai_request_messages(cx),
             stream: true,
         };
 
         let api_key = self.api_key.borrow().clone()?;
         let stream = stream_completion(api_key, cx.background().clone(), request);
-        let assistant_message = self.insert_message_after(ExcerptId::max(), Role::Assistant, cx);
-        let user_message = self.insert_message_after(ExcerptId::max(), Role::User, cx);
+        let assistant_message =
+            self.insert_message_after(self.messages.last()?.id, Role::Assistant, cx)?;
+        let user_message = self.insert_message_after(assistant_message.id, Role::User, cx)?;
         let task = cx.spawn_weak({
-            let assistant_message = assistant_message.clone();
             |this, mut cx| async move {
-                let assistant_message = assistant_message;
+                let assistant_message_id = assistant_message.id;
                 let stream_completion = async {
                     let mut messages = stream.await?;
 
                     while let Some(message) = messages.next().await {
                         let mut message = message?;
                         if let Some(choice) = message.choices.pop() {
-                            assistant_message.content.update(&mut cx, |content, cx| {
-                                let text: Arc<str> = choice.delta.content?.into();
-                                content.edit([(content.len()..content.len(), text)], None, cx);
-                                Some(())
-                            });
                             this.upgrade(&cx)
                                 .ok_or_else(|| anyhow!("assistant was dropped"))?
-                                .update(&mut cx, |_, cx| {
-                                    cx.emit(AssistantEvent::StreamedCompletion);
+                                .update(&mut cx, |this, cx| {
+                                    let text: Arc<str> = choice.delta.content?.into();
+                                    let message_ix = this
+                                        .messages
+                                        .iter()
+                                        .position(|message| message.id == assistant_message_id)?;
+                                    this.buffer.update(cx, |buffer, cx| {
+                                        let offset = if message_ix + 1 == this.messages.len() {
+                                            buffer.len()
+                                        } else {
+                                            this.messages[message_ix + 1]
+                                                .start
+                                                .to_offset(buffer)
+                                                .saturating_sub(1)
+                                        };
+                                        buffer.edit([(offset..offset, text)], None, cx);
+                                    });
+
+                                    Some(())
                                 });
                         }
                     }
@@ -612,9 +617,8 @@ impl Assistant {
                 if let Some(this) = this.upgrade(&cx) {
                     this.update(&mut cx, |this, cx| {
                         if let Err(error) = result {
-                            if let Some(metadata) = this
-                                .messages_metadata
-                                .get_mut(&assistant_message.excerpt_id)
+                            if let Some(metadata) =
+                                this.messages_metadata.get_mut(&assistant_message.id)
                             {
                                 metadata.error = Some(error.to_string().trim().into());
                                 cx.notify();
@@ -642,33 +646,33 @@ impl Assistant {
         protected_offsets: HashSet<usize>,
         cx: &mut ModelContext<Self>,
     ) {
-        let mut offset = 0;
-        let mut excerpts_to_remove = Vec::new();
-        self.messages.retain(|message| {
-            let range = offset..offset + message.content.read(cx).len();
-            offset = range.end + 1;
-            if range.is_empty()
-                && !protected_offsets.contains(&range.start)
-                && messages.contains(&message.id)
-            {
-                excerpts_to_remove.push(message.excerpt_id);
-                self.messages_metadata.remove(&message.excerpt_id);
-                false
-            } else {
-                true
-            }
-        });
-
-        if !excerpts_to_remove.is_empty() {
-            self.buffer.update(cx, |buffer, cx| {
-                buffer.remove_excerpts(excerpts_to_remove, cx)
-            });
-            cx.notify();
-        }
+        // let mut offset = 0;
+        // let mut excerpts_to_remove = Vec::new();
+        // self.messages.retain(|message| {
+        //     let range = offset..offset + message.content.read(cx).len();
+        //     offset = range.end + 1;
+        //     if range.is_empty()
+        //         && !protected_offsets.contains(&range.start)
+        //         && messages.contains(&message.id)
+        //     {
+        //         excerpts_to_remove.push(message.excerpt_id);
+        //         self.messages_metadata.remove(&message.excerpt_id);
+        //         false
+        //     } else {
+        //         true
+        //     }
+        // });
+
+        // if !excerpts_to_remove.is_empty() {
+        //     self.buffer.update(cx, |buffer, cx| {
+        //         buffer.remove_excerpts(excerpts_to_remove, cx)
+        //     });
+        //     cx.notify();
+        // }
     }
 
-    fn cycle_message_role(&mut self, excerpt_id: ExcerptId, cx: &mut ModelContext<Self>) {
-        if let Some(metadata) = self.messages_metadata.get_mut(&excerpt_id) {
+    fn cycle_message_role(&mut self, id: MessageId, cx: &mut ModelContext<Self>) {
+        if let Some(metadata) = self.messages_metadata.get_mut(&id) {
             metadata.role.cycle();
             cx.notify();
         }
@@ -686,15 +690,18 @@ impl Assistant {
             .position(|message| message.id == message_id)
         {
             let start = self.buffer.update(cx, |buffer, cx| {
-                let len = buffer.len();
-                buffer.edit([(len..len, "\n")], None, cx);
-                buffer.anchor_before(len + 1)
+                let offset = self
+                    .messages
+                    .get(prev_message_ix + 1)
+                    .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
+                buffer.edit([(offset..offset, "\n")], None, cx);
+                buffer.anchor_before(offset + 1)
             });
             let message = Message {
                 id: MessageId(post_inc(&mut self.next_message_id.0)),
                 start,
             };
-            self.messages.insert(prev_message_ix, message.clone());
+            self.messages.insert(prev_message_ix + 1, message.clone());
             self.messages_metadata.insert(
                 message.id,
                 MessageMetadata {
@@ -713,24 +720,13 @@ impl Assistant {
         if self.messages.len() >= 2 && self.summary.is_none() {
             let api_key = self.api_key.borrow().clone();
             if let Some(api_key) = api_key {
-                // let messages = self
-                //     .messages
-                //     .iter()
-                //     .take(2)
-                //     .filter_map(|message| {
-                //         Some(RequestMessage {
-                //             role: self.messages_metadata.get(&message.id)?.role,
-                //             content: message.content.read(cx).text(),
-                //         })
-                //     })
-                //     .chain(Some(RequestMessage {
-                //         role: Role::User,
-                //         content:
-                //             "Summarize the conversation into a short title without punctuation"
-                //                 .into(),
-                //     }))
-                //     .collect();
-                let messages = todo!();
+                let mut messages = self.open_ai_request_messages(cx);
+                messages.truncate(2);
+                messages.push(RequestMessage {
+                    role: Role::User,
+                    content: "Summarize the conversation into a short title without punctuation"
+                        .into(),
+                });
                 let request = OpenAIRequest {
                     model: self.model.clone(),
                     messages,
@@ -760,6 +756,44 @@ impl Assistant {
             }
         }
     }
+
+    fn open_ai_request_messages(&self, cx: &AppContext) -> Vec<RequestMessage> {
+        let buffer = self.buffer.read(cx);
+        self.messages(cx)
+            .map(|(message, metadata, range)| RequestMessage {
+                role: metadata.role,
+                content: buffer.text_for_range(range).collect(),
+            })
+            .collect()
+    }
+
+    fn message_id_for_offset(&self, offset: usize, cx: &AppContext) -> Option<MessageId> {
+        Some(
+            self.messages(cx)
+                .find(|(_, _, range)| range.contains(&offset))
+                .map(|(message, _, _)| message)
+                .or(self.messages.last())?
+                .id,
+        )
+    }
+
+    fn messages<'a>(
+        &'a self,
+        cx: &'a AppContext,
+    ) -> impl 'a + Iterator<Item = (&Message, &MessageMetadata, Range<usize>)> {
+        let buffer = self.buffer.read(cx);
+        let mut messages = self.messages.iter().peekable();
+        iter::from_fn(move || {
+            let message = messages.next()?;
+            let metadata = self.messages_metadata.get(&message.id)?;
+            let message_start = message.start.to_offset(buffer);
+            let message_end = messages
+                .peek()
+                .map_or(language::Anchor::MAX, |message| message.start)
+                .to_offset(buffer);
+            Some((message, metadata, message_start..message_end))
+        })
+    }
 }
 
 struct PendingCompletion {
@@ -812,16 +846,12 @@ impl AssistantEditor {
     fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
         let user_message = self.assistant.update(cx, |assistant, cx| {
             let editor = self.editor.read(cx);
-            let newest_selection = editor.selections.newest_anchor();
-            let message_id = if newest_selection.head() == Anchor::min() {
-                assistant.messages.first().map(|message| message.id)?
-            } else if newest_selection.head() == Anchor::max() {
-                assistant.messages.last().map(|message| message.id)?
-            } else {
-                todo!()
-                // newest_selection.head().excerpt_id()
-            };
-
+            let newest_selection = editor
+                .selections
+                .newest_anchor()
+                .head()
+                .to_offset(&editor.buffer().read(cx).snapshot(cx));
+            let message_id = assistant.message_id_for_offset(newest_selection, cx)?;
             let metadata = assistant.messages_metadata.get(&message_id)?;
             let user_message = if metadata.role == Role::User {
                 let (_, user_message) = assistant.assist(cx)?;
@@ -834,16 +864,14 @@ impl AssistantEditor {
         });
 
         if let Some(user_message) = user_message {
+            let cursor = user_message
+                .start
+                .to_offset(&self.assistant.read(cx).buffer.read(cx));
             self.editor.update(cx, |editor, cx| {
-                let cursor = editor
-                    .buffer()
-                    .read(cx)
-                    .snapshot(cx)
-                    .anchor_in_excerpt(Default::default(), user_message.start);
                 editor.change_selections(
                     Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
                     cx,
-                    |selections| selections.select_anchor_ranges([cursor..cursor]),
+                    |selections| selections.select_ranges([cursor..cursor]),
                 );
             });
             self.update_scroll_bottom(cx);
@@ -1011,7 +1039,7 @@ impl AssistantEditor {
             let mut copied_text = String::new();
             let mut spanned_messages = 0;
             for message in &assistant.messages {
-                // TODO
+                todo!();
                 // let message_range = offset..offset + message.content.read(cx).len() + 1;
                 let message_range = offset..offset + 1;
 
@@ -1260,28 +1288,100 @@ mod tests {
     #[gpui::test]
     fn test_inserting_and_removing_messages(cx: &mut AppContext) {
         let registry = Arc::new(LanguageRegistry::test());
+        let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
+        let buffer = assistant.read(cx).buffer.clone();
+
+        let message_1 = assistant.read(cx).messages[0].clone();
+        assert_eq!(
+            messages(&assistant, cx),
+            vec![(message_1.id, Role::User, 0..0)]
+        );
 
-        cx.add_model(|cx| {
-            let mut assistant = Assistant::new(Default::default(), registry, cx);
-            let message_1 = assistant.messages[0].clone();
-            let message_2 = assistant
+        let message_2 = assistant.update(cx, |assistant, cx| {
+            assistant
                 .insert_message_after(message_1.id, Role::Assistant, cx)
-                .unwrap();
-            let message_3 = assistant
-                .insert_message_after(message_2.id, Role::User, cx)
-                .unwrap();
-            let message_4 = assistant
+                .unwrap()
+        });
+        assert_eq!(
+            messages(&assistant, cx),
+            vec![
+                (message_1.id, Role::User, 0..1),
+                (message_2.id, Role::Assistant, 1..1)
+            ]
+        );
+
+        buffer.update(cx, |buffer, cx| {
+            buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
+        });
+        assert_eq!(
+            messages(&assistant, cx),
+            vec![
+                (message_1.id, Role::User, 0..2),
+                (message_2.id, Role::Assistant, 2..3)
+            ]
+        );
+
+        let message_3 = assistant.update(cx, |assistant, cx| {
+            assistant
                 .insert_message_after(message_2.id, Role::User, cx)
-                .unwrap();
-            assistant.remove_empty_messages(
-                HashSet::from_iter([message_3.id, message_4.id]),
-                Default::default(),
-                cx,
-            );
-            assert_eq!(assistant.messages.len(), 2);
-            assert_eq!(assistant.messages[0].id, message_1.id);
-            assert_eq!(assistant.messages[1].id, message_2.id);
+                .unwrap()
+        });
+        assert_eq!(
+            messages(&assistant, cx),
+            vec![
+                (message_1.id, Role::User, 0..2),
+                (message_2.id, Role::Assistant, 2..4),
+                (message_3.id, Role::User, 4..4)
+            ]
+        );
+
+        let message_4 = assistant.update(cx, |assistant, cx| {
             assistant
+                .insert_message_after(message_2.id, Role::User, cx)
+                .unwrap()
+        });
+        assert_eq!(
+            messages(&assistant, cx),
+            vec![
+                (message_1.id, Role::User, 0..2),
+                (message_2.id, Role::Assistant, 2..4),
+                (message_4.id, Role::User, 4..5),
+                (message_3.id, Role::User, 5..5),
+            ]
+        );
+
+        buffer.update(cx, |buffer, cx| {
+            buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
         });
+        assert_eq!(
+            messages(&assistant, cx),
+            vec![
+                (message_1.id, Role::User, 0..2),
+                (message_2.id, Role::Assistant, 2..4),
+                (message_4.id, Role::User, 4..6),
+                (message_3.id, Role::User, 6..7),
+            ]
+        );
+
+        // Deleting across message boundaries merges the messages.
+        buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
+        assert_eq!(
+            messages(&assistant, cx),
+            vec![
+                (message_1.id, Role::User, 0..6),
+                (message_3.id, Role::User, 6..7),
+            ]
+        );
+    }
+
+    fn messages(
+        assistant: &ModelHandle<Assistant>,
+        cx: &AppContext,
+    ) -> Vec<(MessageId, Role, Range<usize>)> {
+        assistant
+            .read(cx)
+            .messages(cx)
+            .map(|(message, metadata, range)| (message.id, metadata.role, range))
+            .collect()
     }
 }