active_thread.rs

  1use std::sync::Arc;
  2
  3use assistant_tool::ToolWorkingSet;
  4use collections::HashMap;
  5use gpui::{
  6    list, AbsoluteLength, AnyElement, App, DefiniteLength, EdgesRefinement, Empty, Entity, Length,
  7    ListAlignment, ListOffset, ListState, StyleRefinement, Subscription, TextStyleRefinement,
  8    UnderlineStyle, WeakEntity,
  9};
 10use language::LanguageRegistry;
 11use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
 12use markdown::{Markdown, MarkdownStyle};
 13use settings::Settings as _;
 14use theme::ThemeSettings;
 15use ui::{prelude::*, Disclosure};
 16use workspace::Workspace;
 17
 18use crate::thread::{
 19    MessageId, RequestKind, Thread, ThreadError, ThreadEvent, ToolUse, ToolUseStatus,
 20};
 21use crate::thread_store::ThreadStore;
 22use crate::ui::ContextPill;
 23
 24pub struct ActiveThread {
 25    workspace: WeakEntity<Workspace>,
 26    language_registry: Arc<LanguageRegistry>,
 27    tools: Arc<ToolWorkingSet>,
 28    thread_store: Entity<ThreadStore>,
 29    thread: Entity<Thread>,
 30    messages: Vec<MessageId>,
 31    list_state: ListState,
 32    rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
 33    expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
 34    last_error: Option<ThreadError>,
 35    _subscriptions: Vec<Subscription>,
 36}
 37
 38impl ActiveThread {
 39    pub fn new(
 40        thread: Entity<Thread>,
 41        thread_store: Entity<ThreadStore>,
 42        workspace: WeakEntity<Workspace>,
 43        language_registry: Arc<LanguageRegistry>,
 44        tools: Arc<ToolWorkingSet>,
 45        window: &mut Window,
 46        cx: &mut Context<Self>,
 47    ) -> Self {
 48        let subscriptions = vec![
 49            cx.observe(&thread, |_, _, cx| cx.notify()),
 50            cx.subscribe_in(&thread, window, Self::handle_thread_event),
 51        ];
 52
 53        let mut this = Self {
 54            workspace,
 55            language_registry,
 56            tools,
 57            thread_store,
 58            thread: thread.clone(),
 59            messages: Vec::new(),
 60            rendered_messages_by_id: HashMap::default(),
 61            expanded_tool_uses: HashMap::default(),
 62            list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
 63                let this = cx.entity().downgrade();
 64                move |ix, _: &mut Window, cx: &mut App| {
 65                    this.update(cx, |this, cx| this.render_message(ix, cx))
 66                        .unwrap()
 67                }
 68            }),
 69            last_error: None,
 70            _subscriptions: subscriptions,
 71        };
 72
 73        for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
 74            this.push_message(&message.id, message.text.clone(), window, cx);
 75        }
 76
 77        this
 78    }
 79
 80    pub fn thread(&self) -> &Entity<Thread> {
 81        &self.thread
 82    }
 83
 84    pub fn is_empty(&self) -> bool {
 85        self.messages.is_empty()
 86    }
 87
 88    pub fn summary(&self, cx: &App) -> Option<SharedString> {
 89        self.thread.read(cx).summary()
 90    }
 91
 92    pub fn summary_or_default(&self, cx: &App) -> SharedString {
 93        self.thread.read(cx).summary_or_default()
 94    }
 95
 96    pub fn cancel_last_completion(&mut self, cx: &mut App) -> bool {
 97        self.last_error.take();
 98        self.thread
 99            .update(cx, |thread, _cx| thread.cancel_last_completion())
100    }
101
102    pub fn last_error(&self) -> Option<ThreadError> {
103        self.last_error.clone()
104    }
105
106    pub fn clear_last_error(&mut self) {
107        self.last_error.take();
108    }
109
110    fn push_message(
111        &mut self,
112        id: &MessageId,
113        text: String,
114        window: &mut Window,
115        cx: &mut Context<Self>,
116    ) {
117        let old_len = self.messages.len();
118        self.messages.push(*id);
119        self.list_state.splice(old_len..old_len, 1);
120
121        let theme_settings = ThemeSettings::get_global(cx);
122        let colors = cx.theme().colors();
123        let ui_font_size = TextSize::Default.rems(cx);
124        let buffer_font_size = TextSize::Small.rems(cx);
125        let mut text_style = window.text_style();
126
127        text_style.refine(&TextStyleRefinement {
128            font_family: Some(theme_settings.ui_font.family.clone()),
129            font_size: Some(ui_font_size.into()),
130            color: Some(cx.theme().colors().text),
131            ..Default::default()
132        });
133
134        let markdown_style = MarkdownStyle {
135            base_text_style: text_style,
136            syntax: cx.theme().syntax().clone(),
137            selection_background_color: cx.theme().players().local().selection,
138            code_block: StyleRefinement {
139                margin: EdgesRefinement {
140                    top: Some(Length::Definite(rems(0.).into())),
141                    left: Some(Length::Definite(rems(0.).into())),
142                    right: Some(Length::Definite(rems(0.).into())),
143                    bottom: Some(Length::Definite(rems(0.5).into())),
144                },
145                padding: EdgesRefinement {
146                    top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
147                    left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
148                    right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
149                    bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
150                },
151                background: Some(colors.editor_background.into()),
152                border_color: Some(colors.border_variant),
153                border_widths: EdgesRefinement {
154                    top: Some(AbsoluteLength::Pixels(Pixels(1.))),
155                    left: Some(AbsoluteLength::Pixels(Pixels(1.))),
156                    right: Some(AbsoluteLength::Pixels(Pixels(1.))),
157                    bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
158                },
159                text: Some(TextStyleRefinement {
160                    font_family: Some(theme_settings.buffer_font.family.clone()),
161                    font_size: Some(buffer_font_size.into()),
162                    ..Default::default()
163                }),
164                ..Default::default()
165            },
166            inline_code: TextStyleRefinement {
167                font_family: Some(theme_settings.buffer_font.family.clone()),
168                font_size: Some(buffer_font_size.into()),
169                background_color: Some(colors.editor_foreground.opacity(0.1)),
170                ..Default::default()
171            },
172            link: TextStyleRefinement {
173                background_color: Some(colors.editor_foreground.opacity(0.025)),
174                underline: Some(UnderlineStyle {
175                    color: Some(colors.text_accent.opacity(0.5)),
176                    thickness: px(1.),
177                    ..Default::default()
178                }),
179                ..Default::default()
180            },
181            ..Default::default()
182        };
183
184        let markdown = cx.new(|cx| {
185            Markdown::new(
186                text.into(),
187                markdown_style,
188                Some(self.language_registry.clone()),
189                None,
190                cx,
191            )
192        });
193        self.rendered_messages_by_id.insert(*id, markdown);
194        self.list_state.scroll_to(ListOffset {
195            item_ix: old_len,
196            offset_in_item: Pixels(0.0),
197        });
198    }
199
200    fn handle_thread_event(
201        &mut self,
202        _: &Entity<Thread>,
203        event: &ThreadEvent,
204        window: &mut Window,
205        cx: &mut Context<Self>,
206    ) {
207        match event {
208            ThreadEvent::ShowError(error) => {
209                self.last_error = Some(error.clone());
210            }
211            ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => {
212                self.thread_store
213                    .update(cx, |thread_store, cx| {
214                        thread_store.save_thread(&self.thread, cx)
215                    })
216                    .detach_and_log_err(cx);
217            }
218            ThreadEvent::StreamedAssistantText(message_id, text) => {
219                if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
220                    markdown.update(cx, |markdown, cx| {
221                        markdown.append(text, cx);
222                    });
223                }
224            }
225            ThreadEvent::MessageAdded(message_id) => {
226                if let Some(message_text) = self
227                    .thread
228                    .read(cx)
229                    .message(*message_id)
230                    .map(|message| message.text.clone())
231                {
232                    self.push_message(message_id, message_text, window, cx);
233                }
234
235                self.thread_store
236                    .update(cx, |thread_store, cx| {
237                        thread_store.save_thread(&self.thread, cx)
238                    })
239                    .detach_and_log_err(cx);
240
241                cx.notify();
242            }
243            ThreadEvent::UsePendingTools => {
244                let pending_tool_uses = self
245                    .thread
246                    .read(cx)
247                    .pending_tool_uses()
248                    .into_iter()
249                    .filter(|tool_use| tool_use.status.is_idle())
250                    .cloned()
251                    .collect::<Vec<_>>();
252
253                for tool_use in pending_tool_uses {
254                    if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
255                        let task = tool.run(tool_use.input, self.workspace.clone(), window, cx);
256
257                        self.thread.update(cx, |thread, cx| {
258                            thread.insert_tool_output(tool_use.id.clone(), task, cx);
259                        });
260                    }
261                }
262            }
263            ThreadEvent::ToolFinished { .. } => {
264                let all_tools_finished = self
265                    .thread
266                    .read(cx)
267                    .pending_tool_uses()
268                    .into_iter()
269                    .all(|tool_use| tool_use.status.is_error());
270                if all_tools_finished {
271                    let model_registry = LanguageModelRegistry::read_global(cx);
272                    if let Some(model) = model_registry.active_model() {
273                        self.thread.update(cx, |thread, cx| {
274                            // Insert an empty user message to contain the tool results.
275                            thread.insert_user_message("", Vec::new(), cx);
276                            thread.send_to_model(model, RequestKind::Chat, true, cx);
277                        });
278                    }
279                }
280            }
281        }
282    }
283
284    fn render_message(&self, ix: usize, cx: &mut Context<Self>) -> AnyElement {
285        let message_id = self.messages[ix];
286        let Some(message) = self.thread.read(cx).message(message_id) else {
287            return Empty.into_any();
288        };
289
290        let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
291            return Empty.into_any();
292        };
293
294        let context = self.thread.read(cx).context_for_message(message_id);
295        let tool_uses = self.thread.read(cx).tool_uses_for_message(message_id);
296        let colors = cx.theme().colors();
297
298        // Don't render user messages that are just there for returning tool results.
299        if message.role == Role::User
300            && message.text.is_empty()
301            && self.thread.read(cx).message_has_tool_results(message_id)
302        {
303            return Empty.into_any();
304        }
305
306        let message_content = v_flex()
307            .child(div().p_2p5().text_ui(cx).child(markdown.clone()))
308            .when_some(context, |parent, context| {
309                if !context.is_empty() {
310                    parent.child(
311                        h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
312                            context
313                                .into_iter()
314                                .map(|context| ContextPill::added(context, false, false, None)),
315                        ),
316                    )
317                } else {
318                    parent
319                }
320            });
321
322        let styled_message = match message.role {
323            Role::User => v_flex()
324                .id(("message-container", ix))
325                .pt_2p5()
326                .px_2p5()
327                .child(
328                    v_flex()
329                        .bg(colors.editor_background)
330                        .rounded_lg()
331                        .border_1()
332                        .border_color(colors.border)
333                        .shadow_sm()
334                        .child(
335                            h_flex()
336                                .py_1()
337                                .px_2()
338                                .bg(colors.editor_foreground.opacity(0.05))
339                                .border_b_1()
340                                .border_color(colors.border)
341                                .justify_between()
342                                .rounded_t(px(6.))
343                                .child(
344                                    h_flex()
345                                        .gap_1p5()
346                                        .child(
347                                            Icon::new(IconName::PersonCircle)
348                                                .size(IconSize::XSmall)
349                                                .color(Color::Muted),
350                                        )
351                                        .child(
352                                            Label::new("You")
353                                                .size(LabelSize::Small)
354                                                .color(Color::Muted),
355                                        ),
356                                ),
357                        )
358                        .child(message_content),
359                ),
360            Role::Assistant => div()
361                .id(("message-container", ix))
362                .child(message_content)
363                .map(|parent| {
364                    if tool_uses.is_empty() {
365                        return parent;
366                    }
367
368                    parent.child(
369                        v_flex().children(
370                            tool_uses
371                                .into_iter()
372                                .map(|tool_use| self.render_tool_use(tool_use, cx)),
373                        ),
374                    )
375                }),
376            Role::System => div().id(("message-container", ix)).py_1().px_2().child(
377                v_flex()
378                    .bg(colors.editor_background)
379                    .rounded_md()
380                    .child(message_content),
381            ),
382        };
383
384        styled_message.into_any()
385    }
386
387    fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {
388        let is_open = self
389            .expanded_tool_uses
390            .get(&tool_use.id)
391            .copied()
392            .unwrap_or_default();
393
394        div().px_2p5().child(
395            v_flex()
396                .gap_1()
397                .rounded_lg()
398                .border_1()
399                .border_color(cx.theme().colors().border)
400                .child(
401                    h_flex()
402                        .justify_between()
403                        .py_0p5()
404                        .pl_1()
405                        .pr_2()
406                        .bg(cx.theme().colors().editor_foreground.opacity(0.02))
407                        .when(is_open, |element| element.border_b_1().rounded_t(px(6.)))
408                        .when(!is_open, |element| element.rounded(px(6.)))
409                        .border_color(cx.theme().colors().border)
410                        .child(
411                            h_flex()
412                                .gap_1()
413                                .child(Disclosure::new("tool-use-disclosure", is_open).on_click(
414                                    cx.listener({
415                                        let tool_use_id = tool_use.id.clone();
416                                        move |this, _event, _window, _cx| {
417                                            let is_open = this
418                                                .expanded_tool_uses
419                                                .entry(tool_use_id.clone())
420                                                .or_insert(false);
421
422                                            *is_open = !*is_open;
423                                        }
424                                    }),
425                                ))
426                                .child(Label::new(tool_use.name)),
427                        )
428                        .child(
429                            Label::new(match tool_use.status {
430                                ToolUseStatus::Pending => "Pending",
431                                ToolUseStatus::Running => "Running",
432                                ToolUseStatus::Finished(_) => "Finished",
433                                ToolUseStatus::Error(_) => "Error",
434                            })
435                            .size(LabelSize::XSmall)
436                            .buffer_font(cx),
437                        ),
438                )
439                .map(|parent| {
440                    if !is_open {
441                        return parent;
442                    }
443
444                    parent.child(
445                        v_flex()
446                            .child(
447                                v_flex()
448                                    .gap_0p5()
449                                    .py_1()
450                                    .px_2p5()
451                                    .border_b_1()
452                                    .border_color(cx.theme().colors().border)
453                                    .child(Label::new("Input:"))
454                                    .child(Label::new(
455                                        serde_json::to_string_pretty(&tool_use.input)
456                                            .unwrap_or_default(),
457                                    )),
458                            )
459                            .map(|parent| match tool_use.status {
460                                ToolUseStatus::Finished(output) => parent.child(
461                                    v_flex()
462                                        .gap_0p5()
463                                        .py_1()
464                                        .px_2p5()
465                                        .child(Label::new("Result:"))
466                                        .child(Label::new(output)),
467                                ),
468                                ToolUseStatus::Error(err) => parent.child(
469                                    v_flex()
470                                        .gap_0p5()
471                                        .py_1()
472                                        .px_2p5()
473                                        .child(Label::new("Error:"))
474                                        .child(Label::new(err)),
475                                ),
476                                ToolUseStatus::Pending | ToolUseStatus::Running => parent,
477                            }),
478                    )
479                }),
480        )
481    }
482}
483
484impl Render for ActiveThread {
485    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
486        v_flex()
487            .size_full()
488            .child(list(self.list_state.clone()).flex_grow())
489    }
490}