active_thread.rs

  1use std::sync::Arc;
  2
  3use collections::HashMap;
  4use editor::{Editor, MultiBuffer};
  5use gpui::{
  6    list, AbsoluteLength, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty,
  7    Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
  8    Task, TextStyleRefinement, UnderlineStyle,
  9};
 10use language::{Buffer, LanguageRegistry};
 11use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
 12use markdown::{Markdown, MarkdownStyle};
 13use settings::Settings as _;
 14use theme::ThemeSettings;
 15use ui::{prelude::*, Disclosure, KeyBinding};
 16use util::ResultExt as _;
 17
 18use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
 19use crate::thread_store::ThreadStore;
 20use crate::tool_use::{ToolUse, ToolUseStatus};
 21use crate::ui::ContextPill;
 22
 23pub struct ActiveThread {
 24    language_registry: Arc<LanguageRegistry>,
 25    thread_store: Entity<ThreadStore>,
 26    thread: Entity<Thread>,
 27    save_thread_task: Option<Task<()>>,
 28    messages: Vec<MessageId>,
 29    list_state: ListState,
 30    rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
 31    editing_message: Option<(MessageId, EditMessageState)>,
 32    expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
 33    last_error: Option<ThreadError>,
 34    _subscriptions: Vec<Subscription>,
 35}
 36
 37struct EditMessageState {
 38    editor: Entity<Editor>,
 39}
 40
 41impl ActiveThread {
 42    pub fn new(
 43        thread: Entity<Thread>,
 44        thread_store: Entity<ThreadStore>,
 45        language_registry: Arc<LanguageRegistry>,
 46        window: &mut Window,
 47        cx: &mut Context<Self>,
 48    ) -> Self {
 49        let subscriptions = vec![
 50            cx.observe(&thread, |_, _, cx| cx.notify()),
 51            cx.subscribe_in(&thread, window, Self::handle_thread_event),
 52        ];
 53
 54        let mut this = Self {
 55            language_registry,
 56            thread_store,
 57            thread: thread.clone(),
 58            save_thread_task: None,
 59            messages: Vec::new(),
 60            rendered_messages_by_id: HashMap::default(),
 61            expanded_tool_uses: HashMap::default(),
 62            list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
 63                let this = cx.entity().downgrade();
 64                move |ix, window: &mut Window, cx: &mut App| {
 65                    this.update(cx, |this, cx| this.render_message(ix, window, cx))
 66                        .unwrap()
 67                }
 68            }),
 69            editing_message: None,
 70            last_error: None,
 71            _subscriptions: subscriptions,
 72        };
 73
 74        for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
 75            this.push_message(&message.id, message.text.clone(), window, cx);
 76        }
 77
 78        this
 79    }
 80
 81    pub fn thread(&self) -> &Entity<Thread> {
 82        &self.thread
 83    }
 84
 85    pub fn is_empty(&self) -> bool {
 86        self.messages.is_empty()
 87    }
 88
 89    pub fn summary(&self, cx: &App) -> Option<SharedString> {
 90        self.thread.read(cx).summary()
 91    }
 92
 93    pub fn summary_or_default(&self, cx: &App) -> SharedString {
 94        self.thread.read(cx).summary_or_default()
 95    }
 96
 97    pub fn cancel_last_completion(&mut self, cx: &mut App) -> bool {
 98        self.last_error.take();
 99        self.thread
100            .update(cx, |thread, _cx| thread.cancel_last_completion())
101    }
102
103    pub fn last_error(&self) -> Option<ThreadError> {
104        self.last_error.clone()
105    }
106
107    pub fn clear_last_error(&mut self) {
108        self.last_error.take();
109    }
110
111    fn push_message(
112        &mut self,
113        id: &MessageId,
114        text: String,
115        window: &mut Window,
116        cx: &mut Context<Self>,
117    ) {
118        let old_len = self.messages.len();
119        self.messages.push(*id);
120        self.list_state.splice(old_len..old_len, 1);
121
122        let markdown = self.render_markdown(text.into(), window, cx);
123        self.rendered_messages_by_id.insert(*id, markdown);
124        self.list_state.scroll_to(ListOffset {
125            item_ix: old_len,
126            offset_in_item: Pixels(0.0),
127        });
128    }
129
130    fn edited_message(
131        &mut self,
132        id: &MessageId,
133        text: String,
134        window: &mut Window,
135        cx: &mut Context<Self>,
136    ) {
137        let Some(index) = self.messages.iter().position(|message_id| message_id == id) else {
138            return;
139        };
140        self.list_state.splice(index..index + 1, 1);
141        let markdown = self.render_markdown(text.into(), window, cx);
142        self.rendered_messages_by_id.insert(*id, markdown);
143    }
144
145    fn deleted_message(&mut self, id: &MessageId) {
146        let Some(index) = self.messages.iter().position(|message_id| message_id == id) else {
147            return;
148        };
149        self.messages.remove(index);
150        self.list_state.splice(index..index + 1, 0);
151        self.rendered_messages_by_id.remove(id);
152    }
153
154    fn render_markdown(
155        &self,
156        text: SharedString,
157        window: &Window,
158        cx: &mut Context<Self>,
159    ) -> Entity<Markdown> {
160        let theme_settings = ThemeSettings::get_global(cx);
161        let colors = cx.theme().colors();
162        let ui_font_size = TextSize::Default.rems(cx);
163        let buffer_font_size = TextSize::Small.rems(cx);
164        let mut text_style = window.text_style();
165
166        text_style.refine(&TextStyleRefinement {
167            font_family: Some(theme_settings.ui_font.family.clone()),
168            font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
169            font_features: Some(theme_settings.ui_font.features.clone()),
170            font_size: Some(ui_font_size.into()),
171            color: Some(cx.theme().colors().text),
172            ..Default::default()
173        });
174
175        let markdown_style = MarkdownStyle {
176            base_text_style: text_style,
177            syntax: cx.theme().syntax().clone(),
178            selection_background_color: cx.theme().players().local().selection,
179            code_block_overflow_x_scroll: true,
180            table_overflow_x_scroll: true,
181            code_block: StyleRefinement {
182                margin: EdgesRefinement {
183                    top: Some(Length::Definite(rems(0.).into())),
184                    left: Some(Length::Definite(rems(0.).into())),
185                    right: Some(Length::Definite(rems(0.).into())),
186                    bottom: Some(Length::Definite(rems(0.5).into())),
187                },
188                padding: EdgesRefinement {
189                    top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
190                    left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
191                    right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
192                    bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
193                },
194                background: Some(colors.editor_background.into()),
195                border_color: Some(colors.border_variant),
196                border_widths: EdgesRefinement {
197                    top: Some(AbsoluteLength::Pixels(Pixels(1.))),
198                    left: Some(AbsoluteLength::Pixels(Pixels(1.))),
199                    right: Some(AbsoluteLength::Pixels(Pixels(1.))),
200                    bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
201                },
202                text: Some(TextStyleRefinement {
203                    font_family: Some(theme_settings.buffer_font.family.clone()),
204                    font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
205                    font_features: Some(theme_settings.buffer_font.features.clone()),
206                    font_size: Some(buffer_font_size.into()),
207                    ..Default::default()
208                }),
209                ..Default::default()
210            },
211            inline_code: TextStyleRefinement {
212                font_family: Some(theme_settings.buffer_font.family.clone()),
213                font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
214                font_features: Some(theme_settings.buffer_font.features.clone()),
215                font_size: Some(buffer_font_size.into()),
216                background_color: Some(colors.editor_foreground.opacity(0.1)),
217                ..Default::default()
218            },
219            link: TextStyleRefinement {
220                background_color: Some(colors.editor_foreground.opacity(0.025)),
221                underline: Some(UnderlineStyle {
222                    color: Some(colors.text_accent.opacity(0.5)),
223                    thickness: px(1.),
224                    ..Default::default()
225                }),
226                ..Default::default()
227            },
228            ..Default::default()
229        };
230
231        cx.new(|cx| {
232            Markdown::new(
233                text,
234                markdown_style,
235                Some(self.language_registry.clone()),
236                None,
237                cx,
238            )
239        })
240    }
241
242    fn handle_thread_event(
243        &mut self,
244        _: &Entity<Thread>,
245        event: &ThreadEvent,
246        window: &mut Window,
247        cx: &mut Context<Self>,
248    ) {
249        match event {
250            ThreadEvent::ShowError(error) => {
251                self.last_error = Some(error.clone());
252            }
253            ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => {
254                self.save_thread(cx);
255            }
256            ThreadEvent::StreamedAssistantText(message_id, text) => {
257                if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
258                    markdown.update(cx, |markdown, cx| {
259                        markdown.append(text, cx);
260                    });
261                }
262            }
263            ThreadEvent::MessageAdded(message_id) => {
264                if let Some(message_text) = self
265                    .thread
266                    .read(cx)
267                    .message(*message_id)
268                    .map(|message| message.text.clone())
269                {
270                    self.push_message(message_id, message_text, window, cx);
271                }
272
273                self.save_thread(cx);
274                cx.notify();
275            }
276            ThreadEvent::MessageEdited(message_id) => {
277                if let Some(message_text) = self
278                    .thread
279                    .read(cx)
280                    .message(*message_id)
281                    .map(|message| message.text.clone())
282                {
283                    self.edited_message(message_id, message_text, window, cx);
284                }
285
286                self.save_thread(cx);
287                cx.notify();
288            }
289            ThreadEvent::MessageDeleted(message_id) => {
290                self.deleted_message(message_id);
291                self.save_thread(cx);
292                cx.notify();
293            }
294            ThreadEvent::UsePendingTools => {
295                self.thread.update(cx, |thread, cx| {
296                    thread.use_pending_tools(cx);
297                });
298            }
299            ThreadEvent::ToolFinished { .. } => {
300                if self.thread.read(cx).all_tools_finished() {
301                    let model_registry = LanguageModelRegistry::read_global(cx);
302                    if let Some(model) = model_registry.active_model() {
303                        self.thread.update(cx, |thread, cx| {
304                            thread.send_tool_results_to_model(model, cx);
305                        });
306                    }
307                }
308            }
309        }
310    }
311
312    /// Spawns a task to save the active thread.
313    ///
314    /// Only one task to save the thread will be in flight at a time.
315    fn save_thread(&mut self, cx: &mut Context<Self>) {
316        let thread = self.thread.clone();
317        self.save_thread_task = Some(cx.spawn(|this, mut cx| async move {
318            let task = this
319                .update(&mut cx, |this, cx| {
320                    this.thread_store
321                        .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
322                })
323                .ok();
324
325            if let Some(task) = task {
326                task.await.log_err();
327            }
328        }));
329    }
330
331    fn start_editing_message(
332        &mut self,
333        message_id: MessageId,
334        message_text: String,
335        window: &mut Window,
336        cx: &mut Context<Self>,
337    ) {
338        let buffer = cx.new(|cx| {
339            MultiBuffer::singleton(cx.new(|cx| Buffer::local(message_text.clone(), cx)), cx)
340        });
341        let editor = cx.new(|cx| {
342            let mut editor = Editor::new(
343                editor::EditorMode::AutoHeight { max_lines: 8 },
344                buffer,
345                None,
346                false,
347                window,
348                cx,
349            );
350            editor.focus_handle(cx).focus(window);
351            editor.move_to_end(&editor::actions::MoveToEnd, window, cx);
352            editor
353        });
354        self.editing_message = Some((
355            message_id,
356            EditMessageState {
357                editor: editor.clone(),
358            },
359        ));
360        cx.notify();
361    }
362
363    fn cancel_editing_message(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
364        self.editing_message.take();
365        cx.notify();
366    }
367
368    fn confirm_editing_message(
369        &mut self,
370        _: &menu::Confirm,
371        _: &mut Window,
372        cx: &mut Context<Self>,
373    ) {
374        let Some((message_id, state)) = self.editing_message.take() else {
375            return;
376        };
377        let edited_text = state.editor.read(cx).text(cx);
378        self.thread.update(cx, |thread, cx| {
379            thread.edit_message(message_id, Role::User, edited_text, cx);
380            for message_id in self.messages_after(message_id) {
381                thread.delete_message(*message_id, cx);
382            }
383        });
384
385        let provider = LanguageModelRegistry::read_global(cx).active_provider();
386        if provider
387            .as_ref()
388            .map_or(false, |provider| provider.must_accept_terms(cx))
389        {
390            cx.notify();
391            return;
392        }
393        let model_registry = LanguageModelRegistry::read_global(cx);
394        let Some(model) = model_registry.active_model() else {
395            return;
396        };
397
398        self.thread.update(cx, |thread, cx| {
399            thread.send_to_model(model, RequestKind::Chat, false, cx)
400        });
401        cx.notify();
402    }
403
404    fn last_user_message(&self, cx: &Context<Self>) -> Option<MessageId> {
405        self.messages
406            .iter()
407            .rev()
408            .find(|message_id| {
409                self.thread
410                    .read(cx)
411                    .message(**message_id)
412                    .map_or(false, |message| message.role == Role::User)
413            })
414            .cloned()
415    }
416
417    fn messages_after(&self, message_id: MessageId) -> &[MessageId] {
418        self.messages
419            .iter()
420            .position(|id| *id == message_id)
421            .map(|index| &self.messages[index + 1..])
422            .unwrap_or(&[])
423    }
424
425    fn handle_cancel_click(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
426        self.cancel_editing_message(&menu::Cancel, window, cx);
427    }
428
429    fn handle_regenerate_click(
430        &mut self,
431        _: &ClickEvent,
432        window: &mut Window,
433        cx: &mut Context<Self>,
434    ) {
435        self.confirm_editing_message(&menu::Confirm, window, cx);
436    }
437
438    fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
439        let message_id = self.messages[ix];
440        let Some(message) = self.thread.read(cx).message(message_id) else {
441            return Empty.into_any();
442        };
443
444        let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
445            return Empty.into_any();
446        };
447
448        let context = self.thread.read(cx).context_for_message(message_id);
449        let tool_uses = self.thread.read(cx).tool_uses_for_message(message_id);
450        let colors = cx.theme().colors();
451
452        // Don't render user messages that are just there for returning tool results.
453        if message.role == Role::User && self.thread.read(cx).message_has_tool_results(message_id) {
454            return Empty.into_any();
455        }
456
457        let allow_editing_message =
458            message.role == Role::User && self.last_user_message(cx) == Some(message_id);
459
460        let edit_message_editor = self
461            .editing_message
462            .as_ref()
463            .filter(|(id, _)| *id == message_id)
464            .map(|(_, state)| state.editor.clone());
465
466        let message_content = v_flex()
467            .child(
468                if let Some(edit_message_editor) = edit_message_editor.clone() {
469                    div()
470                        .key_context("EditMessageEditor")
471                        .on_action(cx.listener(Self::cancel_editing_message))
472                        .on_action(cx.listener(Self::confirm_editing_message))
473                        .p_2p5()
474                        .child(edit_message_editor)
475                } else {
476                    div().p_2p5().text_ui(cx).child(markdown.clone())
477                },
478            )
479            .when_some(context, |parent, context| {
480                if !context.is_empty() {
481                    parent.child(
482                        h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
483                            context
484                                .into_iter()
485                                .map(|context| ContextPill::added(context, false, false, None)),
486                        ),
487                    )
488                } else {
489                    parent
490                }
491            });
492
493        let styled_message = match message.role {
494            Role::User => v_flex()
495                .id(("message-container", ix))
496                .pt_2p5()
497                .px_2p5()
498                .child(
499                    v_flex()
500                        .bg(colors.editor_background)
501                        .rounded_lg()
502                        .border_1()
503                        .border_color(colors.border)
504                        .shadow_sm()
505                        .child(
506                            h_flex()
507                                .py_1()
508                                .pl_2()
509                                .pr_1()
510                                .bg(colors.editor_foreground.opacity(0.05))
511                                .border_b_1()
512                                .border_color(colors.border)
513                                .justify_between()
514                                .rounded_t(px(6.))
515                                .child(
516                                    h_flex()
517                                        .gap_1p5()
518                                        .child(
519                                            Icon::new(IconName::PersonCircle)
520                                                .size(IconSize::XSmall)
521                                                .color(Color::Muted),
522                                        )
523                                        .child(
524                                            Label::new("You")
525                                                .size(LabelSize::Small)
526                                                .color(Color::Muted),
527                                        ),
528                                )
529                                .when_some(
530                                    edit_message_editor.clone(),
531                                    |this, edit_message_editor| {
532                                        let focus_handle = edit_message_editor.focus_handle(cx);
533                                        this.child(
534                                            h_flex()
535                                                .gap_1()
536                                                .child(
537                                                    Button::new("cancel-edit-message", "Cancel")
538                                                        .label_size(LabelSize::Small)
539                                                        .key_binding(
540                                                            KeyBinding::for_action_in(
541                                                                &menu::Cancel,
542                                                                &focus_handle,
543                                                                window,
544                                                                cx,
545                                                            )
546                                                            .map(|kb| kb.size(rems_from_px(12.))),
547                                                        )
548                                                        .on_click(
549                                                            cx.listener(Self::handle_cancel_click),
550                                                        ),
551                                                )
552                                                .child(
553                                                    Button::new(
554                                                        "confirm-edit-message",
555                                                        "Regenerate",
556                                                    )
557                                                    .label_size(LabelSize::Small)
558                                                    .key_binding(
559                                                        KeyBinding::for_action_in(
560                                                            &menu::Confirm,
561                                                            &focus_handle,
562                                                            window,
563                                                            cx,
564                                                        )
565                                                        .map(|kb| kb.size(rems_from_px(12.))),
566                                                    )
567                                                    .on_click(
568                                                        cx.listener(Self::handle_regenerate_click),
569                                                    ),
570                                                ),
571                                        )
572                                    },
573                                )
574                                .when(
575                                    edit_message_editor.is_none() && allow_editing_message,
576                                    |this| {
577                                        this.child(
578                                            Button::new("edit-message", "Edit")
579                                                .label_size(LabelSize::Small)
580                                                .on_click(cx.listener({
581                                                    let message_text = message.text.clone();
582                                                    move |this, _, window, cx| {
583                                                        this.start_editing_message(
584                                                            message_id,
585                                                            message_text.clone(),
586                                                            window,
587                                                            cx,
588                                                        );
589                                                    }
590                                                })),
591                                        )
592                                    },
593                                ),
594                        )
595                        .child(message_content),
596                ),
597            Role::Assistant => div()
598                .id(("message-container", ix))
599                .child(message_content)
600                .map(|parent| {
601                    if tool_uses.is_empty() {
602                        return parent;
603                    }
604
605                    parent.child(
606                        v_flex().children(
607                            tool_uses
608                                .into_iter()
609                                .map(|tool_use| self.render_tool_use(tool_use, cx)),
610                        ),
611                    )
612                }),
613            Role::System => div().id(("message-container", ix)).py_1().px_2().child(
614                v_flex()
615                    .bg(colors.editor_background)
616                    .rounded_sm()
617                    .child(message_content),
618            ),
619        };
620
621        styled_message.into_any()
622    }
623
624    fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {
625        let is_open = self
626            .expanded_tool_uses
627            .get(&tool_use.id)
628            .copied()
629            .unwrap_or_default();
630
631        div().px_2p5().child(
632            v_flex()
633                .gap_1()
634                .rounded_lg()
635                .border_1()
636                .border_color(cx.theme().colors().border)
637                .child(
638                    h_flex()
639                        .justify_between()
640                        .py_0p5()
641                        .pl_1()
642                        .pr_2()
643                        .bg(cx.theme().colors().editor_foreground.opacity(0.02))
644                        .when(is_open, |element| element.border_b_1().rounded_t(px(6.)))
645                        .when(!is_open, |element| element.rounded_md())
646                        .border_color(cx.theme().colors().border)
647                        .child(
648                            h_flex()
649                                .gap_1()
650                                .child(Disclosure::new("tool-use-disclosure", is_open).on_click(
651                                    cx.listener({
652                                        let tool_use_id = tool_use.id.clone();
653                                        move |this, _event, _window, _cx| {
654                                            let is_open = this
655                                                .expanded_tool_uses
656                                                .entry(tool_use_id.clone())
657                                                .or_insert(false);
658
659                                            *is_open = !*is_open;
660                                        }
661                                    }),
662                                ))
663                                .child(Label::new(tool_use.name)),
664                        )
665                        .child(
666                            Label::new(match tool_use.status {
667                                ToolUseStatus::Pending => "Pending",
668                                ToolUseStatus::Running => "Running",
669                                ToolUseStatus::Finished(_) => "Finished",
670                                ToolUseStatus::Error(_) => "Error",
671                            })
672                            .size(LabelSize::XSmall)
673                            .buffer_font(cx),
674                        ),
675                )
676                .map(|parent| {
677                    if !is_open {
678                        return parent;
679                    }
680
681                    parent.child(
682                        v_flex()
683                            .child(
684                                v_flex()
685                                    .gap_0p5()
686                                    .py_1()
687                                    .px_2p5()
688                                    .border_b_1()
689                                    .border_color(cx.theme().colors().border)
690                                    .child(Label::new("Input:"))
691                                    .child(Label::new(
692                                        serde_json::to_string_pretty(&tool_use.input)
693                                            .unwrap_or_default(),
694                                    )),
695                            )
696                            .map(|parent| match tool_use.status {
697                                ToolUseStatus::Finished(output) => parent.child(
698                                    v_flex()
699                                        .gap_0p5()
700                                        .py_1()
701                                        .px_2p5()
702                                        .child(Label::new("Result:"))
703                                        .child(Label::new(output)),
704                                ),
705                                ToolUseStatus::Error(err) => parent.child(
706                                    v_flex()
707                                        .gap_0p5()
708                                        .py_1()
709                                        .px_2p5()
710                                        .child(Label::new("Error:"))
711                                        .child(Label::new(err)),
712                                ),
713                                ToolUseStatus::Pending | ToolUseStatus::Running => parent,
714                            }),
715                    )
716                }),
717        )
718    }
719}
720
721impl Render for ActiveThread {
722    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
723        v_flex()
724            .size_full()
725            .child(list(self.list_state.clone()).flex_grow())
726    }
727}