active_thread.rs

  1use std::sync::Arc;
  2use std::time::Duration;
  3
  4use assistant_tool::ToolWorkingSet;
  5use collections::HashMap;
  6use gpui::{
  7    list, percentage, AbsoluteLength, Animation, AnimationExt, AnyElement, AppContext,
  8    DefiniteLength, EdgesRefinement, Empty, Length, ListAlignment, ListOffset, ListState, Model,
  9    StyleRefinement, Subscription, TextStyleRefinement, Transformation, UnderlineStyle, View,
 10    WeakView,
 11};
 12use language::LanguageRegistry;
 13use language_model::Role;
 14use markdown::{Markdown, MarkdownStyle};
 15use settings::Settings as _;
 16use theme::ThemeSettings;
 17use ui::prelude::*;
 18use workspace::Workspace;
 19
 20use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent};
 21use crate::ui::ContextPill;
 22
 23pub struct ActiveThread {
 24    workspace: WeakView<Workspace>,
 25    language_registry: Arc<LanguageRegistry>,
 26    tools: Arc<ToolWorkingSet>,
 27    pub(crate) thread: Model<Thread>,
 28    messages: Vec<MessageId>,
 29    list_state: ListState,
 30    rendered_messages_by_id: HashMap<MessageId, View<Markdown>>,
 31    last_error: Option<ThreadError>,
 32    _subscriptions: Vec<Subscription>,
 33}
 34
 35impl ActiveThread {
 36    pub fn new(
 37        thread: Model<Thread>,
 38        workspace: WeakView<Workspace>,
 39        language_registry: Arc<LanguageRegistry>,
 40        tools: Arc<ToolWorkingSet>,
 41        cx: &mut ViewContext<Self>,
 42    ) -> Self {
 43        let subscriptions = vec![
 44            cx.observe(&thread, |_, _, cx| cx.notify()),
 45            cx.subscribe(&thread, Self::handle_thread_event),
 46        ];
 47
 48        let mut this = Self {
 49            workspace,
 50            language_registry,
 51            tools,
 52            thread: thread.clone(),
 53            messages: Vec::new(),
 54            rendered_messages_by_id: HashMap::default(),
 55            list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
 56                let this = cx.view().downgrade();
 57                move |ix, cx: &mut WindowContext| {
 58                    this.update(cx, |this, cx| this.render_message(ix, cx))
 59                        .unwrap()
 60                }
 61            }),
 62            last_error: None,
 63            _subscriptions: subscriptions,
 64        };
 65
 66        for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
 67            this.push_message(&message.id, message.text.clone(), cx);
 68        }
 69
 70        this
 71    }
 72
 73    pub fn is_empty(&self) -> bool {
 74        self.messages.is_empty()
 75    }
 76
 77    pub fn summary(&self, cx: &AppContext) -> Option<SharedString> {
 78        self.thread.read(cx).summary()
 79    }
 80
 81    pub fn summary_or_default(&self, cx: &AppContext) -> SharedString {
 82        self.thread.read(cx).summary_or_default()
 83    }
 84
 85    pub fn cancel_last_completion(&mut self, cx: &mut AppContext) -> bool {
 86        self.last_error.take();
 87        self.thread
 88            .update(cx, |thread, _cx| thread.cancel_last_completion())
 89    }
 90
 91    pub fn last_error(&self) -> Option<ThreadError> {
 92        self.last_error.clone()
 93    }
 94
 95    pub fn clear_last_error(&mut self) {
 96        self.last_error.take();
 97    }
 98
 99    fn push_message(&mut self, id: &MessageId, text: String, cx: &mut ViewContext<Self>) {
100        let old_len = self.messages.len();
101        self.messages.push(*id);
102        self.list_state.splice(old_len..old_len, 1);
103
104        let theme_settings = ThemeSettings::get_global(cx);
105        let colors = cx.theme().colors();
106        let ui_font_size = TextSize::Default.rems(cx);
107        let buffer_font_size = TextSize::Small.rems(cx);
108        let mut text_style = cx.text_style();
109
110        text_style.refine(&TextStyleRefinement {
111            font_family: Some(theme_settings.ui_font.family.clone()),
112            font_size: Some(ui_font_size.into()),
113            color: Some(cx.theme().colors().text),
114            ..Default::default()
115        });
116
117        let markdown_style = MarkdownStyle {
118            base_text_style: text_style,
119            syntax: cx.theme().syntax().clone(),
120            selection_background_color: cx.theme().players().local().selection,
121            code_block: StyleRefinement {
122                margin: EdgesRefinement {
123                    top: Some(Length::Definite(rems(1.0).into())),
124                    left: Some(Length::Definite(rems(0.).into())),
125                    right: Some(Length::Definite(rems(0.).into())),
126                    bottom: Some(Length::Definite(rems(1.).into())),
127                },
128                padding: EdgesRefinement {
129                    top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
130                    left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
131                    right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
132                    bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
133                },
134                background: Some(colors.editor_foreground.opacity(0.01).into()),
135                border_color: Some(colors.border_variant.opacity(0.3)),
136                border_widths: EdgesRefinement {
137                    top: Some(AbsoluteLength::Pixels(Pixels(1.0))),
138                    left: Some(AbsoluteLength::Pixels(Pixels(1.))),
139                    right: Some(AbsoluteLength::Pixels(Pixels(1.))),
140                    bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
141                },
142                text: Some(TextStyleRefinement {
143                    font_family: Some(theme_settings.buffer_font.family.clone()),
144                    font_size: Some(buffer_font_size.into()),
145                    ..Default::default()
146                }),
147                ..Default::default()
148            },
149            inline_code: TextStyleRefinement {
150                font_family: Some(theme_settings.buffer_font.family.clone()),
151                font_size: Some(buffer_font_size.into()),
152                background_color: Some(colors.editor_foreground.opacity(0.1)),
153                ..Default::default()
154            },
155            link: TextStyleRefinement {
156                background_color: Some(colors.editor_foreground.opacity(0.025)),
157                underline: Some(UnderlineStyle {
158                    color: Some(colors.text_accent.opacity(0.5)),
159                    thickness: px(1.),
160                    ..Default::default()
161                }),
162                ..Default::default()
163            },
164            ..Default::default()
165        };
166
167        let markdown = cx.new_view(|cx| {
168            Markdown::new(
169                text,
170                markdown_style,
171                Some(self.language_registry.clone()),
172                None,
173                cx,
174            )
175        });
176        self.rendered_messages_by_id.insert(*id, markdown);
177        self.list_state.scroll_to(ListOffset {
178            item_ix: old_len,
179            offset_in_item: Pixels(0.0),
180        });
181    }
182
183    fn handle_thread_event(
184        &mut self,
185        _: Model<Thread>,
186        event: &ThreadEvent,
187        cx: &mut ViewContext<Self>,
188    ) {
189        match event {
190            ThreadEvent::ShowError(error) => {
191                self.last_error = Some(error.clone());
192            }
193            ThreadEvent::StreamedCompletion => {}
194            ThreadEvent::SummaryChanged => {}
195            ThreadEvent::StreamedAssistantText(message_id, text) => {
196                if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
197                    markdown.update(cx, |markdown, cx| {
198                        markdown.append(text, cx);
199                    });
200                }
201            }
202            ThreadEvent::MessageAdded(message_id) => {
203                if let Some(message_text) = self
204                    .thread
205                    .read(cx)
206                    .message(*message_id)
207                    .map(|message| message.text.clone())
208                {
209                    self.push_message(message_id, message_text, cx);
210                }
211
212                cx.notify();
213            }
214            ThreadEvent::UsePendingTools => {
215                let pending_tool_uses = self
216                    .thread
217                    .read(cx)
218                    .pending_tool_uses()
219                    .into_iter()
220                    .filter(|tool_use| tool_use.status.is_idle())
221                    .cloned()
222                    .collect::<Vec<_>>();
223
224                for tool_use in pending_tool_uses {
225                    if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
226                        let task = tool.run(tool_use.input, self.workspace.clone(), cx);
227
228                        self.thread.update(cx, |thread, cx| {
229                            thread.insert_tool_output(
230                                tool_use.assistant_message_id,
231                                tool_use.id.clone(),
232                                task,
233                                cx,
234                            );
235                        });
236                    }
237                }
238            }
239            ThreadEvent::ToolFinished { .. } => {}
240        }
241    }
242
243    fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
244        let message_id = self.messages[ix];
245        let is_last_message = ix == self.messages.len() - 1;
246        let Some(message) = self.thread.read(cx).message(message_id) else {
247            return Empty.into_any();
248        };
249
250        let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
251            return Empty.into_any();
252        };
253
254        let is_streaming_completion = self.thread.read(cx).is_streaming();
255        let context = self.thread.read(cx).context_for_message(message_id);
256        let colors = cx.theme().colors();
257
258        let (role_icon, role_name, role_color) = match message.role {
259            Role::User => (IconName::Person, "You", Color::Muted),
260            Role::Assistant => (IconName::ZedAssistant, "Assistant", Color::Accent),
261            Role::System => (IconName::Settings, "System", Color::Default),
262        };
263
264        div()
265            .id(("message-container", ix))
266            .py_1()
267            .px_2()
268            .child(
269                v_flex()
270                    .border_1()
271                    .border_color(colors.border_variant)
272                    .bg(colors.editor_background)
273                    .rounded_md()
274                    .child(
275                        h_flex()
276                            .py_1p5()
277                            .px_2p5()
278                            .border_b_1()
279                            .border_color(colors.border_variant)
280                            .justify_between()
281                            .child(
282                                h_flex()
283                                    .gap_1p5()
284                                    .child(
285                                        Icon::new(role_icon)
286                                            .size(IconSize::XSmall)
287                                            .color(role_color),
288                                    )
289                                    .child(
290                                        Label::new(role_name)
291                                            .size(LabelSize::XSmall)
292                                            .color(role_color),
293                                    ),
294                            ),
295                    )
296                    .child(div().p_2p5().text_ui(cx).child(markdown.clone()))
297                    .when(
298                        message.role == Role::Assistant
299                            && is_last_message
300                            && is_streaming_completion,
301                        |parent| {
302                            parent.child(
303                                h_flex()
304                                    .gap_1()
305                                    .p_2p5()
306                                    .child(
307                                        Icon::new(IconName::ArrowCircle)
308                                            .size(IconSize::Small)
309                                            .color(Color::Muted)
310                                            .with_animation(
311                                                "arrow-circle",
312                                                Animation::new(Duration::from_secs(2)).repeat(),
313                                                |icon, delta| {
314                                                    icon.transform(Transformation::rotate(
315                                                        percentage(delta),
316                                                    ))
317                                                },
318                                            ),
319                                    )
320                                    .child(
321                                        Label::new("Generating…")
322                                            .size(LabelSize::Small)
323                                            .color(Color::Muted),
324                                    ),
325                            )
326                        },
327                    )
328                    .when_some(context, |parent, context| {
329                        if !context.is_empty() {
330                            parent.child(h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
331                                context.into_iter().map(|context| {
332                                    ContextPill::new_added(context, false, false, None)
333                                }),
334                            ))
335                        } else {
336                            parent
337                        }
338                    }),
339            )
340            .into_any()
341    }
342}
343
344impl Render for ActiveThread {
345    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
346        let is_streaming_completion = self.thread.read(cx).is_streaming();
347
348        v_flex()
349            .size_full()
350            .child(list(self.list_state.clone()).flex_grow())
351            .child(
352                h_flex()
353                    .absolute()
354                    .bottom_1()
355                    .flex_shrink()
356                    .justify_center()
357                    .w_full()
358                    .when(is_streaming_completion, |parent| {
359                        parent.child(
360                            h_flex()
361                                .gap_2()
362                                .p_1p5()
363                                .rounded_md()
364                                .bg(cx.theme().colors().elevated_surface_background)
365                                .child(Label::new("Generating…").size(LabelSize::Small))
366                                .child(Label::new("esc to cancel").size(LabelSize::Small)),
367                        )
368                    }),
369            )
370    }
371}