active_thread.rs

  1use std::sync::Arc;
  2
  3use assistant_tool::ToolWorkingSet;
  4use collections::HashMap;
  5use gpui::{
  6    list, AbsoluteLength, AnyElement, App, DefiniteLength, EdgesRefinement, Empty, Entity, Length,
  7    ListAlignment, ListOffset, ListState, StyleRefinement, Subscription, TextStyleRefinement,
  8    UnderlineStyle, WeakEntity,
  9};
 10use language::LanguageRegistry;
 11use language_model::Role;
 12use markdown::{Markdown, MarkdownStyle};
 13use settings::Settings as _;
 14use theme::ThemeSettings;
 15use ui::prelude::*;
 16use workspace::Workspace;
 17
 18use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent};
 19use crate::thread_store::ThreadStore;
 20use crate::ui::ContextPill;
 21
 22pub struct ActiveThread {
 23    workspace: WeakEntity<Workspace>,
 24    language_registry: Arc<LanguageRegistry>,
 25    tools: Arc<ToolWorkingSet>,
 26    thread_store: Entity<ThreadStore>,
 27    thread: Entity<Thread>,
 28    messages: Vec<MessageId>,
 29    list_state: ListState,
 30    rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
 31    last_error: Option<ThreadError>,
 32    _subscriptions: Vec<Subscription>,
 33}
 34
 35impl ActiveThread {
 36    pub fn new(
 37        thread: Entity<Thread>,
 38        thread_store: Entity<ThreadStore>,
 39        workspace: WeakEntity<Workspace>,
 40        language_registry: Arc<LanguageRegistry>,
 41        tools: Arc<ToolWorkingSet>,
 42        window: &mut Window,
 43        cx: &mut Context<Self>,
 44    ) -> Self {
 45        let subscriptions = vec![
 46            cx.observe(&thread, |_, _, cx| cx.notify()),
 47            cx.subscribe_in(&thread, window, Self::handle_thread_event),
 48        ];
 49
 50        let mut this = Self {
 51            workspace,
 52            language_registry,
 53            tools,
 54            thread_store,
 55            thread: thread.clone(),
 56            messages: Vec::new(),
 57            rendered_messages_by_id: HashMap::default(),
 58            list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
 59                let this = cx.entity().downgrade();
 60                move |ix, _: &mut Window, cx: &mut App| {
 61                    this.update(cx, |this, cx| this.render_message(ix, cx))
 62                        .unwrap()
 63                }
 64            }),
 65            last_error: None,
 66            _subscriptions: subscriptions,
 67        };
 68
 69        for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
 70            this.push_message(&message.id, message.text.clone(), window, cx);
 71        }
 72
 73        this
 74    }
 75
 76    pub fn thread(&self) -> &Entity<Thread> {
 77        &self.thread
 78    }
 79
 80    pub fn is_empty(&self) -> bool {
 81        self.messages.is_empty()
 82    }
 83
 84    pub fn summary(&self, cx: &App) -> Option<SharedString> {
 85        self.thread.read(cx).summary()
 86    }
 87
 88    pub fn summary_or_default(&self, cx: &App) -> SharedString {
 89        self.thread.read(cx).summary_or_default()
 90    }
 91
 92    pub fn cancel_last_completion(&mut self, cx: &mut App) -> bool {
 93        self.last_error.take();
 94        self.thread
 95            .update(cx, |thread, _cx| thread.cancel_last_completion())
 96    }
 97
 98    pub fn last_error(&self) -> Option<ThreadError> {
 99        self.last_error.clone()
100    }
101
102    pub fn clear_last_error(&mut self) {
103        self.last_error.take();
104    }
105
106    fn push_message(
107        &mut self,
108        id: &MessageId,
109        text: String,
110        window: &mut Window,
111        cx: &mut Context<Self>,
112    ) {
113        let old_len = self.messages.len();
114        self.messages.push(*id);
115        self.list_state.splice(old_len..old_len, 1);
116
117        let theme_settings = ThemeSettings::get_global(cx);
118        let colors = cx.theme().colors();
119        let ui_font_size = TextSize::Default.rems(cx);
120        let buffer_font_size = TextSize::Small.rems(cx);
121        let mut text_style = window.text_style();
122
123        text_style.refine(&TextStyleRefinement {
124            font_family: Some(theme_settings.ui_font.family.clone()),
125            font_size: Some(ui_font_size.into()),
126            color: Some(cx.theme().colors().text),
127            ..Default::default()
128        });
129
130        let markdown_style = MarkdownStyle {
131            base_text_style: text_style,
132            syntax: cx.theme().syntax().clone(),
133            selection_background_color: cx.theme().players().local().selection,
134            code_block: StyleRefinement {
135                margin: EdgesRefinement {
136                    top: Some(Length::Definite(rems(0.).into())),
137                    left: Some(Length::Definite(rems(0.).into())),
138                    right: Some(Length::Definite(rems(0.).into())),
139                    bottom: Some(Length::Definite(rems(0.5).into())),
140                },
141                padding: EdgesRefinement {
142                    top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
143                    left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
144                    right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
145                    bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
146                },
147                background: Some(colors.editor_background.into()),
148                border_color: Some(colors.border_variant),
149                border_widths: EdgesRefinement {
150                    top: Some(AbsoluteLength::Pixels(Pixels(1.))),
151                    left: Some(AbsoluteLength::Pixels(Pixels(1.))),
152                    right: Some(AbsoluteLength::Pixels(Pixels(1.))),
153                    bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
154                },
155                text: Some(TextStyleRefinement {
156                    font_family: Some(theme_settings.buffer_font.family.clone()),
157                    font_size: Some(buffer_font_size.into()),
158                    ..Default::default()
159                }),
160                ..Default::default()
161            },
162            inline_code: TextStyleRefinement {
163                font_family: Some(theme_settings.buffer_font.family.clone()),
164                font_size: Some(buffer_font_size.into()),
165                background_color: Some(colors.editor_foreground.opacity(0.1)),
166                ..Default::default()
167            },
168            link: TextStyleRefinement {
169                background_color: Some(colors.editor_foreground.opacity(0.025)),
170                underline: Some(UnderlineStyle {
171                    color: Some(colors.text_accent.opacity(0.5)),
172                    thickness: px(1.),
173                    ..Default::default()
174                }),
175                ..Default::default()
176            },
177            ..Default::default()
178        };
179
180        let markdown = cx.new(|cx| {
181            Markdown::new(
182                text.into(),
183                markdown_style,
184                Some(self.language_registry.clone()),
185                None,
186                cx,
187            )
188        });
189        self.rendered_messages_by_id.insert(*id, markdown);
190        self.list_state.scroll_to(ListOffset {
191            item_ix: old_len,
192            offset_in_item: Pixels(0.0),
193        });
194    }
195
196    fn handle_thread_event(
197        &mut self,
198        _: &Entity<Thread>,
199        event: &ThreadEvent,
200        window: &mut Window,
201        cx: &mut Context<Self>,
202    ) {
203        match event {
204            ThreadEvent::ShowError(error) => {
205                self.last_error = Some(error.clone());
206            }
207            ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => {
208                self.thread_store
209                    .update(cx, |thread_store, cx| {
210                        thread_store.save_thread(&self.thread, cx)
211                    })
212                    .detach_and_log_err(cx);
213            }
214            ThreadEvent::StreamedAssistantText(message_id, text) => {
215                if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
216                    markdown.update(cx, |markdown, cx| {
217                        markdown.append(text, cx);
218                    });
219                }
220            }
221            ThreadEvent::MessageAdded(message_id) => {
222                if let Some(message_text) = self
223                    .thread
224                    .read(cx)
225                    .message(*message_id)
226                    .map(|message| message.text.clone())
227                {
228                    self.push_message(message_id, message_text, window, cx);
229                }
230
231                self.thread_store
232                    .update(cx, |thread_store, cx| {
233                        thread_store.save_thread(&self.thread, cx)
234                    })
235                    .detach_and_log_err(cx);
236
237                cx.notify();
238            }
239            ThreadEvent::UsePendingTools => {
240                let pending_tool_uses = self
241                    .thread
242                    .read(cx)
243                    .pending_tool_uses()
244                    .into_iter()
245                    .filter(|tool_use| tool_use.status.is_idle())
246                    .cloned()
247                    .collect::<Vec<_>>();
248
249                for tool_use in pending_tool_uses {
250                    if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
251                        let task = tool.run(tool_use.input, self.workspace.clone(), window, cx);
252
253                        self.thread.update(cx, |thread, cx| {
254                            thread.insert_tool_output(
255                                tool_use.assistant_message_id,
256                                tool_use.id.clone(),
257                                task,
258                                cx,
259                            );
260                        });
261                    }
262                }
263            }
264            ThreadEvent::ToolFinished { .. } => {}
265        }
266    }
267
268    fn render_message(&self, ix: usize, cx: &mut Context<Self>) -> AnyElement {
269        let message_id = self.messages[ix];
270        let Some(message) = self.thread.read(cx).message(message_id) else {
271            return Empty.into_any();
272        };
273
274        let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
275            return Empty.into_any();
276        };
277
278        let context = self.thread.read(cx).context_for_message(message_id);
279        let colors = cx.theme().colors();
280
281        let message_content = v_flex()
282            .child(div().p_2p5().text_ui(cx).child(markdown.clone()))
283            .when_some(context, |parent, context| {
284                if !context.is_empty() {
285                    parent.child(
286                        h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
287                            context
288                                .into_iter()
289                                .map(|context| ContextPill::added(context, false, false, None)),
290                        ),
291                    )
292                } else {
293                    parent
294                }
295            });
296
297        let styled_message = match message.role {
298            Role::User => v_flex()
299                .id(("message-container", ix))
300                .pt_2p5()
301                .px_2p5()
302                .child(
303                    v_flex()
304                        .bg(colors.editor_background)
305                        .rounded_lg()
306                        .border_1()
307                        .border_color(colors.border)
308                        .shadow_sm()
309                        .child(
310                            h_flex()
311                                .py_1()
312                                .px_2()
313                                .bg(colors.editor_foreground.opacity(0.05))
314                                .border_b_1()
315                                .border_color(colors.border)
316                                .justify_between()
317                                .rounded_t(px(6.))
318                                .child(
319                                    h_flex()
320                                        .gap_1p5()
321                                        .child(
322                                            Icon::new(IconName::PersonCircle)
323                                                .size(IconSize::XSmall)
324                                                .color(Color::Muted),
325                                        )
326                                        .child(
327                                            Label::new("You")
328                                                .size(LabelSize::Small)
329                                                .color(Color::Muted),
330                                        ),
331                                ),
332                        )
333                        .child(message_content),
334                ),
335            Role::Assistant => div().id(("message-container", ix)).child(message_content),
336            Role::System => div().id(("message-container", ix)).py_1().px_2().child(
337                v_flex()
338                    .bg(colors.editor_background)
339                    .rounded_md()
340                    .child(message_content),
341            ),
342        };
343
344        styled_message.into_any()
345    }
346}
347
348impl Render for ActiveThread {
349    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
350        v_flex()
351            .size_full()
352            .child(list(self.list_state.clone()).flex_grow())
353    }
354}