active_thread.rs

  1use std::sync::Arc;
  2
  3use assistant_tool::ToolWorkingSet;
  4use collections::HashMap;
  5use gpui::{
  6    list, AbsoluteLength, AnyElement, App, DefiniteLength, EdgesRefinement, Empty, Entity, Length,
  7    ListAlignment, ListOffset, ListState, StyleRefinement, Subscription, TextStyleRefinement,
  8    UnderlineStyle, WeakEntity,
  9};
 10use language::LanguageRegistry;
 11use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
 12use markdown::{Markdown, MarkdownStyle};
 13use settings::Settings as _;
 14use theme::ThemeSettings;
 15use ui::{prelude::*, Disclosure};
 16use workspace::Workspace;
 17
 18use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
 19use crate::thread_store::ThreadStore;
 20use crate::tool_use::{ToolUse, ToolUseStatus};
 21use crate::ui::ContextPill;
 22
 23pub struct ActiveThread {
 24    workspace: WeakEntity<Workspace>,
 25    language_registry: Arc<LanguageRegistry>,
 26    tools: Arc<ToolWorkingSet>,
 27    thread_store: Entity<ThreadStore>,
 28    thread: Entity<Thread>,
 29    messages: Vec<MessageId>,
 30    list_state: ListState,
 31    rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
 32    expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
 33    last_error: Option<ThreadError>,
 34    _subscriptions: Vec<Subscription>,
 35}
 36
 37impl ActiveThread {
 38    pub fn new(
 39        thread: Entity<Thread>,
 40        thread_store: Entity<ThreadStore>,
 41        workspace: WeakEntity<Workspace>,
 42        language_registry: Arc<LanguageRegistry>,
 43        tools: Arc<ToolWorkingSet>,
 44        window: &mut Window,
 45        cx: &mut Context<Self>,
 46    ) -> Self {
 47        let subscriptions = vec![
 48            cx.observe(&thread, |_, _, cx| cx.notify()),
 49            cx.subscribe_in(&thread, window, Self::handle_thread_event),
 50        ];
 51
 52        let mut this = Self {
 53            workspace,
 54            language_registry,
 55            tools,
 56            thread_store,
 57            thread: thread.clone(),
 58            messages: Vec::new(),
 59            rendered_messages_by_id: HashMap::default(),
 60            expanded_tool_uses: HashMap::default(),
 61            list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
 62                let this = cx.entity().downgrade();
 63                move |ix, _: &mut Window, cx: &mut App| {
 64                    this.update(cx, |this, cx| this.render_message(ix, cx))
 65                        .unwrap()
 66                }
 67            }),
 68            last_error: None,
 69            _subscriptions: subscriptions,
 70        };
 71
 72        for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
 73            this.push_message(&message.id, message.text.clone(), window, cx);
 74        }
 75
 76        this
 77    }
 78
 79    pub fn thread(&self) -> &Entity<Thread> {
 80        &self.thread
 81    }
 82
 83    pub fn is_empty(&self) -> bool {
 84        self.messages.is_empty()
 85    }
 86
 87    pub fn summary(&self, cx: &App) -> Option<SharedString> {
 88        self.thread.read(cx).summary()
 89    }
 90
 91    pub fn summary_or_default(&self, cx: &App) -> SharedString {
 92        self.thread.read(cx).summary_or_default()
 93    }
 94
 95    pub fn cancel_last_completion(&mut self, cx: &mut App) -> bool {
 96        self.last_error.take();
 97        self.thread
 98            .update(cx, |thread, _cx| thread.cancel_last_completion())
 99    }
100
101    pub fn last_error(&self) -> Option<ThreadError> {
102        self.last_error.clone()
103    }
104
105    pub fn clear_last_error(&mut self) {
106        self.last_error.take();
107    }
108
109    fn push_message(
110        &mut self,
111        id: &MessageId,
112        text: String,
113        window: &mut Window,
114        cx: &mut Context<Self>,
115    ) {
116        let old_len = self.messages.len();
117        self.messages.push(*id);
118        self.list_state.splice(old_len..old_len, 1);
119
120        let theme_settings = ThemeSettings::get_global(cx);
121        let colors = cx.theme().colors();
122        let ui_font_size = TextSize::Default.rems(cx);
123        let buffer_font_size = TextSize::Small.rems(cx);
124        let mut text_style = window.text_style();
125
126        text_style.refine(&TextStyleRefinement {
127            font_family: Some(theme_settings.ui_font.family.clone()),
128            font_size: Some(ui_font_size.into()),
129            color: Some(cx.theme().colors().text),
130            ..Default::default()
131        });
132
133        let markdown_style = MarkdownStyle {
134            base_text_style: text_style,
135            syntax: cx.theme().syntax().clone(),
136            selection_background_color: cx.theme().players().local().selection,
137            code_block: StyleRefinement {
138                margin: EdgesRefinement {
139                    top: Some(Length::Definite(rems(0.).into())),
140                    left: Some(Length::Definite(rems(0.).into())),
141                    right: Some(Length::Definite(rems(0.).into())),
142                    bottom: Some(Length::Definite(rems(0.5).into())),
143                },
144                padding: EdgesRefinement {
145                    top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
146                    left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
147                    right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
148                    bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
149                },
150                background: Some(colors.editor_background.into()),
151                border_color: Some(colors.border_variant),
152                border_widths: EdgesRefinement {
153                    top: Some(AbsoluteLength::Pixels(Pixels(1.))),
154                    left: Some(AbsoluteLength::Pixels(Pixels(1.))),
155                    right: Some(AbsoluteLength::Pixels(Pixels(1.))),
156                    bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
157                },
158                text: Some(TextStyleRefinement {
159                    font_family: Some(theme_settings.buffer_font.family.clone()),
160                    font_size: Some(buffer_font_size.into()),
161                    ..Default::default()
162                }),
163                ..Default::default()
164            },
165            inline_code: TextStyleRefinement {
166                font_family: Some(theme_settings.buffer_font.family.clone()),
167                font_size: Some(buffer_font_size.into()),
168                background_color: Some(colors.editor_foreground.opacity(0.1)),
169                ..Default::default()
170            },
171            link: TextStyleRefinement {
172                background_color: Some(colors.editor_foreground.opacity(0.025)),
173                underline: Some(UnderlineStyle {
174                    color: Some(colors.text_accent.opacity(0.5)),
175                    thickness: px(1.),
176                    ..Default::default()
177                }),
178                ..Default::default()
179            },
180            ..Default::default()
181        };
182
183        let markdown = cx.new(|cx| {
184            Markdown::new(
185                text.into(),
186                markdown_style,
187                Some(self.language_registry.clone()),
188                None,
189                cx,
190            )
191        });
192        self.rendered_messages_by_id.insert(*id, markdown);
193        self.list_state.scroll_to(ListOffset {
194            item_ix: old_len,
195            offset_in_item: Pixels(0.0),
196        });
197    }
198
199    fn handle_thread_event(
200        &mut self,
201        _: &Entity<Thread>,
202        event: &ThreadEvent,
203        window: &mut Window,
204        cx: &mut Context<Self>,
205    ) {
206        match event {
207            ThreadEvent::ShowError(error) => {
208                self.last_error = Some(error.clone());
209            }
210            ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => {
211                self.thread_store
212                    .update(cx, |thread_store, cx| {
213                        thread_store.save_thread(&self.thread, cx)
214                    })
215                    .detach_and_log_err(cx);
216            }
217            ThreadEvent::StreamedAssistantText(message_id, text) => {
218                if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
219                    markdown.update(cx, |markdown, cx| {
220                        markdown.append(text, cx);
221                    });
222                }
223            }
224            ThreadEvent::MessageAdded(message_id) => {
225                if let Some(message_text) = self
226                    .thread
227                    .read(cx)
228                    .message(*message_id)
229                    .map(|message| message.text.clone())
230                {
231                    self.push_message(message_id, message_text, window, cx);
232                }
233
234                self.thread_store
235                    .update(cx, |thread_store, cx| {
236                        thread_store.save_thread(&self.thread, cx)
237                    })
238                    .detach_and_log_err(cx);
239
240                cx.notify();
241            }
242            ThreadEvent::UsePendingTools => {
243                let pending_tool_uses = self
244                    .thread
245                    .read(cx)
246                    .pending_tool_uses()
247                    .into_iter()
248                    .filter(|tool_use| tool_use.status.is_idle())
249                    .cloned()
250                    .collect::<Vec<_>>();
251
252                for tool_use in pending_tool_uses {
253                    if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
254                        let task = tool.run(tool_use.input, self.workspace.clone(), window, cx);
255
256                        self.thread.update(cx, |thread, cx| {
257                            thread.insert_tool_output(tool_use.id.clone(), task, cx);
258                        });
259                    }
260                }
261            }
262            ThreadEvent::ToolFinished { .. } => {
263                let all_tools_finished = self
264                    .thread
265                    .read(cx)
266                    .pending_tool_uses()
267                    .into_iter()
268                    .all(|tool_use| tool_use.status.is_error());
269                if all_tools_finished {
270                    let model_registry = LanguageModelRegistry::read_global(cx);
271                    if let Some(model) = model_registry.active_model() {
272                        self.thread.update(cx, |thread, cx| {
273                            // Insert a user message to contain the tool results.
274                            thread.insert_user_message(
275                                // TODO: Sending up a user message without any content results in the model sending back
276                                // responses that also don't have any content. We currently don't handle this case well,
277                                // so for now we provide some text to keep the model on track.
278                                "Here are the tool results.",
279                                Vec::new(),
280                                cx,
281                            );
282                            thread.send_to_model(model, RequestKind::Chat, true, cx);
283                        });
284                    }
285                }
286            }
287        }
288    }
289
290    fn render_message(&self, ix: usize, cx: &mut Context<Self>) -> AnyElement {
291        let message_id = self.messages[ix];
292        let Some(message) = self.thread.read(cx).message(message_id) else {
293            return Empty.into_any();
294        };
295
296        let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
297            return Empty.into_any();
298        };
299
300        let context = self.thread.read(cx).context_for_message(message_id);
301        let tool_uses = self.thread.read(cx).tool_uses_for_message(message_id);
302        let colors = cx.theme().colors();
303
304        // Don't render user messages that are just there for returning tool results.
305        if message.role == Role::User && self.thread.read(cx).message_has_tool_results(message_id) {
306            return Empty.into_any();
307        }
308
309        let message_content = v_flex()
310            .child(div().p_2p5().text_ui(cx).child(markdown.clone()))
311            .when_some(context, |parent, context| {
312                if !context.is_empty() {
313                    parent.child(
314                        h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
315                            context
316                                .into_iter()
317                                .map(|context| ContextPill::added(context, false, false, None)),
318                        ),
319                    )
320                } else {
321                    parent
322                }
323            });
324
325        let styled_message = match message.role {
326            Role::User => v_flex()
327                .id(("message-container", ix))
328                .pt_2p5()
329                .px_2p5()
330                .child(
331                    v_flex()
332                        .bg(colors.editor_background)
333                        .rounded_lg()
334                        .border_1()
335                        .border_color(colors.border)
336                        .shadow_sm()
337                        .child(
338                            h_flex()
339                                .py_1()
340                                .px_2()
341                                .bg(colors.editor_foreground.opacity(0.05))
342                                .border_b_1()
343                                .border_color(colors.border)
344                                .justify_between()
345                                .rounded_t(px(6.))
346                                .child(
347                                    h_flex()
348                                        .gap_1p5()
349                                        .child(
350                                            Icon::new(IconName::PersonCircle)
351                                                .size(IconSize::XSmall)
352                                                .color(Color::Muted),
353                                        )
354                                        .child(
355                                            Label::new("You")
356                                                .size(LabelSize::Small)
357                                                .color(Color::Muted),
358                                        ),
359                                ),
360                        )
361                        .child(message_content),
362                ),
363            Role::Assistant => div()
364                .id(("message-container", ix))
365                .child(message_content)
366                .map(|parent| {
367                    if tool_uses.is_empty() {
368                        return parent;
369                    }
370
371                    parent.child(
372                        v_flex().children(
373                            tool_uses
374                                .into_iter()
375                                .map(|tool_use| self.render_tool_use(tool_use, cx)),
376                        ),
377                    )
378                }),
379            Role::System => div().id(("message-container", ix)).py_1().px_2().child(
380                v_flex()
381                    .bg(colors.editor_background)
382                    .rounded_md()
383                    .child(message_content),
384            ),
385        };
386
387        styled_message.into_any()
388    }
389
390    fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {
391        let is_open = self
392            .expanded_tool_uses
393            .get(&tool_use.id)
394            .copied()
395            .unwrap_or_default();
396
397        div().px_2p5().child(
398            v_flex()
399                .gap_1()
400                .rounded_lg()
401                .border_1()
402                .border_color(cx.theme().colors().border)
403                .child(
404                    h_flex()
405                        .justify_between()
406                        .py_0p5()
407                        .pl_1()
408                        .pr_2()
409                        .bg(cx.theme().colors().editor_foreground.opacity(0.02))
410                        .when(is_open, |element| element.border_b_1().rounded_t(px(6.)))
411                        .when(!is_open, |element| element.rounded(px(6.)))
412                        .border_color(cx.theme().colors().border)
413                        .child(
414                            h_flex()
415                                .gap_1()
416                                .child(Disclosure::new("tool-use-disclosure", is_open).on_click(
417                                    cx.listener({
418                                        let tool_use_id = tool_use.id.clone();
419                                        move |this, _event, _window, _cx| {
420                                            let is_open = this
421                                                .expanded_tool_uses
422                                                .entry(tool_use_id.clone())
423                                                .or_insert(false);
424
425                                            *is_open = !*is_open;
426                                        }
427                                    }),
428                                ))
429                                .child(Label::new(tool_use.name)),
430                        )
431                        .child(
432                            Label::new(match tool_use.status {
433                                ToolUseStatus::Pending => "Pending",
434                                ToolUseStatus::Running => "Running",
435                                ToolUseStatus::Finished(_) => "Finished",
436                                ToolUseStatus::Error(_) => "Error",
437                            })
438                            .size(LabelSize::XSmall)
439                            .buffer_font(cx),
440                        ),
441                )
442                .map(|parent| {
443                    if !is_open {
444                        return parent;
445                    }
446
447                    parent.child(
448                        v_flex()
449                            .child(
450                                v_flex()
451                                    .gap_0p5()
452                                    .py_1()
453                                    .px_2p5()
454                                    .border_b_1()
455                                    .border_color(cx.theme().colors().border)
456                                    .child(Label::new("Input:"))
457                                    .child(Label::new(
458                                        serde_json::to_string_pretty(&tool_use.input)
459                                            .unwrap_or_default(),
460                                    )),
461                            )
462                            .map(|parent| match tool_use.status {
463                                ToolUseStatus::Finished(output) => parent.child(
464                                    v_flex()
465                                        .gap_0p5()
466                                        .py_1()
467                                        .px_2p5()
468                                        .child(Label::new("Result:"))
469                                        .child(Label::new(output)),
470                                ),
471                                ToolUseStatus::Error(err) => parent.child(
472                                    v_flex()
473                                        .gap_0p5()
474                                        .py_1()
475                                        .px_2p5()
476                                        .child(Label::new("Error:"))
477                                        .child(Label::new(err)),
478                                ),
479                                ToolUseStatus::Pending | ToolUseStatus::Running => parent,
480                            }),
481                    )
482                }),
483        )
484    }
485}
486
487impl Render for ActiveThread {
488    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
489        v_flex()
490            .size_full()
491            .child(list(self.list_state.clone()).flex_grow())
492    }
493}