active_thread.rs

  1use std::sync::Arc;
  2
  3use assistant_tool::ToolWorkingSet;
  4use collections::HashMap;
  5use gpui::{
  6    list, AbsoluteLength, AnyElement, AppContext, DefiniteLength, EdgesRefinement, Empty, Length,
  7    ListAlignment, ListOffset, ListState, Model, StyleRefinement, Subscription,
  8    TextStyleRefinement, UnderlineStyle, View, WeakView,
  9};
 10use language::LanguageRegistry;
 11use language_model::Role;
 12use markdown::{Markdown, MarkdownStyle};
 13use settings::Settings as _;
 14use theme::ThemeSettings;
 15use ui::prelude::*;
 16use workspace::Workspace;
 17
 18use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent};
 19use crate::ui::ContextPill;
 20
 21pub struct ActiveThread {
 22    workspace: WeakView<Workspace>,
 23    language_registry: Arc<LanguageRegistry>,
 24    tools: Arc<ToolWorkingSet>,
 25    pub(crate) thread: Model<Thread>,
 26    messages: Vec<MessageId>,
 27    list_state: ListState,
 28    rendered_messages_by_id: HashMap<MessageId, View<Markdown>>,
 29    last_error: Option<ThreadError>,
 30    _subscriptions: Vec<Subscription>,
 31}
 32
 33impl ActiveThread {
 34    pub fn new(
 35        thread: Model<Thread>,
 36        workspace: WeakView<Workspace>,
 37        language_registry: Arc<LanguageRegistry>,
 38        tools: Arc<ToolWorkingSet>,
 39        cx: &mut ViewContext<Self>,
 40    ) -> Self {
 41        let subscriptions = vec![
 42            cx.observe(&thread, |_, _, cx| cx.notify()),
 43            cx.subscribe(&thread, Self::handle_thread_event),
 44        ];
 45
 46        let mut this = Self {
 47            workspace,
 48            language_registry,
 49            tools,
 50            thread: thread.clone(),
 51            messages: Vec::new(),
 52            rendered_messages_by_id: HashMap::default(),
 53            list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
 54                let this = cx.view().downgrade();
 55                move |ix, cx: &mut WindowContext| {
 56                    this.update(cx, |this, cx| this.render_message(ix, cx))
 57                        .unwrap()
 58                }
 59            }),
 60            last_error: None,
 61            _subscriptions: subscriptions,
 62        };
 63
 64        for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
 65            this.push_message(&message.id, message.text.clone(), cx);
 66        }
 67
 68        this
 69    }
 70
 71    pub fn is_empty(&self) -> bool {
 72        self.messages.is_empty()
 73    }
 74
 75    pub fn summary(&self, cx: &AppContext) -> Option<SharedString> {
 76        self.thread.read(cx).summary()
 77    }
 78
 79    pub fn last_error(&self) -> Option<ThreadError> {
 80        self.last_error.clone()
 81    }
 82
 83    pub fn clear_last_error(&mut self) {
 84        self.last_error.take();
 85    }
 86
 87    fn push_message(&mut self, id: &MessageId, text: String, cx: &mut ViewContext<Self>) {
 88        let old_len = self.messages.len();
 89        self.messages.push(*id);
 90        self.list_state.splice(old_len..old_len, 1);
 91
 92        let theme_settings = ThemeSettings::get_global(cx);
 93        let colors = cx.theme().colors();
 94        let ui_font_size = TextSize::Default.rems(cx);
 95        let buffer_font_size = TextSize::Small.rems(cx);
 96        let mut text_style = cx.text_style();
 97
 98        text_style.refine(&TextStyleRefinement {
 99            font_family: Some(theme_settings.ui_font.family.clone()),
100            font_size: Some(ui_font_size.into()),
101            color: Some(cx.theme().colors().text),
102            ..Default::default()
103        });
104
105        let markdown_style = MarkdownStyle {
106            base_text_style: text_style,
107            syntax: cx.theme().syntax().clone(),
108            selection_background_color: cx.theme().players().local().selection,
109            code_block: StyleRefinement {
110                margin: EdgesRefinement {
111                    top: Some(Length::Definite(rems(1.0).into())),
112                    left: Some(Length::Definite(rems(0.).into())),
113                    right: Some(Length::Definite(rems(0.).into())),
114                    bottom: Some(Length::Definite(rems(1.).into())),
115                },
116                padding: EdgesRefinement {
117                    top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
118                    left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
119                    right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
120                    bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
121                },
122                background: Some(colors.editor_foreground.opacity(0.01).into()),
123                border_color: Some(colors.border_variant.opacity(0.3)),
124                border_widths: EdgesRefinement {
125                    top: Some(AbsoluteLength::Pixels(Pixels(1.0))),
126                    left: Some(AbsoluteLength::Pixels(Pixels(1.))),
127                    right: Some(AbsoluteLength::Pixels(Pixels(1.))),
128                    bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
129                },
130                text: Some(TextStyleRefinement {
131                    font_family: Some(theme_settings.buffer_font.family.clone()),
132                    font_size: Some(buffer_font_size.into()),
133                    ..Default::default()
134                }),
135                ..Default::default()
136            },
137            inline_code: TextStyleRefinement {
138                font_family: Some(theme_settings.buffer_font.family.clone()),
139                font_size: Some(buffer_font_size.into()),
140                background_color: Some(colors.editor_foreground.opacity(0.01)),
141                ..Default::default()
142            },
143            link: TextStyleRefinement {
144                background_color: Some(colors.editor_foreground.opacity(0.025)),
145                underline: Some(UnderlineStyle {
146                    color: Some(colors.text_accent.opacity(0.5)),
147                    thickness: px(1.),
148                    ..Default::default()
149                }),
150                ..Default::default()
151            },
152            ..Default::default()
153        };
154
155        let markdown = cx.new_view(|cx| {
156            Markdown::new(
157                text,
158                markdown_style,
159                Some(self.language_registry.clone()),
160                None,
161                cx,
162            )
163        });
164        self.rendered_messages_by_id.insert(*id, markdown);
165        self.list_state.scroll_to(ListOffset {
166            item_ix: old_len,
167            offset_in_item: Pixels(0.0),
168        });
169    }
170
171    fn handle_thread_event(
172        &mut self,
173        _: Model<Thread>,
174        event: &ThreadEvent,
175        cx: &mut ViewContext<Self>,
176    ) {
177        match event {
178            ThreadEvent::ShowError(error) => {
179                self.last_error = Some(error.clone());
180            }
181            ThreadEvent::StreamedCompletion => {}
182            ThreadEvent::SummaryChanged => {}
183            ThreadEvent::StreamedAssistantText(message_id, text) => {
184                if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
185                    markdown.update(cx, |markdown, cx| {
186                        markdown.append(text, cx);
187                    });
188                }
189            }
190            ThreadEvent::MessageAdded(message_id) => {
191                if let Some(message_text) = self
192                    .thread
193                    .read(cx)
194                    .message(*message_id)
195                    .map(|message| message.text.clone())
196                {
197                    self.push_message(message_id, message_text, cx);
198                }
199
200                cx.notify();
201            }
202            ThreadEvent::UsePendingTools => {
203                let pending_tool_uses = self
204                    .thread
205                    .read(cx)
206                    .pending_tool_uses()
207                    .into_iter()
208                    .filter(|tool_use| tool_use.status.is_idle())
209                    .cloned()
210                    .collect::<Vec<_>>();
211
212                for tool_use in pending_tool_uses {
213                    if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
214                        let task = tool.run(tool_use.input, self.workspace.clone(), cx);
215
216                        self.thread.update(cx, |thread, cx| {
217                            thread.insert_tool_output(
218                                tool_use.assistant_message_id,
219                                tool_use.id.clone(),
220                                task,
221                                cx,
222                            );
223                        });
224                    }
225                }
226            }
227            ThreadEvent::ToolFinished { .. } => {}
228        }
229    }
230
231    fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
232        let message_id = self.messages[ix];
233        let Some(message) = self.thread.read(cx).message(message_id) else {
234            return Empty.into_any();
235        };
236
237        let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
238            return Empty.into_any();
239        };
240
241        let context = self.thread.read(cx).context_for_message(message_id);
242        let colors = cx.theme().colors();
243
244        let (role_icon, role_name, role_color) = match message.role {
245            Role::User => (IconName::Person, "You", Color::Muted),
246            Role::Assistant => (IconName::ZedAssistant, "Assistant", Color::Accent),
247            Role::System => (IconName::Settings, "System", Color::Default),
248        };
249
250        div()
251            .id(("message-container", ix))
252            .py_1()
253            .px_2()
254            .child(
255                v_flex()
256                    .border_1()
257                    .border_color(colors.border_variant)
258                    .bg(colors.editor_background)
259                    .rounded_md()
260                    .child(
261                        h_flex()
262                            .py_1p5()
263                            .px_2p5()
264                            .border_b_1()
265                            .border_color(colors.border_variant)
266                            .justify_between()
267                            .child(
268                                h_flex()
269                                    .gap_1p5()
270                                    .child(
271                                        Icon::new(role_icon)
272                                            .size(IconSize::XSmall)
273                                            .color(role_color),
274                                    )
275                                    .child(
276                                        Label::new(role_name)
277                                            .size(LabelSize::XSmall)
278                                            .color(role_color),
279                                    ),
280                            ),
281                    )
282                    .child(div().p_2p5().text_ui(cx).child(markdown.clone()))
283                    .when_some(context, |parent, context| {
284                        if !context.is_empty() {
285                            parent.child(h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
286                                context.iter().map(|context| {
287                                    ContextPill::new_added(context.clone(), false, None)
288                                }),
289                            ))
290                        } else {
291                            parent
292                        }
293                    }),
294            )
295            .into_any()
296    }
297}
298
299impl Render for ActiveThread {
300    fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
301        list(self.list_state.clone()).flex_1().py_1()
302    }
303}