WIP

Antonio Scandurra created

Change summary

Cargo.lock                 |   2 
crates/ai/src/assistant.rs | 314 +++++++++++++--------------------------
2 files changed, 110 insertions(+), 206 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -7559,7 +7559,7 @@ dependencies = [
 [[package]]
 name = "tree-sitter-yaml"
 version = "0.0.1"
-source = "git+https://github.com/zed-industries/tree-sitter-yaml?rev=5694b7f290cd9ef998829a0a6d8391a666370886#5694b7f290cd9ef998829a0a6d8391a666370886"
+source = "git+https://github.com/zed-industries/tree-sitter-yaml?rev=f545a41f57502e1b5ddf2a6668896c1b0620f930#f545a41f57502e1b5ddf2a6668896c1b0620f930"
 dependencies = [
  "cc",
  "tree-sitter",

crates/ai/src/assistant.rs 🔗

@@ -11,7 +11,7 @@ use editor::{
         autoscroll::{Autoscroll, AutoscrollStrategy},
         ScrollAnchor,
     },
-    Anchor, DisplayPoint, Editor, ExcerptId, ExcerptRange, MultiBuffer,
+    Anchor, DisplayPoint, Editor, ExcerptId,
 };
 use fs::Fs;
 use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
@@ -420,15 +420,16 @@ impl Panel for AssistantPanel {
 }
 
 enum AssistantEvent {
-    MessagesEdited { ids: Vec<ExcerptId> },
+    MessagesEdited,
     SummaryChanged,
     StreamedCompletion,
 }
 
 struct Assistant {
-    buffer: ModelHandle<MultiBuffer>,
+    buffer: ModelHandle<Buffer>,
     messages: Vec<Message>,
-    messages_metadata: HashMap<ExcerptId, MessageMetadata>,
+    messages_metadata: HashMap<MessageId, MessageMetadata>,
+    next_message_id: MessageId,
     summary: Option<String>,
     pending_summary: Task<Option<()>>,
     completion_count: usize,
@@ -453,10 +454,11 @@ impl Assistant {
         cx: &mut ModelContext<Self>,
     ) -> Self {
         let model = "gpt-3.5-turbo";
-        let buffer = cx.add_model(|_| MultiBuffer::new(0));
+        let buffer = cx.add_model(|cx| Buffer::new(0, "", cx));
         let mut this = Self {
             messages: Default::default(),
             messages_metadata: Default::default(),
+            next_message_id: Default::default(),
             summary: None,
             pending_summary: Task::ready(None),
             completion_count: Default::default(),
@@ -470,23 +472,34 @@ impl Assistant {
             api_key,
             buffer,
         };
-        this.insert_message_after(ExcerptId::max(), Role::User, cx);
+        let message = Message {
+            id: MessageId(post_inc(&mut this.next_message_id.0)),
+            start: language::Anchor::MIN,
+        };
+        this.messages.push(message.clone());
+        this.messages_metadata.insert(
+            message.id,
+            MessageMetadata {
+                role: Role::User,
+                sent_at: Local::now(),
+                error: None,
+            },
+        );
+
         this.count_remaining_tokens(cx);
         this
     }
 
     fn handle_buffer_event(
         &mut self,
-        _: ModelHandle<MultiBuffer>,
-        event: &editor::multi_buffer::Event,
+        _: ModelHandle<Buffer>,
+        event: &language::Event,
         cx: &mut ModelContext<Self>,
     ) {
         match event {
-            editor::multi_buffer::Event::ExcerptsAdded { .. }
-            | editor::multi_buffer::Event::ExcerptsRemoved { .. }
-            | editor::multi_buffer::Event::Edited => self.count_remaining_tokens(cx),
-            editor::multi_buffer::Event::ExcerptsEdited { ids } => {
-                cx.emit(AssistantEvent::MessagesEdited { ids: ids.clone() });
+            language::Event::Edited => {
+                self.count_remaining_tokens(cx);
+                cx.emit(AssistantEvent::MessagesEdited);
             }
             _ => {}
         }
@@ -625,7 +638,7 @@ impl Assistant {
 
     fn remove_empty_messages<'a>(
         &mut self,
-        excerpts: HashSet<ExcerptId>,
+        messages: HashSet<MessageId>,
         protected_offsets: HashSet<usize>,
         cx: &mut ModelContext<Self>,
     ) {
@@ -636,7 +649,7 @@ impl Assistant {
             offset = range.end + 1;
             if range.is_empty()
                 && !protected_offsets.contains(&range.start)
-                && excerpts.contains(&message.excerpt_id)
+                && messages.contains(&message.id)
             {
                 excerpts_to_remove.push(message.excerpt_id);
                 self.messages_metadata.remove(&message.excerpt_id);
@@ -663,84 +676,61 @@ impl Assistant {
 
     fn insert_message_after(
         &mut self,
-        excerpt_id: ExcerptId,
+        message_id: MessageId,
         role: Role,
         cx: &mut ModelContext<Self>,
-    ) -> Message {
-        let content = cx.add_model(|cx| {
-            let mut buffer = Buffer::new(0, "", cx);
-            let markdown = self.languages.language_for_name("Markdown");
-            cx.spawn_weak(|buffer, mut cx| async move {
-                let markdown = markdown.await?;
-                let buffer = buffer
-                    .upgrade(&cx)
-                    .ok_or_else(|| anyhow!("buffer was dropped"))?;
-                buffer.update(&mut cx, |buffer, cx| {
-                    buffer.set_language(Some(markdown), cx)
-                });
-                anyhow::Ok(())
-            })
-            .detach_and_log_err(cx);
-            buffer.set_language_registry(self.languages.clone());
-            buffer
-        });
-        let new_excerpt_id = self.buffer.update(cx, |buffer, cx| {
-            buffer
-                .insert_excerpts_after(
-                    excerpt_id,
-                    content.clone(),
-                    vec![ExcerptRange {
-                        context: 0..0,
-                        primary: None,
-                    }],
-                    cx,
-                )
-                .pop()
-                .unwrap()
-        });
-
-        let ix = self
+    ) -> Option<Message> {
+        if let Some(prev_message_ix) = self
             .messages
             .iter()
-            .position(|message| message.excerpt_id == excerpt_id)
-            .map_or(self.messages.len(), |ix| ix + 1);
-        let message = Message {
-            excerpt_id: new_excerpt_id,
-            content: content.clone(),
-        };
-        self.messages.insert(ix, message.clone());
-        self.messages_metadata.insert(
-            new_excerpt_id,
-            MessageMetadata {
-                role,
-                sent_at: Local::now(),
-                error: None,
-            },
-        );
-        message
+            .position(|message| message.id == message_id)
+        {
+            let start = self.buffer.update(cx, |buffer, cx| {
+                let len = buffer.len();
+                buffer.edit([(len..len, "\n")], None, cx);
+                buffer.anchor_before(len + 1)
+            });
+            let message = Message {
+                id: MessageId(post_inc(&mut self.next_message_id.0)),
+                start,
+            };
+            self.messages.insert(prev_message_ix, message.clone());
+            self.messages_metadata.insert(
+                message.id,
+                MessageMetadata {
+                    role,
+                    sent_at: Local::now(),
+                    error: None,
+                },
+            );
+            Some(message)
+        } else {
+            None
+        }
     }
 
     fn summarize(&mut self, cx: &mut ModelContext<Self>) {
         if self.messages.len() >= 2 && self.summary.is_none() {
             let api_key = self.api_key.borrow().clone();
             if let Some(api_key) = api_key {
-                let messages = self
-                    .messages
-                    .iter()
-                    .take(2)
-                    .filter_map(|message| {
-                        Some(RequestMessage {
-                            role: self.messages_metadata.get(&message.excerpt_id)?.role,
-                            content: message.content.read(cx).text(),
-                        })
-                    })
-                    .chain(Some(RequestMessage {
-                        role: Role::User,
-                        content:
-                            "Summarize the conversation into a short title without punctuation"
-                                .into(),
-                    }))
-                    .collect();
+                // let messages = self
+                //     .messages
+                //     .iter()
+                //     .take(2)
+                //     .filter_map(|message| {
+                //         Some(RequestMessage {
+                //             role: self.messages_metadata.get(&message.id)?.role,
+                //             content: message.content.read(cx).text(),
+                //         })
+                //     })
+                //     .chain(Some(RequestMessage {
+                //         role: Role::User,
+                //         content:
+                //             "Summarize the conversation into a short title without punctuation"
+                //                 .into(),
+                //     }))
+                //     .collect();
+                let messages = todo!();
                 let request = OpenAIRequest {
                     model: self.model.clone(),
                     messages,
@@ -796,98 +786,9 @@ impl AssistantEditor {
     ) -> Self {
         let assistant = cx.add_model(|cx| Assistant::new(api_key, language_registry, cx));
         let editor = cx.add_view(|cx| {
-            let mut editor = Editor::for_multibuffer(assistant.read(cx).buffer.clone(), None, cx);
+            let mut editor = Editor::for_buffer(assistant.read(cx).buffer.clone(), None, cx);
             editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
             editor.set_show_gutter(false, cx);
-            editor.set_render_excerpt_header(
-                {
-                    let assistant = assistant.clone();
-                    move |_editor, params: editor::RenderExcerptHeaderParams, cx| {
-                        enum Sender {}
-                        enum ErrorTooltip {}
-
-                        let theme = theme::current(cx);
-                        let style = &theme.assistant;
-                        let excerpt_id = params.id;
-                        if let Some(metadata) = assistant
-                            .read(cx)
-                            .messages_metadata
-                            .get(&excerpt_id)
-                            .cloned()
-                        {
-                            let sender = MouseEventHandler::<Sender, _>::new(
-                                params.id.into(),
-                                cx,
-                                |state, _| match metadata.role {
-                                    Role::User => {
-                                        let style = style.user_sender.style_for(state, false);
-                                        Label::new("You", style.text.clone())
-                                            .contained()
-                                            .with_style(style.container)
-                                    }
-                                    Role::Assistant => {
-                                        let style = style.assistant_sender.style_for(state, false);
-                                        Label::new("Assistant", style.text.clone())
-                                            .contained()
-                                            .with_style(style.container)
-                                    }
-                                    Role::System => {
-                                        let style = style.system_sender.style_for(state, false);
-                                        Label::new("System", style.text.clone())
-                                            .contained()
-                                            .with_style(style.container)
-                                    }
-                                },
-                            )
-                            .with_cursor_style(CursorStyle::PointingHand)
-                            .on_down(MouseButton::Left, {
-                                let assistant = assistant.clone();
-                                move |_, _, cx| {
-                                    assistant.update(cx, |assistant, cx| {
-                                        assistant.cycle_message_role(excerpt_id, cx)
-                                    })
-                                }
-                            });
-
-                            Flex::row()
-                                .with_child(sender.aligned())
-                                .with_child(
-                                    Label::new(
-                                        metadata.sent_at.format("%I:%M%P").to_string(),
-                                        style.sent_at.text.clone(),
-                                    )
-                                    .contained()
-                                    .with_style(style.sent_at.container)
-                                    .aligned(),
-                                )
-                                .with_children(metadata.error.map(|error| {
-                                    Svg::new("icons/circle_x_mark_12.svg")
-                                        .with_color(style.error_icon.color)
-                                        .constrained()
-                                        .with_width(style.error_icon.width)
-                                        .contained()
-                                        .with_style(style.error_icon.container)
-                                        .with_tooltip::<ErrorTooltip>(
-                                            params.id.into(),
-                                            error,
-                                            None,
-                                            theme.tooltip.clone(),
-                                            cx,
-                                        )
-                                        .aligned()
-                                }))
-                                .aligned()
-                                .left()
-                                .contained()
-                                .with_style(style.header)
-                                .into_any()
-                        } else {
-                            Empty::new().into_any()
-                        }
-                    }
-                },
-                cx,
-            );
             editor
         });
 
@@ -912,26 +813,21 @@ impl AssistantEditor {
         let user_message = self.assistant.update(cx, |assistant, cx| {
             let editor = self.editor.read(cx);
             let newest_selection = editor.selections.newest_anchor();
-            let excerpt_id = if newest_selection.head() == Anchor::min() {
-                assistant
-                    .messages
-                    .first()
-                    .map(|message| message.excerpt_id)?
+            let message_id = if newest_selection.head() == Anchor::min() {
+                assistant.messages.first().map(|message| message.id)?
             } else if newest_selection.head() == Anchor::max() {
-                assistant
-                    .messages
-                    .last()
-                    .map(|message| message.excerpt_id)?
+                assistant.messages.last().map(|message| message.id)?
             } else {
-                newest_selection.head().excerpt_id()
+                todo!()
+                // newest_selection.head().excerpt_id()
             };
 
-            let metadata = assistant.messages_metadata.get(&excerpt_id)?;
+            let metadata = assistant.messages_metadata.get(&message_id)?;
             let user_message = if metadata.role == Role::User {
                 let (_, user_message) = assistant.assist(cx)?;
                 user_message
             } else {
-                let user_message = assistant.insert_message_after(excerpt_id, Role::User, cx);
+                let user_message = assistant.insert_message_after(message_id, Role::User, cx)?;
                 user_message
             };
             Some(user_message)
@@ -943,7 +839,7 @@ impl AssistantEditor {
                     .buffer()
                     .read(cx)
                     .snapshot(cx)
-                    .anchor_in_excerpt(user_message.excerpt_id, language::Anchor::MIN);
+                    .anchor_in_excerpt(Default::default(), user_message.start);
                 editor.change_selections(
                     Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
                     cx,
@@ -970,16 +866,16 @@ impl AssistantEditor {
         cx: &mut ViewContext<Self>,
     ) {
         match event {
-            AssistantEvent::MessagesEdited { ids } => {
+            AssistantEvent::MessagesEdited => {
                 let selections = self.editor.read(cx).selections.all::<usize>(cx);
                 let selection_heads = selections
                     .iter()
                     .map(|selection| selection.head())
                     .collect::<HashSet<usize>>();
-                let ids = ids.iter().copied().collect::<HashSet<_>>();
-                self.assistant.update(cx, |assistant, cx| {
-                    assistant.remove_empty_messages(ids, selection_heads, cx)
-                });
+                // let ids = ids.iter().copied().collect::<HashSet<_>>();
+                // self.assistant.update(cx, |assistant, cx| {
+                //     assistant.remove_empty_messages(ids, selection_heads, cx)
+                // });
             }
             AssistantEvent::SummaryChanged => {
                 cx.emit(AssistantEditorEvent::TabContentChanged);
@@ -1115,7 +1011,9 @@ impl AssistantEditor {
             let mut copied_text = String::new();
             let mut spanned_messages = 0;
             for message in &assistant.messages {
-                let message_range = offset..offset + message.content.read(cx).len() + 1;
+                // TODO
+                // let message_range = offset..offset + message.content.read(cx).len() + 1;
+                let message_range = offset..offset + 1;
 
                 if message_range.start >= selection.range().end {
                     break;
@@ -1123,13 +1021,10 @@ impl AssistantEditor {
                     let range = cmp::max(message_range.start, selection.range().start)
                         ..cmp::min(message_range.end, selection.range().end);
                     if !range.is_empty() {
-                        if let Some(metadata) = assistant.messages_metadata.get(&message.excerpt_id)
-                        {
+                        if let Some(metadata) = assistant.messages_metadata.get(&message.id) {
                             spanned_messages += 1;
                             write!(&mut copied_text, "## {}\n\n", metadata.role).unwrap();
-                            for chunk in
-                                assistant.buffer.read(cx).snapshot(cx).text_for_range(range)
-                            {
+                            for chunk in assistant.buffer.read(cx).text_for_range(range) {
                                 copied_text.push_str(&chunk);
                             }
                             copied_text.push('\n');
@@ -1255,10 +1150,13 @@ impl Item for AssistantEditor {
     }
 }
 
+#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Hash)]
+struct MessageId(usize);
+
 #[derive(Clone, Debug)]
 struct Message {
-    excerpt_id: ExcerptId,
-    content: ModelHandle<Buffer>,
+    id: MessageId,
+    start: language::Anchor,
 }
 
 #[derive(Clone, Debug)]
@@ -1366,17 +1264,23 @@ mod tests {
         cx.add_model(|cx| {
             let mut assistant = Assistant::new(Default::default(), registry, cx);
             let message_1 = assistant.messages[0].clone();
-            let message_2 = assistant.insert_message_after(ExcerptId::max(), Role::Assistant, cx);
-            let message_3 = assistant.insert_message_after(message_2.excerpt_id, Role::User, cx);
-            let message_4 = assistant.insert_message_after(message_2.excerpt_id, Role::User, cx);
+            let message_2 = assistant
+                .insert_message_after(message_1.id, Role::Assistant, cx)
+                .unwrap();
+            let message_3 = assistant
+                .insert_message_after(message_2.id, Role::User, cx)
+                .unwrap();
+            let message_4 = assistant
+                .insert_message_after(message_2.id, Role::User, cx)
+                .unwrap();
             assistant.remove_empty_messages(
-                HashSet::from_iter([message_3.excerpt_id, message_4.excerpt_id]),
+                HashSet::from_iter([message_3.id, message_4.id]),
                 Default::default(),
                 cx,
             );
             assert_eq!(assistant.messages.len(), 2);
-            assert_eq!(assistant.messages[0].excerpt_id, message_1.excerpt_id);
-            assert_eq!(assistant.messages[1].excerpt_id, message_2.excerpt_id);
+            assert_eq!(assistant.messages[0].id, message_1.id);
+            assert_eq!(assistant.messages[1].id, message_2.id);
             assistant
         });
     }