active_thread.rs

  1use std::sync::Arc;
  2
  3use assistant_tool::ToolWorkingSet;
  4use collections::HashMap;
  5use gpui::{
  6    list, AbsoluteLength, AnyElement, App, DefiniteLength, EdgesRefinement, Empty, Entity, Length,
  7    ListAlignment, ListOffset, ListState, StyleRefinement, Subscription, TextStyleRefinement,
  8    UnderlineStyle, WeakEntity,
  9};
 10use language::LanguageRegistry;
 11use language_model::{LanguageModelToolUseId, Role};
 12use markdown::{Markdown, MarkdownStyle};
 13use settings::Settings as _;
 14use theme::ThemeSettings;
 15use ui::{prelude::*, Disclosure};
 16use workspace::Workspace;
 17
 18use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent, ToolUse, ToolUseStatus};
 19use crate::thread_store::ThreadStore;
 20use crate::ui::ContextPill;
 21
 22pub struct ActiveThread {
 23    workspace: WeakEntity<Workspace>,
 24    language_registry: Arc<LanguageRegistry>,
 25    tools: Arc<ToolWorkingSet>,
 26    thread_store: Entity<ThreadStore>,
 27    thread: Entity<Thread>,
 28    messages: Vec<MessageId>,
 29    list_state: ListState,
 30    rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
 31    expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
 32    last_error: Option<ThreadError>,
 33    _subscriptions: Vec<Subscription>,
 34}
 35
 36impl ActiveThread {
 37    pub fn new(
 38        thread: Entity<Thread>,
 39        thread_store: Entity<ThreadStore>,
 40        workspace: WeakEntity<Workspace>,
 41        language_registry: Arc<LanguageRegistry>,
 42        tools: Arc<ToolWorkingSet>,
 43        window: &mut Window,
 44        cx: &mut Context<Self>,
 45    ) -> Self {
 46        let subscriptions = vec![
 47            cx.observe(&thread, |_, _, cx| cx.notify()),
 48            cx.subscribe_in(&thread, window, Self::handle_thread_event),
 49        ];
 50
 51        let mut this = Self {
 52            workspace,
 53            language_registry,
 54            tools,
 55            thread_store,
 56            thread: thread.clone(),
 57            messages: Vec::new(),
 58            rendered_messages_by_id: HashMap::default(),
 59            expanded_tool_uses: HashMap::default(),
 60            list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
 61                let this = cx.entity().downgrade();
 62                move |ix, _: &mut Window, cx: &mut App| {
 63                    this.update(cx, |this, cx| this.render_message(ix, cx))
 64                        .unwrap()
 65                }
 66            }),
 67            last_error: None,
 68            _subscriptions: subscriptions,
 69        };
 70
 71        for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
 72            this.push_message(&message.id, message.text.clone(), window, cx);
 73        }
 74
 75        this
 76    }
 77
 78    pub fn thread(&self) -> &Entity<Thread> {
 79        &self.thread
 80    }
 81
 82    pub fn is_empty(&self) -> bool {
 83        self.messages.is_empty()
 84    }
 85
 86    pub fn summary(&self, cx: &App) -> Option<SharedString> {
 87        self.thread.read(cx).summary()
 88    }
 89
 90    pub fn summary_or_default(&self, cx: &App) -> SharedString {
 91        self.thread.read(cx).summary_or_default()
 92    }
 93
 94    pub fn cancel_last_completion(&mut self, cx: &mut App) -> bool {
 95        self.last_error.take();
 96        self.thread
 97            .update(cx, |thread, _cx| thread.cancel_last_completion())
 98    }
 99
100    pub fn last_error(&self) -> Option<ThreadError> {
101        self.last_error.clone()
102    }
103
104    pub fn clear_last_error(&mut self) {
105        self.last_error.take();
106    }
107
108    fn push_message(
109        &mut self,
110        id: &MessageId,
111        text: String,
112        window: &mut Window,
113        cx: &mut Context<Self>,
114    ) {
115        let old_len = self.messages.len();
116        self.messages.push(*id);
117        self.list_state.splice(old_len..old_len, 1);
118
119        let theme_settings = ThemeSettings::get_global(cx);
120        let colors = cx.theme().colors();
121        let ui_font_size = TextSize::Default.rems(cx);
122        let buffer_font_size = TextSize::Small.rems(cx);
123        let mut text_style = window.text_style();
124
125        text_style.refine(&TextStyleRefinement {
126            font_family: Some(theme_settings.ui_font.family.clone()),
127            font_size: Some(ui_font_size.into()),
128            color: Some(cx.theme().colors().text),
129            ..Default::default()
130        });
131
132        let markdown_style = MarkdownStyle {
133            base_text_style: text_style,
134            syntax: cx.theme().syntax().clone(),
135            selection_background_color: cx.theme().players().local().selection,
136            code_block: StyleRefinement {
137                margin: EdgesRefinement {
138                    top: Some(Length::Definite(rems(0.).into())),
139                    left: Some(Length::Definite(rems(0.).into())),
140                    right: Some(Length::Definite(rems(0.).into())),
141                    bottom: Some(Length::Definite(rems(0.5).into())),
142                },
143                padding: EdgesRefinement {
144                    top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
145                    left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
146                    right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
147                    bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
148                },
149                background: Some(colors.editor_background.into()),
150                border_color: Some(colors.border_variant),
151                border_widths: EdgesRefinement {
152                    top: Some(AbsoluteLength::Pixels(Pixels(1.))),
153                    left: Some(AbsoluteLength::Pixels(Pixels(1.))),
154                    right: Some(AbsoluteLength::Pixels(Pixels(1.))),
155                    bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
156                },
157                text: Some(TextStyleRefinement {
158                    font_family: Some(theme_settings.buffer_font.family.clone()),
159                    font_size: Some(buffer_font_size.into()),
160                    ..Default::default()
161                }),
162                ..Default::default()
163            },
164            inline_code: TextStyleRefinement {
165                font_family: Some(theme_settings.buffer_font.family.clone()),
166                font_size: Some(buffer_font_size.into()),
167                background_color: Some(colors.editor_foreground.opacity(0.1)),
168                ..Default::default()
169            },
170            link: TextStyleRefinement {
171                background_color: Some(colors.editor_foreground.opacity(0.025)),
172                underline: Some(UnderlineStyle {
173                    color: Some(colors.text_accent.opacity(0.5)),
174                    thickness: px(1.),
175                    ..Default::default()
176                }),
177                ..Default::default()
178            },
179            ..Default::default()
180        };
181
182        let markdown = cx.new(|cx| {
183            Markdown::new(
184                text.into(),
185                markdown_style,
186                Some(self.language_registry.clone()),
187                None,
188                cx,
189            )
190        });
191        self.rendered_messages_by_id.insert(*id, markdown);
192        self.list_state.scroll_to(ListOffset {
193            item_ix: old_len,
194            offset_in_item: Pixels(0.0),
195        });
196    }
197
198    fn handle_thread_event(
199        &mut self,
200        _: &Entity<Thread>,
201        event: &ThreadEvent,
202        window: &mut Window,
203        cx: &mut Context<Self>,
204    ) {
205        match event {
206            ThreadEvent::ShowError(error) => {
207                self.last_error = Some(error.clone());
208            }
209            ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => {
210                self.thread_store
211                    .update(cx, |thread_store, cx| {
212                        thread_store.save_thread(&self.thread, cx)
213                    })
214                    .detach_and_log_err(cx);
215            }
216            ThreadEvent::StreamedAssistantText(message_id, text) => {
217                if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
218                    markdown.update(cx, |markdown, cx| {
219                        markdown.append(text, cx);
220                    });
221                }
222            }
223            ThreadEvent::MessageAdded(message_id) => {
224                if let Some(message_text) = self
225                    .thread
226                    .read(cx)
227                    .message(*message_id)
228                    .map(|message| message.text.clone())
229                {
230                    self.push_message(message_id, message_text, window, cx);
231                }
232
233                self.thread_store
234                    .update(cx, |thread_store, cx| {
235                        thread_store.save_thread(&self.thread, cx)
236                    })
237                    .detach_and_log_err(cx);
238
239                cx.notify();
240            }
241            ThreadEvent::UsePendingTools => {
242                let pending_tool_uses = self
243                    .thread
244                    .read(cx)
245                    .pending_tool_uses()
246                    .into_iter()
247                    .filter(|tool_use| tool_use.status.is_idle())
248                    .cloned()
249                    .collect::<Vec<_>>();
250
251                for tool_use in pending_tool_uses {
252                    if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
253                        let task = tool.run(tool_use.input, self.workspace.clone(), window, cx);
254
255                        self.thread.update(cx, |thread, cx| {
256                            thread.insert_tool_output(
257                                tool_use.assistant_message_id,
258                                tool_use.id.clone(),
259                                task,
260                                cx,
261                            );
262                        });
263                    }
264                }
265            }
266            ThreadEvent::ToolFinished { .. } => {}
267        }
268    }
269
270    fn render_message(&self, ix: usize, cx: &mut Context<Self>) -> AnyElement {
271        let message_id = self.messages[ix];
272        let Some(message) = self.thread.read(cx).message(message_id) else {
273            return Empty.into_any();
274        };
275
276        let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
277            return Empty.into_any();
278        };
279
280        let context = self.thread.read(cx).context_for_message(message_id);
281        let tool_uses = self.thread.read(cx).tool_uses_for_message(message_id);
282        let colors = cx.theme().colors();
283
284        let message_content = v_flex()
285            .child(div().p_2p5().text_ui(cx).child(markdown.clone()))
286            .when_some(context, |parent, context| {
287                if !context.is_empty() {
288                    parent.child(
289                        h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
290                            context
291                                .into_iter()
292                                .map(|context| ContextPill::added(context, false, false, None)),
293                        ),
294                    )
295                } else {
296                    parent
297                }
298            });
299
300        let styled_message = match message.role {
301            Role::User => v_flex()
302                .id(("message-container", ix))
303                .pt_2p5()
304                .px_2p5()
305                .child(
306                    v_flex()
307                        .bg(colors.editor_background)
308                        .rounded_lg()
309                        .border_1()
310                        .border_color(colors.border)
311                        .shadow_sm()
312                        .child(
313                            h_flex()
314                                .py_1()
315                                .px_2()
316                                .bg(colors.editor_foreground.opacity(0.05))
317                                .border_b_1()
318                                .border_color(colors.border)
319                                .justify_between()
320                                .rounded_t(px(6.))
321                                .child(
322                                    h_flex()
323                                        .gap_1p5()
324                                        .child(
325                                            Icon::new(IconName::PersonCircle)
326                                                .size(IconSize::XSmall)
327                                                .color(Color::Muted),
328                                        )
329                                        .child(
330                                            Label::new("You")
331                                                .size(LabelSize::Small)
332                                                .color(Color::Muted),
333                                        ),
334                                ),
335                        )
336                        .child(message_content),
337                ),
338            Role::Assistant => div()
339                .id(("message-container", ix))
340                .child(message_content)
341                .map(|parent| {
342                    if tool_uses.is_empty() {
343                        return parent;
344                    }
345
346                    parent.child(
347                        v_flex().children(
348                            tool_uses
349                                .into_iter()
350                                .map(|tool_use| self.render_tool_use(tool_use, cx)),
351                        ),
352                    )
353                }),
354            Role::System => div().id(("message-container", ix)).py_1().px_2().child(
355                v_flex()
356                    .bg(colors.editor_background)
357                    .rounded_md()
358                    .child(message_content),
359            ),
360        };
361
362        styled_message.into_any()
363    }
364
365    fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {
366        let is_open = self
367            .expanded_tool_uses
368            .get(&tool_use.id)
369            .copied()
370            .unwrap_or_default();
371
372        v_flex().px_2p5().child(
373            v_flex()
374                .gap_1()
375                .bg(cx.theme().colors().editor_background)
376                .rounded_lg()
377                .border_1()
378                .border_color(cx.theme().colors().border)
379                .shadow_sm()
380                .child(
381                    h_flex()
382                        .justify_between()
383                        .py_1()
384                        .px_2()
385                        .bg(cx.theme().colors().editor_foreground.opacity(0.05))
386                        .when(is_open, |element| element.border_b_1())
387                        .border_color(cx.theme().colors().border)
388                        .rounded_t(px(6.))
389                        .child(
390                            h_flex()
391                                .gap_2()
392                                .child(Disclosure::new("tool-use-disclosure", is_open).on_click(
393                                    cx.listener({
394                                        let tool_use_id = tool_use.id.clone();
395                                        move |this, _event, _window, _cx| {
396                                            let is_open = this
397                                                .expanded_tool_uses
398                                                .entry(tool_use_id.clone())
399                                                .or_insert(false);
400
401                                            *is_open = !*is_open;
402                                        }
403                                    }),
404                                ))
405                                .child(Label::new(tool_use.name)),
406                        )
407                        .child(Label::new(match tool_use.status {
408                            ToolUseStatus::Pending => "Pending",
409                            ToolUseStatus::Running => "Running",
410                            ToolUseStatus::Finished(_) => "Finished",
411                            ToolUseStatus::Error(_) => "Error",
412                        })),
413                )
414                .map(|parent| {
415                    if !is_open {
416                        return parent;
417                    }
418
419                    parent.child(
420                        v_flex()
421                            .gap_2()
422                            .p_2p5()
423                            .child(
424                                v_flex()
425                                    .gap_0p5()
426                                    .child(Label::new("Input:"))
427                                    .child(Label::new(
428                                        serde_json::to_string_pretty(&tool_use.input)
429                                            .unwrap_or_default(),
430                                    )),
431                            )
432                            .map(|parent| match tool_use.status {
433                                ToolUseStatus::Finished(output) => parent.child(
434                                    v_flex()
435                                        .gap_0p5()
436                                        .child(Label::new("Result:"))
437                                        .child(Label::new(output)),
438                                ),
439                                ToolUseStatus::Error(err) => parent.child(
440                                    v_flex()
441                                        .gap_0p5()
442                                        .child(Label::new("Error:"))
443                                        .child(Label::new(err)),
444                                ),
445                                ToolUseStatus::Pending | ToolUseStatus::Running => parent,
446                            }),
447                    )
448                }),
449        )
450    }
451}
452
453impl Render for ActiveThread {
454    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
455        v_flex()
456            .size_full()
457            .child(list(self.list_state.clone()).flex_grow())
458    }
459}