active_thread.rs

  1use std::sync::Arc;
  2
  3use assistant_tool::ToolWorkingSet;
  4use collections::HashMap;
  5use gpui::{
  6    list, AbsoluteLength, AnyElement, App, DefiniteLength, EdgesRefinement, Empty, Entity, Length,
  7    ListAlignment, ListOffset, ListState, StyleRefinement, Subscription, TextStyleRefinement,
  8    UnderlineStyle, WeakEntity,
  9};
 10use language::LanguageRegistry;
 11use language_model::Role;
 12use markdown::{Markdown, MarkdownStyle};
 13use settings::Settings as _;
 14use theme::ThemeSettings;
 15use ui::prelude::*;
 16use workspace::Workspace;
 17
 18use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent};
 19use crate::thread_store::ThreadStore;
 20use crate::ui::ContextPill;
 21
 22pub struct ActiveThread {
 23    workspace: WeakEntity<Workspace>,
 24    language_registry: Arc<LanguageRegistry>,
 25    tools: Arc<ToolWorkingSet>,
 26    thread_store: Entity<ThreadStore>,
 27    thread: Entity<Thread>,
 28    messages: Vec<MessageId>,
 29    list_state: ListState,
 30    rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
 31    last_error: Option<ThreadError>,
 32    _subscriptions: Vec<Subscription>,
 33}
 34
 35impl ActiveThread {
 36    pub fn new(
 37        thread: Entity<Thread>,
 38        thread_store: Entity<ThreadStore>,
 39        workspace: WeakEntity<Workspace>,
 40        language_registry: Arc<LanguageRegistry>,
 41        tools: Arc<ToolWorkingSet>,
 42        window: &mut Window,
 43        cx: &mut Context<Self>,
 44    ) -> Self {
 45        let subscriptions = vec![
 46            cx.observe(&thread, |_, _, cx| cx.notify()),
 47            cx.subscribe_in(&thread, window, Self::handle_thread_event),
 48        ];
 49
 50        let mut this = Self {
 51            workspace,
 52            language_registry,
 53            tools,
 54            thread_store,
 55            thread: thread.clone(),
 56            messages: Vec::new(),
 57            rendered_messages_by_id: HashMap::default(),
 58            list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
 59                let this = cx.entity().downgrade();
 60                move |ix, _: &mut Window, cx: &mut App| {
 61                    this.update(cx, |this, cx| this.render_message(ix, cx))
 62                        .unwrap()
 63                }
 64            }),
 65            last_error: None,
 66            _subscriptions: subscriptions,
 67        };
 68
 69        for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
 70            this.push_message(&message.id, message.text.clone(), window, cx);
 71        }
 72
 73        this
 74    }
 75
 76    pub fn thread(&self) -> &Entity<Thread> {
 77        &self.thread
 78    }
 79
 80    pub fn is_empty(&self) -> bool {
 81        self.messages.is_empty()
 82    }
 83
 84    pub fn summary(&self, cx: &App) -> Option<SharedString> {
 85        self.thread.read(cx).summary()
 86    }
 87
 88    pub fn summary_or_default(&self, cx: &App) -> SharedString {
 89        self.thread.read(cx).summary_or_default()
 90    }
 91
 92    pub fn cancel_last_completion(&mut self, cx: &mut App) -> bool {
 93        self.last_error.take();
 94        self.thread
 95            .update(cx, |thread, _cx| thread.cancel_last_completion())
 96    }
 97
 98    pub fn last_error(&self) -> Option<ThreadError> {
 99        self.last_error.clone()
100    }
101
102    pub fn clear_last_error(&mut self) {
103        self.last_error.take();
104    }
105
106    fn push_message(
107        &mut self,
108        id: &MessageId,
109        text: String,
110        window: &mut Window,
111        cx: &mut Context<Self>,
112    ) {
113        let old_len = self.messages.len();
114        self.messages.push(*id);
115        self.list_state.splice(old_len..old_len, 1);
116
117        let theme_settings = ThemeSettings::get_global(cx);
118        let colors = cx.theme().colors();
119        let ui_font_size = TextSize::Default.rems(cx);
120        let buffer_font_size = TextSize::Small.rems(cx);
121        let mut text_style = window.text_style();
122
123        text_style.refine(&TextStyleRefinement {
124            font_family: Some(theme_settings.ui_font.family.clone()),
125            font_size: Some(ui_font_size.into()),
126            color: Some(cx.theme().colors().text),
127            ..Default::default()
128        });
129
130        let markdown_style = MarkdownStyle {
131            base_text_style: text_style,
132            syntax: cx.theme().syntax().clone(),
133            selection_background_color: cx.theme().players().local().selection,
134            code_block: StyleRefinement {
135                margin: EdgesRefinement {
136                    top: Some(Length::Definite(rems(0.).into())),
137                    left: Some(Length::Definite(rems(0.).into())),
138                    right: Some(Length::Definite(rems(0.).into())),
139                    bottom: Some(Length::Definite(rems(0.5).into())),
140                },
141                padding: EdgesRefinement {
142                    top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
143                    left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
144                    right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
145                    bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
146                },
147                background: Some(colors.editor_background.into()),
148                border_color: Some(colors.border_variant),
149                border_widths: EdgesRefinement {
150                    top: Some(AbsoluteLength::Pixels(Pixels(1.))),
151                    left: Some(AbsoluteLength::Pixels(Pixels(1.))),
152                    right: Some(AbsoluteLength::Pixels(Pixels(1.))),
153                    bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
154                },
155                text: Some(TextStyleRefinement {
156                    font_family: Some(theme_settings.buffer_font.family.clone()),
157                    font_size: Some(buffer_font_size.into()),
158                    ..Default::default()
159                }),
160                ..Default::default()
161            },
162            inline_code: TextStyleRefinement {
163                font_family: Some(theme_settings.buffer_font.family.clone()),
164                font_size: Some(buffer_font_size.into()),
165                background_color: Some(colors.editor_foreground.opacity(0.1)),
166                ..Default::default()
167            },
168            link: TextStyleRefinement {
169                background_color: Some(colors.editor_foreground.opacity(0.025)),
170                underline: Some(UnderlineStyle {
171                    color: Some(colors.text_accent.opacity(0.5)),
172                    thickness: px(1.),
173                    ..Default::default()
174                }),
175                ..Default::default()
176            },
177            ..Default::default()
178        };
179
180        let markdown = cx.new(|cx| {
181            Markdown::new(
182                text,
183                markdown_style,
184                Some(self.language_registry.clone()),
185                None,
186                window,
187                cx,
188            )
189        });
190        self.rendered_messages_by_id.insert(*id, markdown);
191        self.list_state.scroll_to(ListOffset {
192            item_ix: old_len,
193            offset_in_item: Pixels(0.0),
194        });
195    }
196
197    fn handle_thread_event(
198        &mut self,
199        _: &Entity<Thread>,
200        event: &ThreadEvent,
201        window: &mut Window,
202        cx: &mut Context<Self>,
203    ) {
204        match event {
205            ThreadEvent::ShowError(error) => {
206                self.last_error = Some(error.clone());
207            }
208            ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => {
209                self.thread_store
210                    .update(cx, |thread_store, cx| {
211                        thread_store.save_thread(&self.thread, cx)
212                    })
213                    .detach_and_log_err(cx);
214            }
215            ThreadEvent::StreamedAssistantText(message_id, text) => {
216                if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
217                    markdown.update(cx, |markdown, cx| {
218                        markdown.append(text, window, cx);
219                    });
220                }
221            }
222            ThreadEvent::MessageAdded(message_id) => {
223                if let Some(message_text) = self
224                    .thread
225                    .read(cx)
226                    .message(*message_id)
227                    .map(|message| message.text.clone())
228                {
229                    self.push_message(message_id, message_text, window, cx);
230                }
231
232                self.thread_store
233                    .update(cx, |thread_store, cx| {
234                        thread_store.save_thread(&self.thread, cx)
235                    })
236                    .detach_and_log_err(cx);
237
238                cx.notify();
239            }
240            ThreadEvent::UsePendingTools => {
241                let pending_tool_uses = self
242                    .thread
243                    .read(cx)
244                    .pending_tool_uses()
245                    .into_iter()
246                    .filter(|tool_use| tool_use.status.is_idle())
247                    .cloned()
248                    .collect::<Vec<_>>();
249
250                for tool_use in pending_tool_uses {
251                    if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
252                        let task = tool.run(tool_use.input, self.workspace.clone(), window, cx);
253
254                        self.thread.update(cx, |thread, cx| {
255                            thread.insert_tool_output(
256                                tool_use.assistant_message_id,
257                                tool_use.id.clone(),
258                                task,
259                                cx,
260                            );
261                        });
262                    }
263                }
264            }
265            ThreadEvent::ToolFinished { .. } => {}
266        }
267    }
268
269    fn render_message(&self, ix: usize, cx: &mut Context<Self>) -> AnyElement {
270        let message_id = self.messages[ix];
271        let Some(message) = self.thread.read(cx).message(message_id) else {
272            return Empty.into_any();
273        };
274
275        let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
276            return Empty.into_any();
277        };
278
279        let context = self.thread.read(cx).context_for_message(message_id);
280        let colors = cx.theme().colors();
281
282        let message_content = v_flex()
283            .child(div().p_2p5().text_ui(cx).child(markdown.clone()))
284            .when_some(context, |parent, context| {
285                if !context.is_empty() {
286                    parent.child(
287                        h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
288                            context
289                                .into_iter()
290                                .map(|context| ContextPill::added(context, false, false, None)),
291                        ),
292                    )
293                } else {
294                    parent
295                }
296            });
297
298        let styled_message = match message.role {
299            Role::User => v_flex()
300                .id(("message-container", ix))
301                .pt_2p5()
302                .px_2p5()
303                .child(
304                    v_flex()
305                        .bg(colors.editor_background)
306                        .rounded_lg()
307                        .border_1()
308                        .border_color(colors.border)
309                        .shadow_sm()
310                        .child(
311                            h_flex()
312                                .py_1()
313                                .px_2()
314                                .bg(colors.editor_foreground.opacity(0.05))
315                                .border_b_1()
316                                .border_color(colors.border)
317                                .justify_between()
318                                .rounded_t(px(6.))
319                                .child(
320                                    h_flex()
321                                        .gap_1p5()
322                                        .child(
323                                            Icon::new(IconName::PersonCircle)
324                                                .size(IconSize::XSmall)
325                                                .color(Color::Muted),
326                                        )
327                                        .child(
328                                            Label::new("You")
329                                                .size(LabelSize::Small)
330                                                .color(Color::Muted),
331                                        ),
332                                ),
333                        )
334                        .child(message_content),
335                ),
336            Role::Assistant => div().id(("message-container", ix)).child(message_content),
337            Role::System => div().id(("message-container", ix)).py_1().px_2().child(
338                v_flex()
339                    .bg(colors.editor_background)
340                    .rounded_md()
341                    .child(message_content),
342            ),
343        };
344
345        styled_message.into_any()
346    }
347}
348
349impl Render for ActiveThread {
350    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
351        v_flex()
352            .size_full()
353            .child(list(self.list_state.clone()).flex_grow())
354    }
355}