active_thread.rs

  1use std::sync::Arc;
  2use std::time::Duration;
  3
  4use assistant_tool::ToolWorkingSet;
  5use collections::HashMap;
  6use gpui::{
  7    linear_color_stop, linear_gradient, list, percentage, AbsoluteLength, Animation, AnimationExt,
  8    AnyElement, AppContext, DefiniteLength, EdgesRefinement, Empty, FocusHandle, Length,
  9    ListAlignment, ListOffset, ListState, Model, StyleRefinement, Subscription,
 10    TextStyleRefinement, Transformation, UnderlineStyle, View, WeakView,
 11};
 12use language::LanguageRegistry;
 13use language_model::Role;
 14use markdown::{Markdown, MarkdownStyle};
 15use settings::Settings as _;
 16use theme::ThemeSettings;
 17use ui::{prelude::*, Divider, KeyBinding};
 18use workspace::Workspace;
 19
 20use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent};
 21use crate::ui::ContextPill;
 22
 23pub struct ActiveThread {
 24    workspace: WeakView<Workspace>,
 25    language_registry: Arc<LanguageRegistry>,
 26    tools: Arc<ToolWorkingSet>,
 27    pub(crate) thread: Model<Thread>,
 28    messages: Vec<MessageId>,
 29    list_state: ListState,
 30    rendered_messages_by_id: HashMap<MessageId, View<Markdown>>,
 31    last_error: Option<ThreadError>,
 32    focus_handle: FocusHandle,
 33    _subscriptions: Vec<Subscription>,
 34}
 35
 36impl ActiveThread {
 37    pub fn new(
 38        thread: Model<Thread>,
 39        workspace: WeakView<Workspace>,
 40        language_registry: Arc<LanguageRegistry>,
 41        tools: Arc<ToolWorkingSet>,
 42        focus_handle: FocusHandle,
 43        cx: &mut ViewContext<Self>,
 44    ) -> Self {
 45        let subscriptions = vec![
 46            cx.observe(&thread, |_, _, cx| cx.notify()),
 47            cx.subscribe(&thread, Self::handle_thread_event),
 48        ];
 49
 50        let mut this = Self {
 51            workspace,
 52            language_registry,
 53            tools,
 54            thread: thread.clone(),
 55            messages: Vec::new(),
 56            rendered_messages_by_id: HashMap::default(),
 57            list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
 58                let this = cx.view().downgrade();
 59                move |ix, cx: &mut WindowContext| {
 60                    this.update(cx, |this, cx| this.render_message(ix, cx))
 61                        .unwrap()
 62                }
 63            }),
 64            last_error: None,
 65            focus_handle,
 66            _subscriptions: subscriptions,
 67        };
 68
 69        for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
 70            this.push_message(&message.id, message.text.clone(), cx);
 71        }
 72
 73        this
 74    }
 75
 76    pub fn is_empty(&self) -> bool {
 77        self.messages.is_empty()
 78    }
 79
 80    pub fn summary(&self, cx: &AppContext) -> Option<SharedString> {
 81        self.thread.read(cx).summary()
 82    }
 83
 84    pub fn summary_or_default(&self, cx: &AppContext) -> SharedString {
 85        self.thread.read(cx).summary_or_default()
 86    }
 87
 88    pub fn cancel_last_completion(&mut self, cx: &mut AppContext) -> bool {
 89        self.last_error.take();
 90        self.thread
 91            .update(cx, |thread, _cx| thread.cancel_last_completion())
 92    }
 93
 94    pub fn last_error(&self) -> Option<ThreadError> {
 95        self.last_error.clone()
 96    }
 97
 98    pub fn clear_last_error(&mut self) {
 99        self.last_error.take();
100    }
101
102    fn push_message(&mut self, id: &MessageId, text: String, cx: &mut ViewContext<Self>) {
103        let old_len = self.messages.len();
104        self.messages.push(*id);
105        self.list_state.splice(old_len..old_len, 1);
106
107        let theme_settings = ThemeSettings::get_global(cx);
108        let colors = cx.theme().colors();
109        let ui_font_size = TextSize::Default.rems(cx);
110        let buffer_font_size = TextSize::Small.rems(cx);
111        let mut text_style = cx.text_style();
112
113        text_style.refine(&TextStyleRefinement {
114            font_family: Some(theme_settings.ui_font.family.clone()),
115            font_size: Some(ui_font_size.into()),
116            color: Some(cx.theme().colors().text),
117            ..Default::default()
118        });
119
120        let markdown_style = MarkdownStyle {
121            base_text_style: text_style,
122            syntax: cx.theme().syntax().clone(),
123            selection_background_color: cx.theme().players().local().selection,
124            code_block: StyleRefinement {
125                margin: EdgesRefinement {
126                    top: Some(Length::Definite(rems(0.).into())),
127                    left: Some(Length::Definite(rems(0.).into())),
128                    right: Some(Length::Definite(rems(0.).into())),
129                    bottom: Some(Length::Definite(rems(0.5).into())),
130                },
131                padding: EdgesRefinement {
132                    top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
133                    left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
134                    right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
135                    bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
136                },
137                background: Some(colors.editor_background.into()),
138                border_color: Some(colors.border_variant),
139                border_widths: EdgesRefinement {
140                    top: Some(AbsoluteLength::Pixels(Pixels(1.))),
141                    left: Some(AbsoluteLength::Pixels(Pixels(1.))),
142                    right: Some(AbsoluteLength::Pixels(Pixels(1.))),
143                    bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
144                },
145                text: Some(TextStyleRefinement {
146                    font_family: Some(theme_settings.buffer_font.family.clone()),
147                    font_size: Some(buffer_font_size.into()),
148                    ..Default::default()
149                }),
150                ..Default::default()
151            },
152            inline_code: TextStyleRefinement {
153                font_family: Some(theme_settings.buffer_font.family.clone()),
154                font_size: Some(buffer_font_size.into()),
155                background_color: Some(colors.editor_foreground.opacity(0.1)),
156                ..Default::default()
157            },
158            link: TextStyleRefinement {
159                background_color: Some(colors.editor_foreground.opacity(0.025)),
160                underline: Some(UnderlineStyle {
161                    color: Some(colors.text_accent.opacity(0.5)),
162                    thickness: px(1.),
163                    ..Default::default()
164                }),
165                ..Default::default()
166            },
167            ..Default::default()
168        };
169
170        let markdown = cx.new_view(|cx| {
171            Markdown::new(
172                text,
173                markdown_style,
174                Some(self.language_registry.clone()),
175                None,
176                cx,
177            )
178        });
179        self.rendered_messages_by_id.insert(*id, markdown);
180        self.list_state.scroll_to(ListOffset {
181            item_ix: old_len,
182            offset_in_item: Pixels(0.0),
183        });
184    }
185
186    fn handle_thread_event(
187        &mut self,
188        _: Model<Thread>,
189        event: &ThreadEvent,
190        cx: &mut ViewContext<Self>,
191    ) {
192        match event {
193            ThreadEvent::ShowError(error) => {
194                self.last_error = Some(error.clone());
195            }
196            ThreadEvent::StreamedCompletion => {}
197            ThreadEvent::SummaryChanged => {}
198            ThreadEvent::StreamedAssistantText(message_id, text) => {
199                if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
200                    markdown.update(cx, |markdown, cx| {
201                        markdown.append(text, cx);
202                    });
203                }
204            }
205            ThreadEvent::MessageAdded(message_id) => {
206                if let Some(message_text) = self
207                    .thread
208                    .read(cx)
209                    .message(*message_id)
210                    .map(|message| message.text.clone())
211                {
212                    self.push_message(message_id, message_text, cx);
213                }
214
215                cx.notify();
216            }
217            ThreadEvent::UsePendingTools => {
218                let pending_tool_uses = self
219                    .thread
220                    .read(cx)
221                    .pending_tool_uses()
222                    .into_iter()
223                    .filter(|tool_use| tool_use.status.is_idle())
224                    .cloned()
225                    .collect::<Vec<_>>();
226
227                for tool_use in pending_tool_uses {
228                    if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
229                        let task = tool.run(tool_use.input, self.workspace.clone(), cx);
230
231                        self.thread.update(cx, |thread, cx| {
232                            thread.insert_tool_output(
233                                tool_use.assistant_message_id,
234                                tool_use.id.clone(),
235                                task,
236                                cx,
237                            );
238                        });
239                    }
240                }
241            }
242            ThreadEvent::ToolFinished { .. } => {}
243        }
244    }
245
246    fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
247        let message_id = self.messages[ix];
248        let Some(message) = self.thread.read(cx).message(message_id) else {
249            return Empty.into_any();
250        };
251
252        let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
253            return Empty.into_any();
254        };
255
256        let context = self.thread.read(cx).context_for_message(message_id);
257        let colors = cx.theme().colors();
258
259        let message_content = v_flex()
260            .child(div().p_2p5().text_ui(cx).child(markdown.clone()))
261            .when_some(context, |parent, context| {
262                if !context.is_empty() {
263                    parent.child(
264                        h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
265                            context
266                                .into_iter()
267                                .map(|context| ContextPill::new_added(context, false, false, None)),
268                        ),
269                    )
270                } else {
271                    parent
272                }
273            });
274
275        let styled_message = match message.role {
276            Role::User => v_flex()
277                .id(("message-container", ix))
278                .py_1()
279                .px_2p5()
280                .child(
281                    v_flex()
282                        .bg(colors.editor_background)
283                        .ml_16()
284                        .rounded_t_lg()
285                        .rounded_bl_lg()
286                        .rounded_br_none()
287                        .border_1()
288                        .border_color(colors.border)
289                        .child(
290                            h_flex()
291                                .py_1()
292                                .px_2()
293                                .bg(colors.editor_foreground.opacity(0.05))
294                                .border_b_1()
295                                .border_color(colors.border)
296                                .justify_between()
297                                .rounded_t(px(6.))
298                                .child(
299                                    h_flex()
300                                        .gap_1p5()
301                                        .child(
302                                            Icon::new(IconName::PersonCircle)
303                                                .size(IconSize::XSmall)
304                                                .color(Color::Muted),
305                                        )
306                                        .child(
307                                            Label::new("You")
308                                                .size(LabelSize::Small)
309                                                .color(Color::Muted),
310                                        ),
311                                ),
312                        )
313                        .child(message_content),
314                ),
315            Role::Assistant => div().id(("message-container", ix)).child(message_content),
316            Role::System => div().id(("message-container", ix)).py_1().px_2().child(
317                v_flex()
318                    .bg(colors.editor_background)
319                    .rounded_md()
320                    .child(message_content),
321            ),
322        };
323
324        styled_message.into_any()
325    }
326}
327
328impl Render for ActiveThread {
329    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
330        let is_streaming_completion = self.thread.read(cx).is_streaming();
331        let panel_bg = cx.theme().colors().panel_background;
332        let focus_handle = self.focus_handle.clone();
333
334        v_flex()
335            .size_full()
336            .pt_1p5()
337            .child(list(self.list_state.clone()).flex_grow())
338            .when(is_streaming_completion, |parent| {
339                parent.child(
340                    h_flex()
341                        .w_full()
342                        .pb_2p5()
343                        .absolute()
344                        .bottom_0()
345                        .flex_shrink()
346                        .justify_center()
347                        .bg(linear_gradient(
348                            180.,
349                            linear_color_stop(panel_bg.opacity(0.0), 0.),
350                            linear_color_stop(panel_bg, 1.),
351                        ))
352                        .child(
353                            h_flex()
354                                .flex_none()
355                                .p_1p5()
356                                .bg(cx.theme().colors().editor_background)
357                                .border_1()
358                                .border_color(cx.theme().colors().border)
359                                .rounded_md()
360                                .shadow_lg()
361                                .gap_1()
362                                .child(
363                                    Icon::new(IconName::ArrowCircle)
364                                        .size(IconSize::Small)
365                                        .color(Color::Muted)
366                                        .with_animation(
367                                            "arrow-circle",
368                                            Animation::new(Duration::from_secs(2)).repeat(),
369                                            |icon, delta| {
370                                                icon.transform(Transformation::rotate(percentage(
371                                                    delta,
372                                                )))
373                                            },
374                                        ),
375                                )
376                                .child(
377                                    Label::new("Generating…")
378                                        .size(LabelSize::Small)
379                                        .color(Color::Muted),
380                                )
381                                .child(Divider::vertical())
382                                .child(
383                                    Button::new("cancel-generation", "Cancel")
384                                        .label_size(LabelSize::Small)
385                                        .key_binding(KeyBinding::for_action_in(
386                                            &editor::actions::Cancel,
387                                            &self.focus_handle,
388                                            cx,
389                                        ))
390                                        .on_click(move |_event, cx| {
391                                            focus_handle
392                                                .dispatch_action(&editor::actions::Cancel, cx);
393                                        }),
394                                ),
395                        ),
396                )
397            })
398    }
399}