active_thread.rs

  1use std::sync::Arc;
  2
  3use assistant_tool::ToolWorkingSet;
  4use collections::HashMap;
  5use editor::{Editor, MultiBuffer};
  6use gpui::{
  7    list, AbsoluteLength, AnyElement, App, DefiniteLength, EdgesRefinement, Empty, Entity,
  8    Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
  9    TextStyleRefinement, UnderlineStyle, WeakEntity,
 10};
 11use language::{Buffer, LanguageRegistry};
 12use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
 13use markdown::{Markdown, MarkdownStyle};
 14use settings::Settings as _;
 15use theme::ThemeSettings;
 16use ui::{prelude::*, Disclosure, KeyBinding};
 17use workspace::Workspace;
 18
 19use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
 20use crate::thread_store::ThreadStore;
 21use crate::tool_use::{ToolUse, ToolUseStatus};
 22use crate::ui::ContextPill;
 23
 24pub struct ActiveThread {
 25    workspace: WeakEntity<Workspace>,
 26    language_registry: Arc<LanguageRegistry>,
 27    tools: Arc<ToolWorkingSet>,
 28    thread_store: Entity<ThreadStore>,
 29    thread: Entity<Thread>,
 30    messages: Vec<MessageId>,
 31    list_state: ListState,
 32    rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
 33    editing_message: Option<(MessageId, EditMessageState)>,
 34    expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
 35    last_error: Option<ThreadError>,
 36    _subscriptions: Vec<Subscription>,
 37}
 38
 39struct EditMessageState {
 40    editor: Entity<Editor>,
 41}
 42
 43impl ActiveThread {
 44    pub fn new(
 45        thread: Entity<Thread>,
 46        thread_store: Entity<ThreadStore>,
 47        workspace: WeakEntity<Workspace>,
 48        language_registry: Arc<LanguageRegistry>,
 49        tools: Arc<ToolWorkingSet>,
 50        window: &mut Window,
 51        cx: &mut Context<Self>,
 52    ) -> Self {
 53        let subscriptions = vec![
 54            cx.observe(&thread, |_, _, cx| cx.notify()),
 55            cx.subscribe_in(&thread, window, Self::handle_thread_event),
 56        ];
 57
 58        let mut this = Self {
 59            workspace,
 60            language_registry,
 61            tools,
 62            thread_store,
 63            thread: thread.clone(),
 64            messages: Vec::new(),
 65            rendered_messages_by_id: HashMap::default(),
 66            expanded_tool_uses: HashMap::default(),
 67            list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
 68                let this = cx.entity().downgrade();
 69                move |ix, window: &mut Window, cx: &mut App| {
 70                    this.update(cx, |this, cx| this.render_message(ix, window, cx))
 71                        .unwrap()
 72                }
 73            }),
 74            editing_message: None,
 75            last_error: None,
 76            _subscriptions: subscriptions,
 77        };
 78
 79        for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
 80            this.push_message(&message.id, message.text.clone(), window, cx);
 81        }
 82
 83        this
 84    }
 85
 86    pub fn thread(&self) -> &Entity<Thread> {
 87        &self.thread
 88    }
 89
 90    pub fn is_empty(&self) -> bool {
 91        self.messages.is_empty()
 92    }
 93
 94    pub fn summary(&self, cx: &App) -> Option<SharedString> {
 95        self.thread.read(cx).summary()
 96    }
 97
 98    pub fn summary_or_default(&self, cx: &App) -> SharedString {
 99        self.thread.read(cx).summary_or_default()
100    }
101
102    pub fn cancel_last_completion(&mut self, cx: &mut App) -> bool {
103        self.last_error.take();
104        self.thread
105            .update(cx, |thread, _cx| thread.cancel_last_completion())
106    }
107
108    pub fn last_error(&self) -> Option<ThreadError> {
109        self.last_error.clone()
110    }
111
112    pub fn clear_last_error(&mut self) {
113        self.last_error.take();
114    }
115
116    fn push_message(
117        &mut self,
118        id: &MessageId,
119        text: String,
120        window: &mut Window,
121        cx: &mut Context<Self>,
122    ) {
123        let old_len = self.messages.len();
124        self.messages.push(*id);
125        self.list_state.splice(old_len..old_len, 1);
126
127        let markdown = self.render_markdown(text.into(), window, cx);
128        self.rendered_messages_by_id.insert(*id, markdown);
129        self.list_state.scroll_to(ListOffset {
130            item_ix: old_len,
131            offset_in_item: Pixels(0.0),
132        });
133    }
134
135    fn edited_message(
136        &mut self,
137        id: &MessageId,
138        text: String,
139        window: &mut Window,
140        cx: &mut Context<Self>,
141    ) {
142        let Some(index) = self.messages.iter().position(|message_id| message_id == id) else {
143            return;
144        };
145        self.list_state.splice(index..index + 1, 1);
146        let markdown = self.render_markdown(text.into(), window, cx);
147        self.rendered_messages_by_id.insert(*id, markdown);
148    }
149
150    fn deleted_message(&mut self, id: &MessageId) {
151        let Some(index) = self.messages.iter().position(|message_id| message_id == id) else {
152            return;
153        };
154        self.messages.remove(index);
155        self.list_state.splice(index..index + 1, 0);
156        self.rendered_messages_by_id.remove(id);
157    }
158
159    fn render_markdown(
160        &self,
161        text: SharedString,
162        window: &Window,
163        cx: &mut Context<Self>,
164    ) -> Entity<Markdown> {
165        let theme_settings = ThemeSettings::get_global(cx);
166        let colors = cx.theme().colors();
167        let ui_font_size = TextSize::Default.rems(cx);
168        let buffer_font_size = TextSize::Small.rems(cx);
169        let mut text_style = window.text_style();
170
171        text_style.refine(&TextStyleRefinement {
172            font_family: Some(theme_settings.ui_font.family.clone()),
173            font_size: Some(ui_font_size.into()),
174            color: Some(cx.theme().colors().text),
175            ..Default::default()
176        });
177
178        let markdown_style = MarkdownStyle {
179            base_text_style: text_style,
180            syntax: cx.theme().syntax().clone(),
181            selection_background_color: cx.theme().players().local().selection,
182            code_block_overflow_x_scroll: true,
183            table_overflow_x_scroll: true,
184            code_block: StyleRefinement {
185                margin: EdgesRefinement {
186                    top: Some(Length::Definite(rems(0.).into())),
187                    left: Some(Length::Definite(rems(0.).into())),
188                    right: Some(Length::Definite(rems(0.).into())),
189                    bottom: Some(Length::Definite(rems(0.5).into())),
190                },
191                padding: EdgesRefinement {
192                    top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
193                    left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
194                    right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
195                    bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
196                },
197                background: Some(colors.editor_background.into()),
198                border_color: Some(colors.border_variant),
199                border_widths: EdgesRefinement {
200                    top: Some(AbsoluteLength::Pixels(Pixels(1.))),
201                    left: Some(AbsoluteLength::Pixels(Pixels(1.))),
202                    right: Some(AbsoluteLength::Pixels(Pixels(1.))),
203                    bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
204                },
205                text: Some(TextStyleRefinement {
206                    font_family: Some(theme_settings.buffer_font.family.clone()),
207                    font_size: Some(buffer_font_size.into()),
208                    ..Default::default()
209                }),
210                ..Default::default()
211            },
212            inline_code: TextStyleRefinement {
213                font_family: Some(theme_settings.buffer_font.family.clone()),
214                font_size: Some(buffer_font_size.into()),
215                background_color: Some(colors.editor_foreground.opacity(0.1)),
216                ..Default::default()
217            },
218            link: TextStyleRefinement {
219                background_color: Some(colors.editor_foreground.opacity(0.025)),
220                underline: Some(UnderlineStyle {
221                    color: Some(colors.text_accent.opacity(0.5)),
222                    thickness: px(1.),
223                    ..Default::default()
224                }),
225                ..Default::default()
226            },
227            ..Default::default()
228        };
229
230        cx.new(|cx| {
231            Markdown::new(
232                text,
233                markdown_style,
234                Some(self.language_registry.clone()),
235                None,
236                cx,
237            )
238        })
239    }
240
241    fn handle_thread_event(
242        &mut self,
243        _: &Entity<Thread>,
244        event: &ThreadEvent,
245        window: &mut Window,
246        cx: &mut Context<Self>,
247    ) {
248        match event {
249            ThreadEvent::ShowError(error) => {
250                self.last_error = Some(error.clone());
251            }
252            ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => {
253                self.thread_store
254                    .update(cx, |thread_store, cx| {
255                        thread_store.save_thread(&self.thread, cx)
256                    })
257                    .detach_and_log_err(cx);
258            }
259            ThreadEvent::StreamedAssistantText(message_id, text) => {
260                if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
261                    markdown.update(cx, |markdown, cx| {
262                        markdown.append(text, cx);
263                    });
264                }
265            }
266            ThreadEvent::MessageAdded(message_id) => {
267                if let Some(message_text) = self
268                    .thread
269                    .read(cx)
270                    .message(*message_id)
271                    .map(|message| message.text.clone())
272                {
273                    self.push_message(message_id, message_text, window, cx);
274                }
275
276                self.thread_store
277                    .update(cx, |thread_store, cx| {
278                        thread_store.save_thread(&self.thread, cx)
279                    })
280                    .detach_and_log_err(cx);
281
282                cx.notify();
283            }
284            ThreadEvent::MessageEdited(message_id) => {
285                if let Some(message_text) = self
286                    .thread
287                    .read(cx)
288                    .message(*message_id)
289                    .map(|message| message.text.clone())
290                {
291                    self.edited_message(message_id, message_text, window, cx);
292                }
293
294                self.thread_store
295                    .update(cx, |thread_store, cx| {
296                        thread_store.save_thread(&self.thread, cx)
297                    })
298                    .detach_and_log_err(cx);
299
300                cx.notify();
301            }
302            ThreadEvent::MessageDeleted(message_id) => {
303                self.deleted_message(message_id);
304
305                self.thread_store
306                    .update(cx, |thread_store, cx| {
307                        thread_store.save_thread(&self.thread, cx)
308                    })
309                    .detach_and_log_err(cx);
310
311                cx.notify();
312            }
313            ThreadEvent::UsePendingTools => {
314                let pending_tool_uses = self
315                    .thread
316                    .read(cx)
317                    .pending_tool_uses()
318                    .into_iter()
319                    .filter(|tool_use| tool_use.status.is_idle())
320                    .cloned()
321                    .collect::<Vec<_>>();
322
323                for tool_use in pending_tool_uses {
324                    if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
325                        let task = tool.run(tool_use.input, self.workspace.clone(), window, cx);
326
327                        self.thread.update(cx, |thread, cx| {
328                            thread.insert_tool_output(tool_use.id.clone(), task, cx);
329                        });
330                    }
331                }
332            }
333            ThreadEvent::ToolFinished { .. } => {
334                let all_tools_finished = self
335                    .thread
336                    .read(cx)
337                    .pending_tool_uses()
338                    .into_iter()
339                    .all(|tool_use| tool_use.status.is_error());
340                if all_tools_finished {
341                    let model_registry = LanguageModelRegistry::read_global(cx);
342                    if let Some(model) = model_registry.active_model() {
343                        self.thread.update(cx, |thread, cx| {
344                            // Insert a user message to contain the tool results.
345                            thread.insert_user_message(
346                                // TODO: Sending up a user message without any content results in the model sending back
347                                // responses that also don't have any content. We currently don't handle this case well,
348                                // so for now we provide some text to keep the model on track.
349                                "Here are the tool results.",
350                                Vec::new(),
351                                cx,
352                            );
353                            thread.send_to_model(model, RequestKind::Chat, true, cx);
354                        });
355                    }
356                }
357            }
358        }
359    }
360
361    fn start_editing_message(
362        &mut self,
363        message_id: MessageId,
364        message_text: String,
365        window: &mut Window,
366        cx: &mut Context<Self>,
367    ) {
368        let buffer = cx.new(|cx| {
369            MultiBuffer::singleton(cx.new(|cx| Buffer::local(message_text.clone(), cx)), cx)
370        });
371        let editor = cx.new(|cx| {
372            let mut editor = Editor::new(
373                editor::EditorMode::AutoHeight { max_lines: 8 },
374                buffer,
375                None,
376                false,
377                window,
378                cx,
379            );
380            editor.focus_handle(cx).focus(window);
381            editor.move_to_end(&editor::actions::MoveToEnd, window, cx);
382            editor
383        });
384        self.editing_message = Some((
385            message_id,
386            EditMessageState {
387                editor: editor.clone(),
388            },
389        ));
390        cx.notify();
391    }
392
393    fn cancel_editing_message(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
394        self.editing_message.take();
395        cx.notify();
396    }
397
398    fn confirm_editing_message(
399        &mut self,
400        _: &menu::Confirm,
401        _: &mut Window,
402        cx: &mut Context<Self>,
403    ) {
404        let Some((message_id, state)) = self.editing_message.take() else {
405            return;
406        };
407        let edited_text = state.editor.read(cx).text(cx);
408        self.thread.update(cx, |thread, cx| {
409            thread.edit_message(message_id, Role::User, edited_text, cx);
410            for message_id in self.messages_after(message_id) {
411                thread.delete_message(*message_id, cx);
412            }
413        });
414
415        let provider = LanguageModelRegistry::read_global(cx).active_provider();
416        if provider
417            .as_ref()
418            .map_or(false, |provider| provider.must_accept_terms(cx))
419        {
420            cx.notify();
421            return;
422        }
423        let model_registry = LanguageModelRegistry::read_global(cx);
424        let Some(model) = model_registry.active_model() else {
425            return;
426        };
427
428        self.thread.update(cx, |thread, cx| {
429            thread.send_to_model(model, RequestKind::Chat, false, cx)
430        });
431        cx.notify();
432    }
433
434    fn last_user_message(&self, cx: &Context<Self>) -> Option<MessageId> {
435        self.messages
436            .iter()
437            .rev()
438            .find(|message_id| {
439                self.thread
440                    .read(cx)
441                    .message(**message_id)
442                    .map_or(false, |message| message.role == Role::User)
443            })
444            .cloned()
445    }
446
447    fn messages_after(&self, message_id: MessageId) -> &[MessageId] {
448        self.messages
449            .iter()
450            .position(|id| *id == message_id)
451            .map(|index| &self.messages[index + 1..])
452            .unwrap_or(&[])
453    }
454
455    fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
456        let message_id = self.messages[ix];
457        let Some(message) = self.thread.read(cx).message(message_id) else {
458            return Empty.into_any();
459        };
460
461        let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
462            return Empty.into_any();
463        };
464
465        let context = self.thread.read(cx).context_for_message(message_id);
466        let tool_uses = self.thread.read(cx).tool_uses_for_message(message_id);
467        let colors = cx.theme().colors();
468
469        // Don't render user messages that are just there for returning tool results.
470        if message.role == Role::User && self.thread.read(cx).message_has_tool_results(message_id) {
471            return Empty.into_any();
472        }
473
474        let allow_editing_message =
475            message.role == Role::User && self.last_user_message(cx) == Some(message_id);
476
477        let edit_message_editor = self
478            .editing_message
479            .as_ref()
480            .filter(|(id, _)| *id == message_id)
481            .map(|(_, state)| state.editor.clone());
482
483        let message_content = v_flex()
484            .child(
485                if let Some(edit_message_editor) = edit_message_editor.clone() {
486                    div()
487                        .key_context("EditMessageEditor")
488                        .on_action(cx.listener(Self::cancel_editing_message))
489                        .on_action(cx.listener(Self::confirm_editing_message))
490                        .p_2p5()
491                        .child(edit_message_editor)
492                } else {
493                    div().p_2p5().text_ui(cx).child(markdown.clone())
494                },
495            )
496            .when_some(context, |parent, context| {
497                if !context.is_empty() {
498                    parent.child(
499                        h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
500                            context
501                                .into_iter()
502                                .map(|context| ContextPill::added(context, false, false, None)),
503                        ),
504                    )
505                } else {
506                    parent
507                }
508            });
509
510        let styled_message = match message.role {
511            Role::User => v_flex()
512                .id(("message-container", ix))
513                .pt_2p5()
514                .px_2p5()
515                .child(
516                    v_flex()
517                        .bg(colors.editor_background)
518                        .rounded_lg()
519                        .border_1()
520                        .border_color(colors.border)
521                        .shadow_sm()
522                        .child(
523                            h_flex()
524                                .py_1()
525                                .px_2()
526                                .bg(colors.editor_foreground.opacity(0.05))
527                                .border_b_1()
528                                .border_color(colors.border)
529                                .justify_between()
530                                .rounded_t(px(6.))
531                                .child(
532                                    h_flex()
533                                        .gap_1p5()
534                                        .child(
535                                            Icon::new(IconName::PersonCircle)
536                                                .size(IconSize::XSmall)
537                                                .color(Color::Muted),
538                                        )
539                                        .child(
540                                            Label::new("You")
541                                                .size(LabelSize::Small)
542                                                .color(Color::Muted),
543                                        ),
544                                )
545                                .when_some(
546                                    edit_message_editor.clone(),
547                                    |this, edit_message_editor| {
548                                        let focus_handle = edit_message_editor.focus_handle(cx);
549                                        this.child(
550                                            h_flex()
551                                                .gap_1()
552                                                .child(
553                                                    Button::new("cancel-edit-message", "Cancel")
554                                                        .key_binding(KeyBinding::for_action_in(
555                                                            &menu::Cancel,
556                                                            &focus_handle,
557                                                            window,
558                                                            cx,
559                                                        )),
560                                                )
561                                                .child(
562                                                    Button::new(
563                                                        "confirm-edit-message",
564                                                        "Regenerate",
565                                                    )
566                                                    .key_binding(KeyBinding::for_action_in(
567                                                        &menu::Confirm,
568                                                        &focus_handle,
569                                                        window,
570                                                        cx,
571                                                    )),
572                                                ),
573                                        )
574                                    },
575                                )
576                                .when(
577                                    edit_message_editor.is_none() && allow_editing_message,
578                                    |this| {
579                                        this.child(Button::new("edit-message", "Edit").on_click(
580                                            cx.listener({
581                                                let message_text = message.text.clone();
582                                                move |this, _, window, cx| {
583                                                    this.start_editing_message(
584                                                        message_id,
585                                                        message_text.clone(),
586                                                        window,
587                                                        cx,
588                                                    );
589                                                }
590                                            }),
591                                        ))
592                                    },
593                                ),
594                        )
595                        .child(message_content),
596                ),
597            Role::Assistant => div()
598                .id(("message-container", ix))
599                .child(message_content)
600                .map(|parent| {
601                    if tool_uses.is_empty() {
602                        return parent;
603                    }
604
605                    parent.child(
606                        v_flex().children(
607                            tool_uses
608                                .into_iter()
609                                .map(|tool_use| self.render_tool_use(tool_use, cx)),
610                        ),
611                    )
612                }),
613            Role::System => div().id(("message-container", ix)).py_1().px_2().child(
614                v_flex()
615                    .bg(colors.editor_background)
616                    .rounded_md()
617                    .child(message_content),
618            ),
619        };
620
621        styled_message.into_any()
622    }
623
624    fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {
625        let is_open = self
626            .expanded_tool_uses
627            .get(&tool_use.id)
628            .copied()
629            .unwrap_or_default();
630
631        div().px_2p5().child(
632            v_flex()
633                .gap_1()
634                .rounded_lg()
635                .border_1()
636                .border_color(cx.theme().colors().border)
637                .child(
638                    h_flex()
639                        .justify_between()
640                        .py_0p5()
641                        .pl_1()
642                        .pr_2()
643                        .bg(cx.theme().colors().editor_foreground.opacity(0.02))
644                        .when(is_open, |element| element.border_b_1().rounded_t(px(6.)))
645                        .when(!is_open, |element| element.rounded(px(6.)))
646                        .border_color(cx.theme().colors().border)
647                        .child(
648                            h_flex()
649                                .gap_1()
650                                .child(Disclosure::new("tool-use-disclosure", is_open).on_click(
651                                    cx.listener({
652                                        let tool_use_id = tool_use.id.clone();
653                                        move |this, _event, _window, _cx| {
654                                            let is_open = this
655                                                .expanded_tool_uses
656                                                .entry(tool_use_id.clone())
657                                                .or_insert(false);
658
659                                            *is_open = !*is_open;
660                                        }
661                                    }),
662                                ))
663                                .child(Label::new(tool_use.name)),
664                        )
665                        .child(
666                            Label::new(match tool_use.status {
667                                ToolUseStatus::Pending => "Pending",
668                                ToolUseStatus::Running => "Running",
669                                ToolUseStatus::Finished(_) => "Finished",
670                                ToolUseStatus::Error(_) => "Error",
671                            })
672                            .size(LabelSize::XSmall)
673                            .buffer_font(cx),
674                        ),
675                )
676                .map(|parent| {
677                    if !is_open {
678                        return parent;
679                    }
680
681                    parent.child(
682                        v_flex()
683                            .child(
684                                v_flex()
685                                    .gap_0p5()
686                                    .py_1()
687                                    .px_2p5()
688                                    .border_b_1()
689                                    .border_color(cx.theme().colors().border)
690                                    .child(Label::new("Input:"))
691                                    .child(Label::new(
692                                        serde_json::to_string_pretty(&tool_use.input)
693                                            .unwrap_or_default(),
694                                    )),
695                            )
696                            .map(|parent| match tool_use.status {
697                                ToolUseStatus::Finished(output) => parent.child(
698                                    v_flex()
699                                        .gap_0p5()
700                                        .py_1()
701                                        .px_2p5()
702                                        .child(Label::new("Result:"))
703                                        .child(Label::new(output)),
704                                ),
705                                ToolUseStatus::Error(err) => parent.child(
706                                    v_flex()
707                                        .gap_0p5()
708                                        .py_1()
709                                        .px_2p5()
710                                        .child(Label::new("Error:"))
711                                        .child(Label::new(err)),
712                                ),
713                                ToolUseStatus::Pending | ToolUseStatus::Running => parent,
714                            }),
715                    )
716                }),
717        )
718    }
719}
720
721impl Render for ActiveThread {
722    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
723        v_flex()
724            .size_full()
725            .child(list(self.list_state.clone()).flex_grow())
726    }
727}