active_thread.rs

  1use std::sync::Arc;
  2
  3use collections::HashMap;
  4use editor::{Editor, MultiBuffer};
  5use gpui::{
  6    list, AbsoluteLength, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty,
  7    Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
  8    Task, TextStyleRefinement, UnderlineStyle,
  9};
 10use language::{Buffer, LanguageRegistry};
 11use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
 12use markdown::{Markdown, MarkdownStyle};
 13use settings::Settings as _;
 14use theme::ThemeSettings;
 15use ui::{prelude::*, Disclosure, KeyBinding};
 16use util::ResultExt as _;
 17
 18use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
 19use crate::thread_store::ThreadStore;
 20use crate::tool_use::{ToolUse, ToolUseStatus};
 21use crate::ui::ContextPill;
 22
 23pub struct ActiveThread {
 24    language_registry: Arc<LanguageRegistry>,
 25    thread_store: Entity<ThreadStore>,
 26    thread: Entity<Thread>,
 27    save_thread_task: Option<Task<()>>,
 28    messages: Vec<MessageId>,
 29    list_state: ListState,
 30    rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
 31    editing_message: Option<(MessageId, EditMessageState)>,
 32    expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
 33    last_error: Option<ThreadError>,
 34    _subscriptions: Vec<Subscription>,
 35}
 36
 37struct EditMessageState {
 38    editor: Entity<Editor>,
 39}
 40
 41impl ActiveThread {
 42    pub fn new(
 43        thread: Entity<Thread>,
 44        thread_store: Entity<ThreadStore>,
 45        language_registry: Arc<LanguageRegistry>,
 46        window: &mut Window,
 47        cx: &mut Context<Self>,
 48    ) -> Self {
 49        let subscriptions = vec![
 50            cx.observe(&thread, |_, _, cx| cx.notify()),
 51            cx.subscribe_in(&thread, window, Self::handle_thread_event),
 52        ];
 53
 54        let mut this = Self {
 55            language_registry,
 56            thread_store,
 57            thread: thread.clone(),
 58            save_thread_task: None,
 59            messages: Vec::new(),
 60            rendered_messages_by_id: HashMap::default(),
 61            expanded_tool_uses: HashMap::default(),
 62            list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
 63                let this = cx.entity().downgrade();
 64                move |ix, window: &mut Window, cx: &mut App| {
 65                    this.update(cx, |this, cx| this.render_message(ix, window, cx))
 66                        .unwrap()
 67                }
 68            }),
 69            editing_message: None,
 70            last_error: None,
 71            _subscriptions: subscriptions,
 72        };
 73
 74        for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
 75            this.push_message(&message.id, message.text.clone(), window, cx);
 76        }
 77
 78        this
 79    }
 80
 81    pub fn thread(&self) -> &Entity<Thread> {
 82        &self.thread
 83    }
 84
 85    pub fn is_empty(&self) -> bool {
 86        self.messages.is_empty()
 87    }
 88
 89    pub fn summary(&self, cx: &App) -> Option<SharedString> {
 90        self.thread.read(cx).summary()
 91    }
 92
 93    pub fn summary_or_default(&self, cx: &App) -> SharedString {
 94        self.thread.read(cx).summary_or_default()
 95    }
 96
 97    pub fn cancel_last_completion(&mut self, cx: &mut App) -> bool {
 98        self.last_error.take();
 99        self.thread
100            .update(cx, |thread, _cx| thread.cancel_last_completion())
101    }
102
103    pub fn last_error(&self) -> Option<ThreadError> {
104        self.last_error.clone()
105    }
106
107    pub fn clear_last_error(&mut self) {
108        self.last_error.take();
109    }
110
111    fn push_message(
112        &mut self,
113        id: &MessageId,
114        text: String,
115        window: &mut Window,
116        cx: &mut Context<Self>,
117    ) {
118        let old_len = self.messages.len();
119        self.messages.push(*id);
120        self.list_state.splice(old_len..old_len, 1);
121
122        let markdown = self.render_markdown(text.into(), window, cx);
123        self.rendered_messages_by_id.insert(*id, markdown);
124        self.list_state.scroll_to(ListOffset {
125            item_ix: old_len,
126            offset_in_item: Pixels(0.0),
127        });
128    }
129
130    fn edited_message(
131        &mut self,
132        id: &MessageId,
133        text: String,
134        window: &mut Window,
135        cx: &mut Context<Self>,
136    ) {
137        let Some(index) = self.messages.iter().position(|message_id| message_id == id) else {
138            return;
139        };
140        self.list_state.splice(index..index + 1, 1);
141        let markdown = self.render_markdown(text.into(), window, cx);
142        self.rendered_messages_by_id.insert(*id, markdown);
143    }
144
145    fn deleted_message(&mut self, id: &MessageId) {
146        let Some(index) = self.messages.iter().position(|message_id| message_id == id) else {
147            return;
148        };
149        self.messages.remove(index);
150        self.list_state.splice(index..index + 1, 0);
151        self.rendered_messages_by_id.remove(id);
152    }
153
154    fn render_markdown(
155        &self,
156        text: SharedString,
157        window: &Window,
158        cx: &mut Context<Self>,
159    ) -> Entity<Markdown> {
160        let theme_settings = ThemeSettings::get_global(cx);
161        let colors = cx.theme().colors();
162        let ui_font_size = TextSize::Default.rems(cx);
163        let buffer_font_size = TextSize::Small.rems(cx);
164        let mut text_style = window.text_style();
165
166        text_style.refine(&TextStyleRefinement {
167            font_family: Some(theme_settings.ui_font.family.clone()),
168            font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
169            font_features: Some(theme_settings.ui_font.features.clone()),
170            font_size: Some(ui_font_size.into()),
171            color: Some(cx.theme().colors().text),
172            ..Default::default()
173        });
174
175        let markdown_style = MarkdownStyle {
176            base_text_style: text_style,
177            syntax: cx.theme().syntax().clone(),
178            selection_background_color: cx.theme().players().local().selection,
179            code_block_overflow_x_scroll: true,
180            table_overflow_x_scroll: true,
181            code_block: StyleRefinement {
182                margin: EdgesRefinement {
183                    top: Some(Length::Definite(rems(0.).into())),
184                    left: Some(Length::Definite(rems(0.).into())),
185                    right: Some(Length::Definite(rems(0.).into())),
186                    bottom: Some(Length::Definite(rems(0.5).into())),
187                },
188                padding: EdgesRefinement {
189                    top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
190                    left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
191                    right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
192                    bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
193                },
194                background: Some(colors.editor_background.into()),
195                border_color: Some(colors.border_variant),
196                border_widths: EdgesRefinement {
197                    top: Some(AbsoluteLength::Pixels(Pixels(1.))),
198                    left: Some(AbsoluteLength::Pixels(Pixels(1.))),
199                    right: Some(AbsoluteLength::Pixels(Pixels(1.))),
200                    bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
201                },
202                text: Some(TextStyleRefinement {
203                    font_family: Some(theme_settings.buffer_font.family.clone()),
204                    font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
205                    font_features: Some(theme_settings.buffer_font.features.clone()),
206                    font_size: Some(buffer_font_size.into()),
207                    ..Default::default()
208                }),
209                ..Default::default()
210            },
211            inline_code: TextStyleRefinement {
212                font_family: Some(theme_settings.buffer_font.family.clone()),
213                font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
214                font_features: Some(theme_settings.buffer_font.features.clone()),
215                font_size: Some(buffer_font_size.into()),
216                background_color: Some(colors.editor_foreground.opacity(0.1)),
217                ..Default::default()
218            },
219            link: TextStyleRefinement {
220                background_color: Some(colors.editor_foreground.opacity(0.025)),
221                underline: Some(UnderlineStyle {
222                    color: Some(colors.text_accent.opacity(0.5)),
223                    thickness: px(1.),
224                    ..Default::default()
225                }),
226                ..Default::default()
227            },
228            ..Default::default()
229        };
230
231        cx.new(|cx| {
232            Markdown::new(
233                text,
234                markdown_style,
235                Some(self.language_registry.clone()),
236                None,
237                cx,
238            )
239        })
240    }
241
242    fn handle_thread_event(
243        &mut self,
244        _thread: &Entity<Thread>,
245        event: &ThreadEvent,
246        window: &mut Window,
247        cx: &mut Context<Self>,
248    ) {
249        match event {
250            ThreadEvent::ShowError(error) => {
251                self.last_error = Some(error.clone());
252            }
253            ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => {
254                self.save_thread(cx);
255            }
256            ThreadEvent::StreamedAssistantText(message_id, text) => {
257                if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
258                    markdown.update(cx, |markdown, cx| {
259                        markdown.append(text, cx);
260                    });
261                }
262            }
263            ThreadEvent::MessageAdded(message_id) => {
264                if let Some(message_text) = self
265                    .thread
266                    .read(cx)
267                    .message(*message_id)
268                    .map(|message| message.text.clone())
269                {
270                    self.push_message(message_id, message_text, window, cx);
271                }
272
273                self.save_thread(cx);
274                cx.notify();
275            }
276            ThreadEvent::MessageEdited(message_id) => {
277                if let Some(message_text) = self
278                    .thread
279                    .read(cx)
280                    .message(*message_id)
281                    .map(|message| message.text.clone())
282                {
283                    self.edited_message(message_id, message_text, window, cx);
284                }
285
286                self.save_thread(cx);
287                cx.notify();
288            }
289            ThreadEvent::MessageDeleted(message_id) => {
290                self.deleted_message(message_id);
291                self.save_thread(cx);
292                cx.notify();
293            }
294            ThreadEvent::UsePendingTools => {
295                self.thread.update(cx, |thread, cx| {
296                    thread.use_pending_tools(cx);
297                });
298            }
299            ThreadEvent::ToolFinished { .. } => {
300                if self.thread.read(cx).all_tools_finished() {
301                    let model_registry = LanguageModelRegistry::read_global(cx);
302                    if let Some(model) = model_registry.active_model() {
303                        self.thread.update(cx, |thread, cx| {
304                            thread.send_tool_results_to_model(model, cx);
305                        });
306                    }
307                }
308            }
309            ThreadEvent::ScriptFinished => {
310                let model_registry = LanguageModelRegistry::read_global(cx);
311                if let Some(model) = model_registry.active_model() {
312                    self.thread.update(cx, |thread, cx| {
313                        thread.send_to_model(model, RequestKind::Chat, false, cx);
314                    });
315                }
316            }
317        }
318    }
319
320    /// Spawns a task to save the active thread.
321    ///
322    /// Only one task to save the thread will be in flight at a time.
323    fn save_thread(&mut self, cx: &mut Context<Self>) {
324        let thread = self.thread.clone();
325        self.save_thread_task = Some(cx.spawn(|this, mut cx| async move {
326            let task = this
327                .update(&mut cx, |this, cx| {
328                    this.thread_store
329                        .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
330                })
331                .ok();
332
333            if let Some(task) = task {
334                task.await.log_err();
335            }
336        }));
337    }
338
339    fn start_editing_message(
340        &mut self,
341        message_id: MessageId,
342        message_text: String,
343        window: &mut Window,
344        cx: &mut Context<Self>,
345    ) {
346        let buffer = cx.new(|cx| {
347            MultiBuffer::singleton(cx.new(|cx| Buffer::local(message_text.clone(), cx)), cx)
348        });
349        let editor = cx.new(|cx| {
350            let mut editor = Editor::new(
351                editor::EditorMode::AutoHeight { max_lines: 8 },
352                buffer,
353                None,
354                false,
355                window,
356                cx,
357            );
358            editor.focus_handle(cx).focus(window);
359            editor.move_to_end(&editor::actions::MoveToEnd, window, cx);
360            editor
361        });
362        self.editing_message = Some((
363            message_id,
364            EditMessageState {
365                editor: editor.clone(),
366            },
367        ));
368        cx.notify();
369    }
370
371    fn cancel_editing_message(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
372        self.editing_message.take();
373        cx.notify();
374    }
375
376    fn confirm_editing_message(
377        &mut self,
378        _: &menu::Confirm,
379        _: &mut Window,
380        cx: &mut Context<Self>,
381    ) {
382        let Some((message_id, state)) = self.editing_message.take() else {
383            return;
384        };
385        let edited_text = state.editor.read(cx).text(cx);
386        self.thread.update(cx, |thread, cx| {
387            thread.edit_message(message_id, Role::User, edited_text, cx);
388            for message_id in self.messages_after(message_id) {
389                thread.delete_message(*message_id, cx);
390            }
391        });
392
393        let provider = LanguageModelRegistry::read_global(cx).active_provider();
394        if provider
395            .as_ref()
396            .map_or(false, |provider| provider.must_accept_terms(cx))
397        {
398            cx.notify();
399            return;
400        }
401        let model_registry = LanguageModelRegistry::read_global(cx);
402        let Some(model) = model_registry.active_model() else {
403            return;
404        };
405
406        self.thread.update(cx, |thread, cx| {
407            thread.send_to_model(model, RequestKind::Chat, false, cx)
408        });
409        cx.notify();
410    }
411
412    fn last_user_message(&self, cx: &Context<Self>) -> Option<MessageId> {
413        self.messages
414            .iter()
415            .rev()
416            .find(|message_id| {
417                self.thread
418                    .read(cx)
419                    .message(**message_id)
420                    .map_or(false, |message| message.role == Role::User)
421            })
422            .cloned()
423    }
424
425    fn messages_after(&self, message_id: MessageId) -> &[MessageId] {
426        self.messages
427            .iter()
428            .position(|id| *id == message_id)
429            .map(|index| &self.messages[index + 1..])
430            .unwrap_or(&[])
431    }
432
433    fn handle_cancel_click(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
434        self.cancel_editing_message(&menu::Cancel, window, cx);
435    }
436
437    fn handle_regenerate_click(
438        &mut self,
439        _: &ClickEvent,
440        window: &mut Window,
441        cx: &mut Context<Self>,
442    ) {
443        self.confirm_editing_message(&menu::Confirm, window, cx);
444    }
445
446    fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
447        let message_id = self.messages[ix];
448        let Some(message) = self.thread.read(cx).message(message_id) else {
449            return Empty.into_any();
450        };
451
452        let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
453            return Empty.into_any();
454        };
455
456        let thread = self.thread.read(cx);
457
458        let context = thread.context_for_message(message_id);
459        let tool_uses = thread.tool_uses_for_message(message_id);
460
461        // Don't render user messages that are just there for returning tool results.
462        if message.role == Role::User && thread.message_has_tool_results(message_id) {
463            return Empty.into_any();
464        }
465
466        let allow_editing_message =
467            message.role == Role::User && self.last_user_message(cx) == Some(message_id);
468
469        let edit_message_editor = self
470            .editing_message
471            .as_ref()
472            .filter(|(id, _)| *id == message_id)
473            .map(|(_, state)| state.editor.clone());
474
475        let colors = cx.theme().colors();
476
477        let message_content = v_flex()
478            .child(
479                if let Some(edit_message_editor) = edit_message_editor.clone() {
480                    div()
481                        .key_context("EditMessageEditor")
482                        .on_action(cx.listener(Self::cancel_editing_message))
483                        .on_action(cx.listener(Self::confirm_editing_message))
484                        .p_2p5()
485                        .child(edit_message_editor)
486                } else {
487                    div().p_2p5().text_ui(cx).child(markdown.clone())
488                },
489            )
490            .when_some(context, |parent, context| {
491                if !context.is_empty() {
492                    parent.child(
493                        h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
494                            context
495                                .into_iter()
496                                .map(|context| ContextPill::added(context, false, false, None)),
497                        ),
498                    )
499                } else {
500                    parent
501                }
502            });
503
504        let styled_message = match message.role {
505            Role::User => v_flex()
506                .id(("message-container", ix))
507                .pt_2p5()
508                .px_2p5()
509                .child(
510                    v_flex()
511                        .bg(colors.editor_background)
512                        .rounded_lg()
513                        .border_1()
514                        .border_color(colors.border)
515                        .shadow_sm()
516                        .child(
517                            h_flex()
518                                .py_1()
519                                .pl_2()
520                                .pr_1()
521                                .bg(colors.editor_foreground.opacity(0.05))
522                                .border_b_1()
523                                .border_color(colors.border)
524                                .justify_between()
525                                .rounded_t(px(6.))
526                                .child(
527                                    h_flex()
528                                        .gap_1p5()
529                                        .child(
530                                            Icon::new(IconName::PersonCircle)
531                                                .size(IconSize::XSmall)
532                                                .color(Color::Muted),
533                                        )
534                                        .child(
535                                            Label::new("You")
536                                                .size(LabelSize::Small)
537                                                .color(Color::Muted),
538                                        ),
539                                )
540                                .when_some(
541                                    edit_message_editor.clone(),
542                                    |this, edit_message_editor| {
543                                        let focus_handle = edit_message_editor.focus_handle(cx);
544                                        this.child(
545                                            h_flex()
546                                                .gap_1()
547                                                .child(
548                                                    Button::new("cancel-edit-message", "Cancel")
549                                                        .label_size(LabelSize::Small)
550                                                        .key_binding(
551                                                            KeyBinding::for_action_in(
552                                                                &menu::Cancel,
553                                                                &focus_handle,
554                                                                window,
555                                                                cx,
556                                                            )
557                                                            .map(|kb| kb.size(rems_from_px(12.))),
558                                                        )
559                                                        .on_click(
560                                                            cx.listener(Self::handle_cancel_click),
561                                                        ),
562                                                )
563                                                .child(
564                                                    Button::new(
565                                                        "confirm-edit-message",
566                                                        "Regenerate",
567                                                    )
568                                                    .label_size(LabelSize::Small)
569                                                    .key_binding(
570                                                        KeyBinding::for_action_in(
571                                                            &menu::Confirm,
572                                                            &focus_handle,
573                                                            window,
574                                                            cx,
575                                                        )
576                                                        .map(|kb| kb.size(rems_from_px(12.))),
577                                                    )
578                                                    .on_click(
579                                                        cx.listener(Self::handle_regenerate_click),
580                                                    ),
581                                                ),
582                                        )
583                                    },
584                                )
585                                .when(
586                                    edit_message_editor.is_none() && allow_editing_message,
587                                    |this| {
588                                        this.child(
589                                            Button::new("edit-message", "Edit")
590                                                .label_size(LabelSize::Small)
591                                                .on_click(cx.listener({
592                                                    let message_text = message.text.clone();
593                                                    move |this, _, window, cx| {
594                                                        this.start_editing_message(
595                                                            message_id,
596                                                            message_text.clone(),
597                                                            window,
598                                                            cx,
599                                                        );
600                                                    }
601                                                })),
602                                        )
603                                    },
604                                ),
605                        )
606                        .child(message_content),
607                ),
608            Role::Assistant => div()
609                .id(("message-container", ix))
610                .child(message_content)
611                .map(|parent| {
612                    if tool_uses.is_empty() {
613                        return parent;
614                    }
615
616                    parent.child(
617                        v_flex().children(
618                            tool_uses
619                                .into_iter()
620                                .map(|tool_use| self.render_tool_use(tool_use, cx)),
621                        ),
622                    )
623                }),
624            Role::System => div().id(("message-container", ix)).py_1().px_2().child(
625                v_flex()
626                    .bg(colors.editor_background)
627                    .rounded_sm()
628                    .child(message_content),
629            ),
630        };
631
632        styled_message.into_any()
633    }
634
635    fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {
636        let is_open = self
637            .expanded_tool_uses
638            .get(&tool_use.id)
639            .copied()
640            .unwrap_or_default();
641
642        div().px_2p5().child(
643            v_flex()
644                .gap_1()
645                .rounded_lg()
646                .border_1()
647                .border_color(cx.theme().colors().border)
648                .child(
649                    h_flex()
650                        .justify_between()
651                        .py_0p5()
652                        .pl_1()
653                        .pr_2()
654                        .bg(cx.theme().colors().editor_foreground.opacity(0.02))
655                        .when(is_open, |element| element.border_b_1().rounded_t(px(6.)))
656                        .when(!is_open, |element| element.rounded_md())
657                        .border_color(cx.theme().colors().border)
658                        .child(
659                            h_flex()
660                                .gap_1()
661                                .child(Disclosure::new("tool-use-disclosure", is_open).on_click(
662                                    cx.listener({
663                                        let tool_use_id = tool_use.id.clone();
664                                        move |this, _event, _window, _cx| {
665                                            let is_open = this
666                                                .expanded_tool_uses
667                                                .entry(tool_use_id.clone())
668                                                .or_insert(false);
669
670                                            *is_open = !*is_open;
671                                        }
672                                    }),
673                                ))
674                                .child(Label::new(tool_use.name)),
675                        )
676                        .child(
677                            Label::new(match tool_use.status {
678                                ToolUseStatus::Pending => "Pending",
679                                ToolUseStatus::Running => "Running",
680                                ToolUseStatus::Finished(_) => "Finished",
681                                ToolUseStatus::Error(_) => "Error",
682                            })
683                            .size(LabelSize::XSmall)
684                            .buffer_font(cx),
685                        ),
686                )
687                .map(|parent| {
688                    if !is_open {
689                        return parent;
690                    }
691
692                    parent.child(
693                        v_flex()
694                            .child(
695                                v_flex()
696                                    .gap_0p5()
697                                    .py_1()
698                                    .px_2p5()
699                                    .border_b_1()
700                                    .border_color(cx.theme().colors().border)
701                                    .child(Label::new("Input:"))
702                                    .child(Label::new(
703                                        serde_json::to_string_pretty(&tool_use.input)
704                                            .unwrap_or_default(),
705                                    )),
706                            )
707                            .map(|parent| match tool_use.status {
708                                ToolUseStatus::Finished(output) => parent.child(
709                                    v_flex()
710                                        .gap_0p5()
711                                        .py_1()
712                                        .px_2p5()
713                                        .child(Label::new("Result:"))
714                                        .child(Label::new(output)),
715                                ),
716                                ToolUseStatus::Error(err) => parent.child(
717                                    v_flex()
718                                        .gap_0p5()
719                                        .py_1()
720                                        .px_2p5()
721                                        .child(Label::new("Error:"))
722                                        .child(Label::new(err)),
723                                ),
724                                ToolUseStatus::Pending | ToolUseStatus::Running => parent,
725                            }),
726                    )
727                }),
728        )
729    }
730}
731
732impl Render for ActiveThread {
733    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
734        v_flex()
735            .size_full()
736            .child(list(self.list_state.clone()).flex_grow())
737    }
738}