active_thread.rs

  1use std::sync::Arc;
  2
  3use collections::HashMap;
  4use editor::{Editor, MultiBuffer};
  5use gpui::{
  6    list, AbsoluteLength, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty,
  7    Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
  8    Task, TextStyleRefinement, UnderlineStyle,
  9};
 10use language::{Buffer, LanguageRegistry};
 11use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
 12use markdown::{Markdown, MarkdownStyle};
 13use settings::Settings as _;
 14use theme::ThemeSettings;
 15use ui::{prelude::*, Disclosure, KeyBinding};
 16use util::ResultExt as _;
 17
 18use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
 19use crate::thread_store::ThreadStore;
 20use crate::tool_use::{ToolUse, ToolUseStatus};
 21use crate::ui::ContextPill;
 22
 23pub struct ActiveThread {
 24    language_registry: Arc<LanguageRegistry>,
 25    thread_store: Entity<ThreadStore>,
 26    thread: Entity<Thread>,
 27    save_thread_task: Option<Task<()>>,
 28    messages: Vec<MessageId>,
 29    list_state: ListState,
 30    rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
 31    editing_message: Option<(MessageId, EditMessageState)>,
 32    expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
 33    last_error: Option<ThreadError>,
 34    _subscriptions: Vec<Subscription>,
 35}
 36
 37struct EditMessageState {
 38    editor: Entity<Editor>,
 39}
 40
 41impl ActiveThread {
 42    pub fn new(
 43        thread: Entity<Thread>,
 44        thread_store: Entity<ThreadStore>,
 45        language_registry: Arc<LanguageRegistry>,
 46        window: &mut Window,
 47        cx: &mut Context<Self>,
 48    ) -> Self {
 49        let subscriptions = vec![
 50            cx.observe(&thread, |_, _, cx| cx.notify()),
 51            cx.subscribe_in(&thread, window, Self::handle_thread_event),
 52        ];
 53
 54        let mut this = Self {
 55            language_registry,
 56            thread_store,
 57            thread: thread.clone(),
 58            save_thread_task: None,
 59            messages: Vec::new(),
 60            rendered_messages_by_id: HashMap::default(),
 61            expanded_tool_uses: HashMap::default(),
 62            list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
 63                let this = cx.entity().downgrade();
 64                move |ix, window: &mut Window, cx: &mut App| {
 65                    this.update(cx, |this, cx| this.render_message(ix, window, cx))
 66                        .unwrap()
 67                }
 68            }),
 69            editing_message: None,
 70            last_error: None,
 71            _subscriptions: subscriptions,
 72        };
 73
 74        for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
 75            this.push_message(&message.id, message.text.clone(), window, cx);
 76        }
 77
 78        this
 79    }
 80
 81    pub fn thread(&self) -> &Entity<Thread> {
 82        &self.thread
 83    }
 84
 85    pub fn is_empty(&self) -> bool {
 86        self.messages.is_empty()
 87    }
 88
 89    pub fn summary(&self, cx: &App) -> Option<SharedString> {
 90        self.thread.read(cx).summary()
 91    }
 92
 93    pub fn summary_or_default(&self, cx: &App) -> SharedString {
 94        self.thread.read(cx).summary_or_default()
 95    }
 96
 97    pub fn cancel_last_completion(&mut self, cx: &mut App) -> bool {
 98        self.last_error.take();
 99        self.thread
100            .update(cx, |thread, _cx| thread.cancel_last_completion())
101    }
102
103    pub fn last_error(&self) -> Option<ThreadError> {
104        self.last_error.clone()
105    }
106
107    pub fn clear_last_error(&mut self) {
108        self.last_error.take();
109    }
110
111    fn push_message(
112        &mut self,
113        id: &MessageId,
114        text: String,
115        window: &mut Window,
116        cx: &mut Context<Self>,
117    ) {
118        let old_len = self.messages.len();
119        self.messages.push(*id);
120        self.list_state.splice(old_len..old_len, 1);
121
122        let markdown = self.render_markdown(text.into(), window, cx);
123        self.rendered_messages_by_id.insert(*id, markdown);
124        self.list_state.scroll_to(ListOffset {
125            item_ix: old_len,
126            offset_in_item: Pixels(0.0),
127        });
128    }
129
130    fn edited_message(
131        &mut self,
132        id: &MessageId,
133        text: String,
134        window: &mut Window,
135        cx: &mut Context<Self>,
136    ) {
137        let Some(index) = self.messages.iter().position(|message_id| message_id == id) else {
138            return;
139        };
140        self.list_state.splice(index..index + 1, 1);
141        let markdown = self.render_markdown(text.into(), window, cx);
142        self.rendered_messages_by_id.insert(*id, markdown);
143    }
144
145    fn deleted_message(&mut self, id: &MessageId) {
146        let Some(index) = self.messages.iter().position(|message_id| message_id == id) else {
147            return;
148        };
149        self.messages.remove(index);
150        self.list_state.splice(index..index + 1, 0);
151        self.rendered_messages_by_id.remove(id);
152    }
153
154    fn render_markdown(
155        &self,
156        text: SharedString,
157        window: &Window,
158        cx: &mut Context<Self>,
159    ) -> Entity<Markdown> {
160        let theme_settings = ThemeSettings::get_global(cx);
161        let colors = cx.theme().colors();
162        let ui_font_size = TextSize::Default.rems(cx);
163        let buffer_font_size = TextSize::Small.rems(cx);
164        let mut text_style = window.text_style();
165
166        text_style.refine(&TextStyleRefinement {
167            font_family: Some(theme_settings.ui_font.family.clone()),
168            font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
169            font_features: Some(theme_settings.ui_font.features.clone()),
170            font_size: Some(ui_font_size.into()),
171            color: Some(cx.theme().colors().text),
172            ..Default::default()
173        });
174
175        let markdown_style = MarkdownStyle {
176            base_text_style: text_style,
177            syntax: cx.theme().syntax().clone(),
178            selection_background_color: cx.theme().players().local().selection,
179            code_block_overflow_x_scroll: true,
180            table_overflow_x_scroll: true,
181            code_block: StyleRefinement {
182                margin: EdgesRefinement {
183                    top: Some(Length::Definite(rems(0.).into())),
184                    left: Some(Length::Definite(rems(0.).into())),
185                    right: Some(Length::Definite(rems(0.).into())),
186                    bottom: Some(Length::Definite(rems(0.5).into())),
187                },
188                padding: EdgesRefinement {
189                    top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
190                    left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
191                    right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
192                    bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
193                },
194                background: Some(colors.editor_background.into()),
195                border_color: Some(colors.border_variant),
196                border_widths: EdgesRefinement {
197                    top: Some(AbsoluteLength::Pixels(Pixels(1.))),
198                    left: Some(AbsoluteLength::Pixels(Pixels(1.))),
199                    right: Some(AbsoluteLength::Pixels(Pixels(1.))),
200                    bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
201                },
202                text: Some(TextStyleRefinement {
203                    font_family: Some(theme_settings.buffer_font.family.clone()),
204                    font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
205                    font_features: Some(theme_settings.buffer_font.features.clone()),
206                    font_size: Some(buffer_font_size.into()),
207                    ..Default::default()
208                }),
209                ..Default::default()
210            },
211            inline_code: TextStyleRefinement {
212                font_family: Some(theme_settings.buffer_font.family.clone()),
213                font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
214                font_features: Some(theme_settings.buffer_font.features.clone()),
215                font_size: Some(buffer_font_size.into()),
216                background_color: Some(colors.editor_foreground.opacity(0.1)),
217                ..Default::default()
218            },
219            link: TextStyleRefinement {
220                background_color: Some(colors.editor_foreground.opacity(0.025)),
221                underline: Some(UnderlineStyle {
222                    color: Some(colors.text_accent.opacity(0.5)),
223                    thickness: px(1.),
224                    ..Default::default()
225                }),
226                ..Default::default()
227            },
228            ..Default::default()
229        };
230
231        cx.new(|cx| {
232            Markdown::new(
233                text,
234                markdown_style,
235                Some(self.language_registry.clone()),
236                None,
237                cx,
238            )
239        })
240    }
241
242    fn handle_thread_event(
243        &mut self,
244        _: &Entity<Thread>,
245        event: &ThreadEvent,
246        window: &mut Window,
247        cx: &mut Context<Self>,
248    ) {
249        match event {
250            ThreadEvent::ShowError(error) => {
251                self.last_error = Some(error.clone());
252            }
253            ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => {
254                self.save_thread(cx);
255            }
256            ThreadEvent::StreamedAssistantText(message_id, text) => {
257                if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
258                    markdown.update(cx, |markdown, cx| {
259                        markdown.append(text, cx);
260                    });
261                }
262            }
263            ThreadEvent::MessageAdded(message_id) => {
264                if let Some(message_text) = self
265                    .thread
266                    .read(cx)
267                    .message(*message_id)
268                    .map(|message| message.text.clone())
269                {
270                    self.push_message(message_id, message_text, window, cx);
271                }
272
273                self.save_thread(cx);
274                cx.notify();
275            }
276            ThreadEvent::MessageEdited(message_id) => {
277                if let Some(message_text) = self
278                    .thread
279                    .read(cx)
280                    .message(*message_id)
281                    .map(|message| message.text.clone())
282                {
283                    self.edited_message(message_id, message_text, window, cx);
284                }
285
286                self.save_thread(cx);
287                cx.notify();
288            }
289            ThreadEvent::MessageDeleted(message_id) => {
290                self.deleted_message(message_id);
291                self.save_thread(cx);
292                cx.notify();
293            }
294            ThreadEvent::UsePendingTools => {
295                self.thread.update(cx, |thread, cx| {
296                    thread.use_pending_tools(cx);
297                });
298            }
299            ThreadEvent::ToolFinished { .. } => {
300                let all_tools_finished = self
301                    .thread
302                    .read(cx)
303                    .pending_tool_uses()
304                    .into_iter()
305                    .all(|tool_use| tool_use.status.is_error());
306                if all_tools_finished {
307                    let model_registry = LanguageModelRegistry::read_global(cx);
308                    if let Some(model) = model_registry.active_model() {
309                        self.thread.update(cx, |thread, cx| {
310                            thread.send_tool_results_to_model(model, cx);
311                        });
312                    }
313                }
314            }
315        }
316    }
317
318    /// Spawns a task to save the active thread.
319    ///
320    /// Only one task to save the thread will be in flight at a time.
321    fn save_thread(&mut self, cx: &mut Context<Self>) {
322        let thread = self.thread.clone();
323        self.save_thread_task = Some(cx.spawn(|this, mut cx| async move {
324            let task = this
325                .update(&mut cx, |this, cx| {
326                    this.thread_store
327                        .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
328                })
329                .ok();
330
331            if let Some(task) = task {
332                task.await.log_err();
333            }
334        }));
335    }
336
337    fn start_editing_message(
338        &mut self,
339        message_id: MessageId,
340        message_text: String,
341        window: &mut Window,
342        cx: &mut Context<Self>,
343    ) {
344        let buffer = cx.new(|cx| {
345            MultiBuffer::singleton(cx.new(|cx| Buffer::local(message_text.clone(), cx)), cx)
346        });
347        let editor = cx.new(|cx| {
348            let mut editor = Editor::new(
349                editor::EditorMode::AutoHeight { max_lines: 8 },
350                buffer,
351                None,
352                false,
353                window,
354                cx,
355            );
356            editor.focus_handle(cx).focus(window);
357            editor.move_to_end(&editor::actions::MoveToEnd, window, cx);
358            editor
359        });
360        self.editing_message = Some((
361            message_id,
362            EditMessageState {
363                editor: editor.clone(),
364            },
365        ));
366        cx.notify();
367    }
368
369    fn cancel_editing_message(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
370        self.editing_message.take();
371        cx.notify();
372    }
373
374    fn confirm_editing_message(
375        &mut self,
376        _: &menu::Confirm,
377        _: &mut Window,
378        cx: &mut Context<Self>,
379    ) {
380        let Some((message_id, state)) = self.editing_message.take() else {
381            return;
382        };
383        let edited_text = state.editor.read(cx).text(cx);
384        self.thread.update(cx, |thread, cx| {
385            thread.edit_message(message_id, Role::User, edited_text, cx);
386            for message_id in self.messages_after(message_id) {
387                thread.delete_message(*message_id, cx);
388            }
389        });
390
391        let provider = LanguageModelRegistry::read_global(cx).active_provider();
392        if provider
393            .as_ref()
394            .map_or(false, |provider| provider.must_accept_terms(cx))
395        {
396            cx.notify();
397            return;
398        }
399        let model_registry = LanguageModelRegistry::read_global(cx);
400        let Some(model) = model_registry.active_model() else {
401            return;
402        };
403
404        self.thread.update(cx, |thread, cx| {
405            thread.send_to_model(model, RequestKind::Chat, false, cx)
406        });
407        cx.notify();
408    }
409
410    fn last_user_message(&self, cx: &Context<Self>) -> Option<MessageId> {
411        self.messages
412            .iter()
413            .rev()
414            .find(|message_id| {
415                self.thread
416                    .read(cx)
417                    .message(**message_id)
418                    .map_or(false, |message| message.role == Role::User)
419            })
420            .cloned()
421    }
422
423    fn messages_after(&self, message_id: MessageId) -> &[MessageId] {
424        self.messages
425            .iter()
426            .position(|id| *id == message_id)
427            .map(|index| &self.messages[index + 1..])
428            .unwrap_or(&[])
429    }
430
431    fn handle_cancel_click(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
432        self.cancel_editing_message(&menu::Cancel, window, cx);
433    }
434
435    fn handle_regenerate_click(
436        &mut self,
437        _: &ClickEvent,
438        window: &mut Window,
439        cx: &mut Context<Self>,
440    ) {
441        self.confirm_editing_message(&menu::Confirm, window, cx);
442    }
443
444    fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
445        let message_id = self.messages[ix];
446        let Some(message) = self.thread.read(cx).message(message_id) else {
447            return Empty.into_any();
448        };
449
450        let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
451            return Empty.into_any();
452        };
453
454        let context = self.thread.read(cx).context_for_message(message_id);
455        let tool_uses = self.thread.read(cx).tool_uses_for_message(message_id);
456        let colors = cx.theme().colors();
457
458        // Don't render user messages that are just there for returning tool results.
459        if message.role == Role::User && self.thread.read(cx).message_has_tool_results(message_id) {
460            return Empty.into_any();
461        }
462
463        let allow_editing_message =
464            message.role == Role::User && self.last_user_message(cx) == Some(message_id);
465
466        let edit_message_editor = self
467            .editing_message
468            .as_ref()
469            .filter(|(id, _)| *id == message_id)
470            .map(|(_, state)| state.editor.clone());
471
472        let message_content = v_flex()
473            .child(
474                if let Some(edit_message_editor) = edit_message_editor.clone() {
475                    div()
476                        .key_context("EditMessageEditor")
477                        .on_action(cx.listener(Self::cancel_editing_message))
478                        .on_action(cx.listener(Self::confirm_editing_message))
479                        .p_2p5()
480                        .child(edit_message_editor)
481                } else {
482                    div().p_2p5().text_ui(cx).child(markdown.clone())
483                },
484            )
485            .when_some(context, |parent, context| {
486                if !context.is_empty() {
487                    parent.child(
488                        h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
489                            context
490                                .into_iter()
491                                .map(|context| ContextPill::added(context, false, false, None)),
492                        ),
493                    )
494                } else {
495                    parent
496                }
497            });
498
499        let styled_message = match message.role {
500            Role::User => v_flex()
501                .id(("message-container", ix))
502                .pt_2p5()
503                .px_2p5()
504                .child(
505                    v_flex()
506                        .bg(colors.editor_background)
507                        .rounded_lg()
508                        .border_1()
509                        .border_color(colors.border)
510                        .shadow_sm()
511                        .child(
512                            h_flex()
513                                .py_1()
514                                .pl_2()
515                                .pr_1()
516                                .bg(colors.editor_foreground.opacity(0.05))
517                                .border_b_1()
518                                .border_color(colors.border)
519                                .justify_between()
520                                .rounded_t(px(6.))
521                                .child(
522                                    h_flex()
523                                        .gap_1p5()
524                                        .child(
525                                            Icon::new(IconName::PersonCircle)
526                                                .size(IconSize::XSmall)
527                                                .color(Color::Muted),
528                                        )
529                                        .child(
530                                            Label::new("You")
531                                                .size(LabelSize::Small)
532                                                .color(Color::Muted),
533                                        ),
534                                )
535                                .when_some(
536                                    edit_message_editor.clone(),
537                                    |this, edit_message_editor| {
538                                        let focus_handle = edit_message_editor.focus_handle(cx);
539                                        this.child(
540                                            h_flex()
541                                                .gap_1()
542                                                .child(
543                                                    Button::new("cancel-edit-message", "Cancel")
544                                                        .label_size(LabelSize::Small)
545                                                        .key_binding(
546                                                            KeyBinding::for_action_in(
547                                                                &menu::Cancel,
548                                                                &focus_handle,
549                                                                window,
550                                                                cx,
551                                                            )
552                                                            .map(|kb| kb.size(rems_from_px(12.))),
553                                                        )
554                                                        .on_click(
555                                                            cx.listener(Self::handle_cancel_click),
556                                                        ),
557                                                )
558                                                .child(
559                                                    Button::new(
560                                                        "confirm-edit-message",
561                                                        "Regenerate",
562                                                    )
563                                                    .label_size(LabelSize::Small)
564                                                    .key_binding(
565                                                        KeyBinding::for_action_in(
566                                                            &menu::Confirm,
567                                                            &focus_handle,
568                                                            window,
569                                                            cx,
570                                                        )
571                                                        .map(|kb| kb.size(rems_from_px(12.))),
572                                                    )
573                                                    .on_click(
574                                                        cx.listener(Self::handle_regenerate_click),
575                                                    ),
576                                                ),
577                                        )
578                                    },
579                                )
580                                .when(
581                                    edit_message_editor.is_none() && allow_editing_message,
582                                    |this| {
583                                        this.child(
584                                            Button::new("edit-message", "Edit")
585                                                .label_size(LabelSize::Small)
586                                                .on_click(cx.listener({
587                                                    let message_text = message.text.clone();
588                                                    move |this, _, window, cx| {
589                                                        this.start_editing_message(
590                                                            message_id,
591                                                            message_text.clone(),
592                                                            window,
593                                                            cx,
594                                                        );
595                                                    }
596                                                })),
597                                        )
598                                    },
599                                ),
600                        )
601                        .child(message_content),
602                ),
603            Role::Assistant => div()
604                .id(("message-container", ix))
605                .child(message_content)
606                .map(|parent| {
607                    if tool_uses.is_empty() {
608                        return parent;
609                    }
610
611                    parent.child(
612                        v_flex().children(
613                            tool_uses
614                                .into_iter()
615                                .map(|tool_use| self.render_tool_use(tool_use, cx)),
616                        ),
617                    )
618                }),
619            Role::System => div().id(("message-container", ix)).py_1().px_2().child(
620                v_flex()
621                    .bg(colors.editor_background)
622                    .rounded_sm()
623                    .child(message_content),
624            ),
625        };
626
627        styled_message.into_any()
628    }
629
630    fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {
631        let is_open = self
632            .expanded_tool_uses
633            .get(&tool_use.id)
634            .copied()
635            .unwrap_or_default();
636
637        div().px_2p5().child(
638            v_flex()
639                .gap_1()
640                .rounded_lg()
641                .border_1()
642                .border_color(cx.theme().colors().border)
643                .child(
644                    h_flex()
645                        .justify_between()
646                        .py_0p5()
647                        .pl_1()
648                        .pr_2()
649                        .bg(cx.theme().colors().editor_foreground.opacity(0.02))
650                        .when(is_open, |element| element.border_b_1().rounded_t(px(6.)))
651                        .when(!is_open, |element| element.rounded_md())
652                        .border_color(cx.theme().colors().border)
653                        .child(
654                            h_flex()
655                                .gap_1()
656                                .child(Disclosure::new("tool-use-disclosure", is_open).on_click(
657                                    cx.listener({
658                                        let tool_use_id = tool_use.id.clone();
659                                        move |this, _event, _window, _cx| {
660                                            let is_open = this
661                                                .expanded_tool_uses
662                                                .entry(tool_use_id.clone())
663                                                .or_insert(false);
664
665                                            *is_open = !*is_open;
666                                        }
667                                    }),
668                                ))
669                                .child(Label::new(tool_use.name)),
670                        )
671                        .child(
672                            Label::new(match tool_use.status {
673                                ToolUseStatus::Pending => "Pending",
674                                ToolUseStatus::Running => "Running",
675                                ToolUseStatus::Finished(_) => "Finished",
676                                ToolUseStatus::Error(_) => "Error",
677                            })
678                            .size(LabelSize::XSmall)
679                            .buffer_font(cx),
680                        ),
681                )
682                .map(|parent| {
683                    if !is_open {
684                        return parent;
685                    }
686
687                    parent.child(
688                        v_flex()
689                            .child(
690                                v_flex()
691                                    .gap_0p5()
692                                    .py_1()
693                                    .px_2p5()
694                                    .border_b_1()
695                                    .border_color(cx.theme().colors().border)
696                                    .child(Label::new("Input:"))
697                                    .child(Label::new(
698                                        serde_json::to_string_pretty(&tool_use.input)
699                                            .unwrap_or_default(),
700                                    )),
701                            )
702                            .map(|parent| match tool_use.status {
703                                ToolUseStatus::Finished(output) => parent.child(
704                                    v_flex()
705                                        .gap_0p5()
706                                        .py_1()
707                                        .px_2p5()
708                                        .child(Label::new("Result:"))
709                                        .child(Label::new(output)),
710                                ),
711                                ToolUseStatus::Error(err) => parent.child(
712                                    v_flex()
713                                        .gap_0p5()
714                                        .py_1()
715                                        .px_2p5()
716                                        .child(Label::new("Error:"))
717                                        .child(Label::new(err)),
718                                ),
719                                ToolUseStatus::Pending | ToolUseStatus::Running => parent,
720                            }),
721                    )
722                }),
723        )
724    }
725}
726
727impl Render for ActiveThread {
728    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
729        v_flex()
730            .size_full()
731            .child(list(self.list_state.clone()).flex_grow())
732    }
733}