active_thread.rs

  1use std::sync::Arc;
  2
  3use assistant_tool::ToolWorkingSet;
  4use collections::HashMap;
  5use gpui::{
  6    list, AbsoluteLength, AnyElement, App, DefiniteLength, EdgesRefinement, Empty, Entity, Length,
  7    ListAlignment, ListOffset, ListState, StyleRefinement, Subscription, TextStyleRefinement,
  8    UnderlineStyle, WeakEntity,
  9};
 10use language::LanguageRegistry;
 11use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
 12use markdown::{Markdown, MarkdownStyle};
 13use settings::Settings as _;
 14use theme::ThemeSettings;
 15use ui::{prelude::*, Disclosure};
 16use workspace::Workspace;
 17
 18use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
 19use crate::thread_store::ThreadStore;
 20use crate::tool_use::{ToolUse, ToolUseStatus};
 21use crate::ui::ContextPill;
 22
 23pub struct ActiveThread {
 24    workspace: WeakEntity<Workspace>,
 25    language_registry: Arc<LanguageRegistry>,
 26    tools: Arc<ToolWorkingSet>,
 27    thread_store: Entity<ThreadStore>,
 28    thread: Entity<Thread>,
 29    messages: Vec<MessageId>,
 30    list_state: ListState,
 31    rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
 32    expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
 33    last_error: Option<ThreadError>,
 34    _subscriptions: Vec<Subscription>,
 35}
 36
 37impl ActiveThread {
 38    pub fn new(
 39        thread: Entity<Thread>,
 40        thread_store: Entity<ThreadStore>,
 41        workspace: WeakEntity<Workspace>,
 42        language_registry: Arc<LanguageRegistry>,
 43        tools: Arc<ToolWorkingSet>,
 44        window: &mut Window,
 45        cx: &mut Context<Self>,
 46    ) -> Self {
 47        let subscriptions = vec![
 48            cx.observe(&thread, |_, _, cx| cx.notify()),
 49            cx.subscribe_in(&thread, window, Self::handle_thread_event),
 50        ];
 51
 52        let mut this = Self {
 53            workspace,
 54            language_registry,
 55            tools,
 56            thread_store,
 57            thread: thread.clone(),
 58            messages: Vec::new(),
 59            rendered_messages_by_id: HashMap::default(),
 60            expanded_tool_uses: HashMap::default(),
 61            list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
 62                let this = cx.entity().downgrade();
 63                move |ix, _: &mut Window, cx: &mut App| {
 64                    this.update(cx, |this, cx| this.render_message(ix, cx))
 65                        .unwrap()
 66                }
 67            }),
 68            last_error: None,
 69            _subscriptions: subscriptions,
 70        };
 71
 72        for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
 73            this.push_message(&message.id, message.text.clone(), window, cx);
 74        }
 75
 76        this
 77    }
 78
 79    pub fn thread(&self) -> &Entity<Thread> {
 80        &self.thread
 81    }
 82
 83    pub fn is_empty(&self) -> bool {
 84        self.messages.is_empty()
 85    }
 86
 87    pub fn summary(&self, cx: &App) -> Option<SharedString> {
 88        self.thread.read(cx).summary()
 89    }
 90
 91    pub fn summary_or_default(&self, cx: &App) -> SharedString {
 92        self.thread.read(cx).summary_or_default()
 93    }
 94
 95    pub fn cancel_last_completion(&mut self, cx: &mut App) -> bool {
 96        self.last_error.take();
 97        self.thread
 98            .update(cx, |thread, _cx| thread.cancel_last_completion())
 99    }
100
101    pub fn last_error(&self) -> Option<ThreadError> {
102        self.last_error.clone()
103    }
104
105    pub fn clear_last_error(&mut self) {
106        self.last_error.take();
107    }
108
109    fn push_message(
110        &mut self,
111        id: &MessageId,
112        text: String,
113        window: &mut Window,
114        cx: &mut Context<Self>,
115    ) {
116        let old_len = self.messages.len();
117        self.messages.push(*id);
118        self.list_state.splice(old_len..old_len, 1);
119
120        let theme_settings = ThemeSettings::get_global(cx);
121        let colors = cx.theme().colors();
122        let ui_font_size = TextSize::Default.rems(cx);
123        let buffer_font_size = TextSize::Small.rems(cx);
124        let mut text_style = window.text_style();
125
126        text_style.refine(&TextStyleRefinement {
127            font_family: Some(theme_settings.ui_font.family.clone()),
128            font_size: Some(ui_font_size.into()),
129            color: Some(cx.theme().colors().text),
130            ..Default::default()
131        });
132
133        let markdown_style = MarkdownStyle {
134            base_text_style: text_style,
135            syntax: cx.theme().syntax().clone(),
136            selection_background_color: cx.theme().players().local().selection,
137            code_block_overflow_x_scroll: true,
138            table_overflow_x_scroll: true,
139            code_block: StyleRefinement {
140                margin: EdgesRefinement {
141                    top: Some(Length::Definite(rems(0.).into())),
142                    left: Some(Length::Definite(rems(0.).into())),
143                    right: Some(Length::Definite(rems(0.).into())),
144                    bottom: Some(Length::Definite(rems(0.5).into())),
145                },
146                padding: EdgesRefinement {
147                    top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
148                    left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
149                    right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
150                    bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
151                },
152                background: Some(colors.editor_background.into()),
153                border_color: Some(colors.border_variant),
154                border_widths: EdgesRefinement {
155                    top: Some(AbsoluteLength::Pixels(Pixels(1.))),
156                    left: Some(AbsoluteLength::Pixels(Pixels(1.))),
157                    right: Some(AbsoluteLength::Pixels(Pixels(1.))),
158                    bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
159                },
160                text: Some(TextStyleRefinement {
161                    font_family: Some(theme_settings.buffer_font.family.clone()),
162                    font_size: Some(buffer_font_size.into()),
163                    ..Default::default()
164                }),
165                ..Default::default()
166            },
167            inline_code: TextStyleRefinement {
168                font_family: Some(theme_settings.buffer_font.family.clone()),
169                font_size: Some(buffer_font_size.into()),
170                background_color: Some(colors.editor_foreground.opacity(0.1)),
171                ..Default::default()
172            },
173            link: TextStyleRefinement {
174                background_color: Some(colors.editor_foreground.opacity(0.025)),
175                underline: Some(UnderlineStyle {
176                    color: Some(colors.text_accent.opacity(0.5)),
177                    thickness: px(1.),
178                    ..Default::default()
179                }),
180                ..Default::default()
181            },
182            ..Default::default()
183        };
184
185        let markdown = cx.new(|cx| {
186            Markdown::new(
187                text.into(),
188                markdown_style,
189                Some(self.language_registry.clone()),
190                None,
191                cx,
192            )
193        });
194        self.rendered_messages_by_id.insert(*id, markdown);
195        self.list_state.scroll_to(ListOffset {
196            item_ix: old_len,
197            offset_in_item: Pixels(0.0),
198        });
199    }
200
201    fn handle_thread_event(
202        &mut self,
203        _: &Entity<Thread>,
204        event: &ThreadEvent,
205        window: &mut Window,
206        cx: &mut Context<Self>,
207    ) {
208        match event {
209            ThreadEvent::ShowError(error) => {
210                self.last_error = Some(error.clone());
211            }
212            ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => {
213                self.thread_store
214                    .update(cx, |thread_store, cx| {
215                        thread_store.save_thread(&self.thread, cx)
216                    })
217                    .detach_and_log_err(cx);
218            }
219            ThreadEvent::StreamedAssistantText(message_id, text) => {
220                if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
221                    markdown.update(cx, |markdown, cx| {
222                        markdown.append(text, cx);
223                    });
224                }
225            }
226            ThreadEvent::MessageAdded(message_id) => {
227                if let Some(message_text) = self
228                    .thread
229                    .read(cx)
230                    .message(*message_id)
231                    .map(|message| message.text.clone())
232                {
233                    self.push_message(message_id, message_text, window, cx);
234                }
235
236                self.thread_store
237                    .update(cx, |thread_store, cx| {
238                        thread_store.save_thread(&self.thread, cx)
239                    })
240                    .detach_and_log_err(cx);
241
242                cx.notify();
243            }
244            ThreadEvent::UsePendingTools => {
245                let pending_tool_uses = self
246                    .thread
247                    .read(cx)
248                    .pending_tool_uses()
249                    .into_iter()
250                    .filter(|tool_use| tool_use.status.is_idle())
251                    .cloned()
252                    .collect::<Vec<_>>();
253
254                for tool_use in pending_tool_uses {
255                    if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
256                        let task = tool.run(tool_use.input, self.workspace.clone(), window, cx);
257
258                        self.thread.update(cx, |thread, cx| {
259                            thread.insert_tool_output(tool_use.id.clone(), task, cx);
260                        });
261                    }
262                }
263            }
264            ThreadEvent::ToolFinished { .. } => {
265                let all_tools_finished = self
266                    .thread
267                    .read(cx)
268                    .pending_tool_uses()
269                    .into_iter()
270                    .all(|tool_use| tool_use.status.is_error());
271                if all_tools_finished {
272                    let model_registry = LanguageModelRegistry::read_global(cx);
273                    if let Some(model) = model_registry.active_model() {
274                        self.thread.update(cx, |thread, cx| {
275                            // Insert a user message to contain the tool results.
276                            thread.insert_user_message(
277                                // TODO: Sending up a user message without any content results in the model sending back
278                                // responses that also don't have any content. We currently don't handle this case well,
279                                // so for now we provide some text to keep the model on track.
280                                "Here are the tool results.",
281                                Vec::new(),
282                                cx,
283                            );
284                            thread.send_to_model(model, RequestKind::Chat, true, cx);
285                        });
286                    }
287                }
288            }
289        }
290    }
291
292    fn render_message(&self, ix: usize, cx: &mut Context<Self>) -> AnyElement {
293        let message_id = self.messages[ix];
294        let Some(message) = self.thread.read(cx).message(message_id) else {
295            return Empty.into_any();
296        };
297
298        let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
299            return Empty.into_any();
300        };
301
302        let context = self.thread.read(cx).context_for_message(message_id);
303        let tool_uses = self.thread.read(cx).tool_uses_for_message(message_id);
304        let colors = cx.theme().colors();
305
306        // Don't render user messages that are just there for returning tool results.
307        if message.role == Role::User && self.thread.read(cx).message_has_tool_results(message_id) {
308            return Empty.into_any();
309        }
310
311        let message_content = v_flex()
312            .child(div().p_2p5().text_ui(cx).child(markdown.clone()))
313            .when_some(context, |parent, context| {
314                if !context.is_empty() {
315                    parent.child(
316                        h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
317                            context
318                                .into_iter()
319                                .map(|context| ContextPill::added(context, false, false, None)),
320                        ),
321                    )
322                } else {
323                    parent
324                }
325            });
326
327        let styled_message = match message.role {
328            Role::User => v_flex()
329                .id(("message-container", ix))
330                .pt_2p5()
331                .px_2p5()
332                .child(
333                    v_flex()
334                        .bg(colors.editor_background)
335                        .rounded_lg()
336                        .border_1()
337                        .border_color(colors.border)
338                        .shadow_sm()
339                        .child(
340                            h_flex()
341                                .py_1()
342                                .px_2()
343                                .bg(colors.editor_foreground.opacity(0.05))
344                                .border_b_1()
345                                .border_color(colors.border)
346                                .justify_between()
347                                .rounded_t(px(6.))
348                                .child(
349                                    h_flex()
350                                        .gap_1p5()
351                                        .child(
352                                            Icon::new(IconName::PersonCircle)
353                                                .size(IconSize::XSmall)
354                                                .color(Color::Muted),
355                                        )
356                                        .child(
357                                            Label::new("You")
358                                                .size(LabelSize::Small)
359                                                .color(Color::Muted),
360                                        ),
361                                ),
362                        )
363                        .child(message_content),
364                ),
365            Role::Assistant => div()
366                .id(("message-container", ix))
367                .child(message_content)
368                .map(|parent| {
369                    if tool_uses.is_empty() {
370                        return parent;
371                    }
372
373                    parent.child(
374                        v_flex().children(
375                            tool_uses
376                                .into_iter()
377                                .map(|tool_use| self.render_tool_use(tool_use, cx)),
378                        ),
379                    )
380                }),
381            Role::System => div().id(("message-container", ix)).py_1().px_2().child(
382                v_flex()
383                    .bg(colors.editor_background)
384                    .rounded_md()
385                    .child(message_content),
386            ),
387        };
388
389        styled_message.into_any()
390    }
391
392    fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {
393        let is_open = self
394            .expanded_tool_uses
395            .get(&tool_use.id)
396            .copied()
397            .unwrap_or_default();
398
399        div().px_2p5().child(
400            v_flex()
401                .gap_1()
402                .rounded_lg()
403                .border_1()
404                .border_color(cx.theme().colors().border)
405                .child(
406                    h_flex()
407                        .justify_between()
408                        .py_0p5()
409                        .pl_1()
410                        .pr_2()
411                        .bg(cx.theme().colors().editor_foreground.opacity(0.02))
412                        .when(is_open, |element| element.border_b_1().rounded_t(px(6.)))
413                        .when(!is_open, |element| element.rounded(px(6.)))
414                        .border_color(cx.theme().colors().border)
415                        .child(
416                            h_flex()
417                                .gap_1()
418                                .child(Disclosure::new("tool-use-disclosure", is_open).on_click(
419                                    cx.listener({
420                                        let tool_use_id = tool_use.id.clone();
421                                        move |this, _event, _window, _cx| {
422                                            let is_open = this
423                                                .expanded_tool_uses
424                                                .entry(tool_use_id.clone())
425                                                .or_insert(false);
426
427                                            *is_open = !*is_open;
428                                        }
429                                    }),
430                                ))
431                                .child(Label::new(tool_use.name)),
432                        )
433                        .child(
434                            Label::new(match tool_use.status {
435                                ToolUseStatus::Pending => "Pending",
436                                ToolUseStatus::Running => "Running",
437                                ToolUseStatus::Finished(_) => "Finished",
438                                ToolUseStatus::Error(_) => "Error",
439                            })
440                            .size(LabelSize::XSmall)
441                            .buffer_font(cx),
442                        ),
443                )
444                .map(|parent| {
445                    if !is_open {
446                        return parent;
447                    }
448
449                    parent.child(
450                        v_flex()
451                            .child(
452                                v_flex()
453                                    .gap_0p5()
454                                    .py_1()
455                                    .px_2p5()
456                                    .border_b_1()
457                                    .border_color(cx.theme().colors().border)
458                                    .child(Label::new("Input:"))
459                                    .child(Label::new(
460                                        serde_json::to_string_pretty(&tool_use.input)
461                                            .unwrap_or_default(),
462                                    )),
463                            )
464                            .map(|parent| match tool_use.status {
465                                ToolUseStatus::Finished(output) => parent.child(
466                                    v_flex()
467                                        .gap_0p5()
468                                        .py_1()
469                                        .px_2p5()
470                                        .child(Label::new("Result:"))
471                                        .child(Label::new(output)),
472                                ),
473                                ToolUseStatus::Error(err) => parent.child(
474                                    v_flex()
475                                        .gap_0p5()
476                                        .py_1()
477                                        .px_2p5()
478                                        .child(Label::new("Error:"))
479                                        .child(Label::new(err)),
480                                ),
481                                ToolUseStatus::Pending | ToolUseStatus::Running => parent,
482                            }),
483                    )
484                }),
485        )
486    }
487}
488
489impl Render for ActiveThread {
490    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
491        v_flex()
492            .size_full()
493            .child(list(self.list_state.clone()).flex_grow())
494    }
495}