active_thread.rs

  1use std::sync::Arc;
  2use std::time::Duration;
  3
  4use assistant_tool::ToolWorkingSet;
  5use collections::HashMap;
  6use gpui::{
  7    list, percentage, AbsoluteLength, Animation, AnimationExt, AnyElement, AppContext,
  8    DefiniteLength, EdgesRefinement, Empty, FocusHandle, Length, ListAlignment, ListOffset,
  9    ListState, Model, StyleRefinement, Subscription, TextStyleRefinement, Transformation,
 10    UnderlineStyle, View, WeakView,
 11};
 12use language::LanguageRegistry;
 13use language_model::Role;
 14use markdown::{Markdown, MarkdownStyle};
 15use settings::Settings as _;
 16use theme::ThemeSettings;
 17use ui::{prelude::*, ButtonLike, KeyBinding};
 18use workspace::Workspace;
 19
 20use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent};
 21use crate::ui::ContextPill;
 22
 23pub struct ActiveThread {
 24    workspace: WeakView<Workspace>,
 25    language_registry: Arc<LanguageRegistry>,
 26    tools: Arc<ToolWorkingSet>,
 27    pub(crate) thread: Model<Thread>,
 28    messages: Vec<MessageId>,
 29    list_state: ListState,
 30    rendered_messages_by_id: HashMap<MessageId, View<Markdown>>,
 31    last_error: Option<ThreadError>,
 32    focus_handle: FocusHandle,
 33    _subscriptions: Vec<Subscription>,
 34}
 35
 36impl ActiveThread {
 37    pub fn new(
 38        thread: Model<Thread>,
 39        workspace: WeakView<Workspace>,
 40        language_registry: Arc<LanguageRegistry>,
 41        tools: Arc<ToolWorkingSet>,
 42        focus_handle: FocusHandle,
 43        cx: &mut ViewContext<Self>,
 44    ) -> Self {
 45        let subscriptions = vec![
 46            cx.observe(&thread, |_, _, cx| cx.notify()),
 47            cx.subscribe(&thread, Self::handle_thread_event),
 48        ];
 49
 50        let mut this = Self {
 51            workspace,
 52            language_registry,
 53            tools,
 54            thread: thread.clone(),
 55            messages: Vec::new(),
 56            rendered_messages_by_id: HashMap::default(),
 57            list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
 58                let this = cx.view().downgrade();
 59                move |ix, cx: &mut WindowContext| {
 60                    this.update(cx, |this, cx| this.render_message(ix, cx))
 61                        .unwrap()
 62                }
 63            }),
 64            last_error: None,
 65            focus_handle,
 66            _subscriptions: subscriptions,
 67        };
 68
 69        for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
 70            this.push_message(&message.id, message.text.clone(), cx);
 71        }
 72
 73        this
 74    }
 75
 76    pub fn is_empty(&self) -> bool {
 77        self.messages.is_empty()
 78    }
 79
 80    pub fn summary(&self, cx: &AppContext) -> Option<SharedString> {
 81        self.thread.read(cx).summary()
 82    }
 83
 84    pub fn summary_or_default(&self, cx: &AppContext) -> SharedString {
 85        self.thread.read(cx).summary_or_default()
 86    }
 87
 88    pub fn cancel_last_completion(&mut self, cx: &mut AppContext) -> bool {
 89        self.last_error.take();
 90        self.thread
 91            .update(cx, |thread, _cx| thread.cancel_last_completion())
 92    }
 93
 94    pub fn last_error(&self) -> Option<ThreadError> {
 95        self.last_error.clone()
 96    }
 97
 98    pub fn clear_last_error(&mut self) {
 99        self.last_error.take();
100    }
101
102    fn push_message(&mut self, id: &MessageId, text: String, cx: &mut ViewContext<Self>) {
103        let old_len = self.messages.len();
104        self.messages.push(*id);
105        self.list_state.splice(old_len..old_len, 1);
106
107        let theme_settings = ThemeSettings::get_global(cx);
108        let colors = cx.theme().colors();
109        let ui_font_size = TextSize::Default.rems(cx);
110        let buffer_font_size = TextSize::Small.rems(cx);
111        let mut text_style = cx.text_style();
112
113        text_style.refine(&TextStyleRefinement {
114            font_family: Some(theme_settings.ui_font.family.clone()),
115            font_size: Some(ui_font_size.into()),
116            color: Some(cx.theme().colors().text),
117            ..Default::default()
118        });
119
120        let markdown_style = MarkdownStyle {
121            base_text_style: text_style,
122            syntax: cx.theme().syntax().clone(),
123            selection_background_color: cx.theme().players().local().selection,
124            code_block: StyleRefinement {
125                margin: EdgesRefinement {
126                    top: Some(Length::Definite(rems(1.0).into())),
127                    left: Some(Length::Definite(rems(0.).into())),
128                    right: Some(Length::Definite(rems(0.).into())),
129                    bottom: Some(Length::Definite(rems(1.).into())),
130                },
131                padding: EdgesRefinement {
132                    top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
133                    left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
134                    right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
135                    bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
136                },
137                background: Some(colors.editor_foreground.opacity(0.01).into()),
138                border_color: Some(colors.border_variant.opacity(0.3)),
139                border_widths: EdgesRefinement {
140                    top: Some(AbsoluteLength::Pixels(Pixels(1.0))),
141                    left: Some(AbsoluteLength::Pixels(Pixels(1.))),
142                    right: Some(AbsoluteLength::Pixels(Pixels(1.))),
143                    bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
144                },
145                text: Some(TextStyleRefinement {
146                    font_family: Some(theme_settings.buffer_font.family.clone()),
147                    font_size: Some(buffer_font_size.into()),
148                    ..Default::default()
149                }),
150                ..Default::default()
151            },
152            inline_code: TextStyleRefinement {
153                font_family: Some(theme_settings.buffer_font.family.clone()),
154                font_size: Some(buffer_font_size.into()),
155                background_color: Some(colors.editor_foreground.opacity(0.1)),
156                ..Default::default()
157            },
158            link: TextStyleRefinement {
159                background_color: Some(colors.editor_foreground.opacity(0.025)),
160                underline: Some(UnderlineStyle {
161                    color: Some(colors.text_accent.opacity(0.5)),
162                    thickness: px(1.),
163                    ..Default::default()
164                }),
165                ..Default::default()
166            },
167            ..Default::default()
168        };
169
170        let markdown = cx.new_view(|cx| {
171            Markdown::new(
172                text,
173                markdown_style,
174                Some(self.language_registry.clone()),
175                None,
176                cx,
177            )
178        });
179        self.rendered_messages_by_id.insert(*id, markdown);
180        self.list_state.scroll_to(ListOffset {
181            item_ix: old_len,
182            offset_in_item: Pixels(0.0),
183        });
184    }
185
186    fn handle_thread_event(
187        &mut self,
188        _: Model<Thread>,
189        event: &ThreadEvent,
190        cx: &mut ViewContext<Self>,
191    ) {
192        match event {
193            ThreadEvent::ShowError(error) => {
194                self.last_error = Some(error.clone());
195            }
196            ThreadEvent::StreamedCompletion => {}
197            ThreadEvent::SummaryChanged => {}
198            ThreadEvent::StreamedAssistantText(message_id, text) => {
199                if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
200                    markdown.update(cx, |markdown, cx| {
201                        markdown.append(text, cx);
202                    });
203                }
204            }
205            ThreadEvent::MessageAdded(message_id) => {
206                if let Some(message_text) = self
207                    .thread
208                    .read(cx)
209                    .message(*message_id)
210                    .map(|message| message.text.clone())
211                {
212                    self.push_message(message_id, message_text, cx);
213                }
214
215                cx.notify();
216            }
217            ThreadEvent::UsePendingTools => {
218                let pending_tool_uses = self
219                    .thread
220                    .read(cx)
221                    .pending_tool_uses()
222                    .into_iter()
223                    .filter(|tool_use| tool_use.status.is_idle())
224                    .cloned()
225                    .collect::<Vec<_>>();
226
227                for tool_use in pending_tool_uses {
228                    if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
229                        let task = tool.run(tool_use.input, self.workspace.clone(), cx);
230
231                        self.thread.update(cx, |thread, cx| {
232                            thread.insert_tool_output(
233                                tool_use.assistant_message_id,
234                                tool_use.id.clone(),
235                                task,
236                                cx,
237                            );
238                        });
239                    }
240                }
241            }
242            ThreadEvent::ToolFinished { .. } => {}
243        }
244    }
245
246    fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
247        let message_id = self.messages[ix];
248        let is_last_message = ix == self.messages.len() - 1;
249        let Some(message) = self.thread.read(cx).message(message_id) else {
250            return Empty.into_any();
251        };
252
253        let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
254            return Empty.into_any();
255        };
256
257        let is_streaming_completion = self.thread.read(cx).is_streaming();
258        let context = self.thread.read(cx).context_for_message(message_id);
259        let colors = cx.theme().colors();
260
261        let (role_icon, role_name, role_color) = match message.role {
262            Role::User => (IconName::Person, "You", Color::Muted),
263            Role::Assistant => (IconName::ZedAssistant, "Assistant", Color::Accent),
264            Role::System => (IconName::Settings, "System", Color::Default),
265        };
266
267        div()
268            .id(("message-container", ix))
269            .py_1()
270            .px_2()
271            .child(
272                v_flex()
273                    .border_1()
274                    .border_color(colors.border_variant)
275                    .bg(colors.editor_background)
276                    .rounded_md()
277                    .child(
278                        h_flex()
279                            .py_1p5()
280                            .px_2p5()
281                            .border_b_1()
282                            .border_color(colors.border_variant)
283                            .justify_between()
284                            .child(
285                                h_flex()
286                                    .gap_1p5()
287                                    .child(
288                                        Icon::new(role_icon)
289                                            .size(IconSize::XSmall)
290                                            .color(role_color),
291                                    )
292                                    .child(
293                                        Label::new(role_name)
294                                            .size(LabelSize::XSmall)
295                                            .color(role_color),
296                                    ),
297                            ),
298                    )
299                    .child(div().p_2p5().text_ui(cx).child(markdown.clone()))
300                    .when(
301                        message.role == Role::Assistant
302                            && is_last_message
303                            && is_streaming_completion,
304                        |parent| {
305                            parent.child(
306                                h_flex()
307                                    .gap_1()
308                                    .p_2p5()
309                                    .child(
310                                        Icon::new(IconName::ArrowCircle)
311                                            .size(IconSize::Small)
312                                            .color(Color::Muted)
313                                            .with_animation(
314                                                "arrow-circle",
315                                                Animation::new(Duration::from_secs(2)).repeat(),
316                                                |icon, delta| {
317                                                    icon.transform(Transformation::rotate(
318                                                        percentage(delta),
319                                                    ))
320                                                },
321                                            ),
322                                    )
323                                    .child(
324                                        Label::new("Generating…")
325                                            .size(LabelSize::Small)
326                                            .color(Color::Muted),
327                                    ),
328                            )
329                        },
330                    )
331                    .when_some(context, |parent, context| {
332                        if !context.is_empty() {
333                            parent.child(h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
334                                context.into_iter().map(|context| {
335                                    ContextPill::new_added(context, false, false, None)
336                                }),
337                            ))
338                        } else {
339                            parent
340                        }
341                    }),
342            )
343            .into_any()
344    }
345}
346
347impl Render for ActiveThread {
348    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
349        let is_streaming_completion = self.thread.read(cx).is_streaming();
350
351        let focus_handle = self.focus_handle.clone();
352
353        v_flex()
354            .size_full()
355            .child(list(self.list_state.clone()).flex_grow())
356            .child(
357                h_flex()
358                    .absolute()
359                    .bottom_1()
360                    .flex_shrink()
361                    .justify_center()
362                    .w_full()
363                    .when(is_streaming_completion, |parent| {
364                        parent.child(
365                            h_flex()
366                                .gap_2()
367                                .p_1p5()
368                                .rounded_md()
369                                .bg(cx.theme().colors().elevated_surface_background)
370                                .child(Label::new("Generating…").size(LabelSize::Small))
371                                .child(
372                                    ButtonLike::new("cancel-generation")
373                                        .style(ButtonStyle::Filled)
374                                        .child(Label::new("Cancel").size(LabelSize::Small))
375                                        .children(
376                                            KeyBinding::for_action_in(
377                                                &editor::actions::Cancel,
378                                                &self.focus_handle,
379                                                cx,
380                                            )
381                                            .map(|binding| binding.into_any_element()),
382                                        )
383                                        .on_click(move |_event, cx| {
384                                            focus_handle
385                                                .dispatch_action(&editor::actions::Cancel, cx);
386                                        }),
387                                ),
388                        )
389                    }),
390            )
391    }
392}