:lipstick:

Antonio Scandurra created

Change summary

crates/ai/src/assistant.rs | 224 +++++++++++++++++++--------------------
1 file changed, 110 insertions(+), 114 deletions(-)

Detailed changes

crates/ai/src/assistant.rs 🔗

@@ -611,6 +611,7 @@ impl Assistant {
                                         };
                                         buffer.edit([(offset..offset, text)], None, cx);
                                     });
+                                    cx.emit(AssistantEvent::StreamedCompletion);
 
                                     Some(())
                                 });
@@ -745,7 +746,7 @@ impl Assistant {
     fn open_ai_request_messages(&self, cx: &AppContext) -> Vec<RequestMessage> {
         let buffer = self.buffer.read(cx);
         self.messages(cx)
-            .map(|(message, metadata, range)| RequestMessage {
+            .map(|(_message, metadata, range)| RequestMessage {
                 role: metadata.role,
                 content: buffer.text_for_range(range).collect(),
             })
@@ -828,7 +829,7 @@ impl AssistantEditor {
             cx.subscribe(&editor, Self::handle_editor_event),
         ];
 
-        Self {
+        let mut this = Self {
             assistant,
             editor,
             blocks: Default::default(),
@@ -837,7 +838,9 @@ impl AssistantEditor {
                 anchor: Anchor::max(),
             },
             _subscriptions,
-        }
+        };
+        this.update_message_headers(cx);
+        this
     }
 
     fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
@@ -891,116 +894,10 @@ impl AssistantEditor {
         cx: &mut ViewContext<Self>,
     ) {
         match event {
-            AssistantEvent::MessagesEdited => {
-                self.editor.update(cx, |editor, cx| {
-                    let buffer = editor.buffer().read(cx).snapshot(cx);
-                    let excerpt_id = *buffer.as_singleton().unwrap().0;
-                    let old_blocks = std::mem::take(&mut self.blocks);
-                    let new_blocks =
-                        self.assistant
-                            .read(cx)
-                            .messages(cx)
-                            .map(|(message, metadata, _)| BlockProperties {
-                                position: buffer.anchor_in_excerpt(excerpt_id, message.start),
-                                height: 2,
-                                style: BlockStyle::Sticky,
-                                render: Arc::new({
-                                    let assistant = self.assistant.clone();
-                                    let metadata = metadata.clone();
-                                    let message = message.clone();
-                                    move |cx| {
-                                        enum Sender {}
-                                        enum ErrorTooltip {}
-
-                                        let theme = theme::current(cx);
-                                        let style = &theme.assistant;
-                                        let message_id = message.id;
-                                        let sender = MouseEventHandler::<Sender, _>::new(
-                                            message_id.0,
-                                            cx,
-                                            |state, _| match metadata.role {
-                                                Role::User => {
-                                                    let style =
-                                                        style.user_sender.style_for(state, false);
-                                                    Label::new("You", style.text.clone())
-                                                        .contained()
-                                                        .with_style(style.container)
-                                                }
-                                                Role::Assistant => {
-                                                    let style = style
-                                                        .assistant_sender
-                                                        .style_for(state, false);
-                                                    Label::new("Assistant", style.text.clone())
-                                                        .contained()
-                                                        .with_style(style.container)
-                                                }
-                                                Role::System => {
-                                                    let style =
-                                                        style.system_sender.style_for(state, false);
-                                                    Label::new("System", style.text.clone())
-                                                        .contained()
-                                                        .with_style(style.container)
-                                                }
-                                            },
-                                        )
-                                        .with_cursor_style(CursorStyle::PointingHand)
-                                        .on_down(MouseButton::Left, {
-                                            let assistant = assistant.clone();
-                                            move |_, _, cx| {
-                                                assistant.update(cx, |assistant, cx| {
-                                                    assistant.cycle_message_role(message_id, cx)
-                                                })
-                                            }
-                                        });
-
-                                        Flex::row()
-                                            .with_child(sender.aligned())
-                                            .with_child(
-                                                Label::new(
-                                                    metadata.sent_at.format("%I:%M%P").to_string(),
-                                                    style.sent_at.text.clone(),
-                                                )
-                                                .contained()
-                                                .with_style(style.sent_at.container)
-                                                .aligned(),
-                                            )
-                                            .with_children(metadata.error.clone().map(|error| {
-                                                Svg::new("icons/circle_x_mark_12.svg")
-                                                    .with_color(style.error_icon.color)
-                                                    .constrained()
-                                                    .with_width(style.error_icon.width)
-                                                    .contained()
-                                                    .with_style(style.error_icon.container)
-                                                    .with_tooltip::<ErrorTooltip>(
-                                                        message_id.0,
-                                                        error,
-                                                        None,
-                                                        theme.tooltip.clone(),
-                                                        cx,
-                                                    )
-                                                    .aligned()
-                                            }))
-                                            .aligned()
-                                            .left()
-                                            .contained()
-                                            .with_style(style.header)
-                                            .into_any()
-                                    }
-                                }),
-                                disposition: BlockDisposition::Above,
-                            })
-                            .collect::<Vec<_>>();
-
-                    editor.remove_blocks(old_blocks, cx);
-                    let ids = editor.insert_blocks(new_blocks, cx);
-                    self.blocks = HashSet::from_iter(ids);
-                });
-            }
-
+            AssistantEvent::MessagesEdited => self.update_message_headers(cx),
             AssistantEvent::SummaryChanged => {
                 cx.emit(AssistantEditorEvent::TabContentChanged);
             }
-
             AssistantEvent::StreamedCompletion => {
                 self.editor.update(cx, |editor, cx| {
                     let snapshot = editor.snapshot(cx);
@@ -1032,6 +929,108 @@ impl AssistantEditor {
         }
     }
 
+    fn update_message_headers(&mut self, cx: &mut ViewContext<Self>) {
+        self.editor.update(cx, |editor, cx| {
+            let buffer = editor.buffer().read(cx).snapshot(cx);
+            let excerpt_id = *buffer.as_singleton().unwrap().0;
+            let old_blocks = std::mem::take(&mut self.blocks);
+            let new_blocks = self
+                .assistant
+                .read(cx)
+                .messages(cx)
+                .map(|(message, metadata, _)| BlockProperties {
+                    position: buffer.anchor_in_excerpt(excerpt_id, message.start),
+                    height: 2,
+                    style: BlockStyle::Sticky,
+                    render: Arc::new({
+                        let assistant = self.assistant.clone();
+                        let metadata = metadata.clone();
+                        let message = message.clone();
+                        move |cx| {
+                            enum Sender {}
+                            enum ErrorTooltip {}
+
+                            let theme = theme::current(cx);
+                            let style = &theme.assistant;
+                            let message_id = message.id;
+                            let sender = MouseEventHandler::<Sender, _>::new(
+                                message_id.0,
+                                cx,
+                                |state, _| match metadata.role {
+                                    Role::User => {
+                                        let style = style.user_sender.style_for(state, false);
+                                        Label::new("You", style.text.clone())
+                                            .contained()
+                                            .with_style(style.container)
+                                    }
+                                    Role::Assistant => {
+                                        let style = style.assistant_sender.style_for(state, false);
+                                        Label::new("Assistant", style.text.clone())
+                                            .contained()
+                                            .with_style(style.container)
+                                    }
+                                    Role::System => {
+                                        let style = style.system_sender.style_for(state, false);
+                                        Label::new("System", style.text.clone())
+                                            .contained()
+                                            .with_style(style.container)
+                                    }
+                                },
+                            )
+                            .with_cursor_style(CursorStyle::PointingHand)
+                            .on_down(MouseButton::Left, {
+                                let assistant = assistant.clone();
+                                move |_, _, cx| {
+                                    assistant.update(cx, |assistant, cx| {
+                                        assistant.cycle_message_role(message_id, cx)
+                                    })
+                                }
+                            });
+
+                            Flex::row()
+                                .with_child(sender.aligned())
+                                .with_child(
+                                    Label::new(
+                                        metadata.sent_at.format("%I:%M%P").to_string(),
+                                        style.sent_at.text.clone(),
+                                    )
+                                    .contained()
+                                    .with_style(style.sent_at.container)
+                                    .aligned(),
+                                )
+                                .with_children(metadata.error.clone().map(|error| {
+                                    Svg::new("icons/circle_x_mark_12.svg")
+                                        .with_color(style.error_icon.color)
+                                        .constrained()
+                                        .with_width(style.error_icon.width)
+                                        .contained()
+                                        .with_style(style.error_icon.container)
+                                        .with_tooltip::<ErrorTooltip>(
+                                            message_id.0,
+                                            error,
+                                            None,
+                                            theme.tooltip.clone(),
+                                            cx,
+                                        )
+                                        .aligned()
+                                }))
+                                .aligned()
+                                .left()
+                                .contained()
+                                .with_style(style.header)
+                                .into_any()
+                        }
+                    }),
+                    disposition: BlockDisposition::Above,
+                })
+                .collect::<Vec<_>>();
+
+            editor.remove_blocks(old_blocks, cx);
+            let ids = editor.insert_blocks(new_blocks, cx);
+            self.blocks = HashSet::from_iter(ids);
+        });
+    }
+
     fn update_scroll_bottom(&mut self, cx: &mut ViewContext<Self>) {
         self.editor.update(cx, |editor, cx| {
             let snapshot = editor.snapshot(cx);
@@ -1128,10 +1127,9 @@ impl AssistantEditor {
         let assistant = self.assistant.read(cx);
         if editor.selections.count() == 1 {
             let selection = editor.selections.newest::<usize>(cx);
-            let mut offset = 0;
             let mut copied_text = String::new();
             let mut spanned_messages = 0;
-            for (message, metadata, message_range) in assistant.messages(cx) {
+            for (_message, metadata, message_range) in assistant.messages(cx) {
                 if message_range.start >= selection.range().end {
                     break;
                 } else if message_range.end >= selection.range().start {
@@ -1146,8 +1144,6 @@ impl AssistantEditor {
                         copied_text.push('\n');
                     }
                 }
-
-                offset = message_range.end;
             }
 
             if spanned_messages > 1 {