active_thread.rs

  1use std::sync::Arc;
  2
  3use assistant_tool::ToolWorkingSet;
  4use collections::HashMap;
  5use gpui::{
  6    list, AbsoluteLength, AnyElement, AppContext, DefiniteLength, EdgesRefinement, Empty, Length,
  7    ListAlignment, ListOffset, ListState, Model, StyleRefinement, Subscription,
  8    TextStyleRefinement, UnderlineStyle, View, WeakView,
  9};
 10use language::LanguageRegistry;
 11use language_model::Role;
 12use markdown::{Markdown, MarkdownStyle};
 13use settings::Settings as _;
 14use theme::ThemeSettings;
 15use ui::prelude::*;
 16use workspace::Workspace;
 17
 18use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent};
 19use crate::ui::ContextPill;
 20
 21pub struct ActiveThread {
 22    workspace: WeakView<Workspace>,
 23    language_registry: Arc<LanguageRegistry>,
 24    tools: Arc<ToolWorkingSet>,
 25    thread: Model<Thread>,
 26    messages: Vec<MessageId>,
 27    list_state: ListState,
 28    rendered_messages_by_id: HashMap<MessageId, View<Markdown>>,
 29    last_error: Option<ThreadError>,
 30    _subscriptions: Vec<Subscription>,
 31}
 32
 33impl ActiveThread {
 34    pub fn new(
 35        thread: Model<Thread>,
 36        workspace: WeakView<Workspace>,
 37        language_registry: Arc<LanguageRegistry>,
 38        tools: Arc<ToolWorkingSet>,
 39        cx: &mut ViewContext<Self>,
 40    ) -> Self {
 41        let subscriptions = vec![
 42            cx.observe(&thread, |_, _, cx| cx.notify()),
 43            cx.subscribe(&thread, Self::handle_thread_event),
 44        ];
 45
 46        let mut this = Self {
 47            workspace,
 48            language_registry,
 49            tools,
 50            thread: thread.clone(),
 51            messages: Vec::new(),
 52            rendered_messages_by_id: HashMap::default(),
 53            list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
 54                let this = cx.view().downgrade();
 55                move |ix, cx: &mut WindowContext| {
 56                    this.update(cx, |this, cx| this.render_message(ix, cx))
 57                        .unwrap()
 58                }
 59            }),
 60            last_error: None,
 61            _subscriptions: subscriptions,
 62        };
 63
 64        for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
 65            this.push_message(&message.id, message.text.clone(), cx);
 66        }
 67
 68        this
 69    }
 70
 71    pub fn thread(&self) -> &Model<Thread> {
 72        &self.thread
 73    }
 74
 75    pub fn is_empty(&self) -> bool {
 76        self.messages.is_empty()
 77    }
 78
 79    pub fn summary(&self, cx: &AppContext) -> Option<SharedString> {
 80        self.thread.read(cx).summary()
 81    }
 82
 83    pub fn summary_or_default(&self, cx: &AppContext) -> SharedString {
 84        self.thread.read(cx).summary_or_default()
 85    }
 86
 87    pub fn cancel_last_completion(&mut self, cx: &mut AppContext) -> bool {
 88        self.last_error.take();
 89        self.thread
 90            .update(cx, |thread, _cx| thread.cancel_last_completion())
 91    }
 92
 93    pub fn last_error(&self) -> Option<ThreadError> {
 94        self.last_error.clone()
 95    }
 96
 97    pub fn clear_last_error(&mut self) {
 98        self.last_error.take();
 99    }
100
101    fn push_message(&mut self, id: &MessageId, text: String, cx: &mut ViewContext<Self>) {
102        let old_len = self.messages.len();
103        self.messages.push(*id);
104        self.list_state.splice(old_len..old_len, 1);
105
106        let theme_settings = ThemeSettings::get_global(cx);
107        let colors = cx.theme().colors();
108        let ui_font_size = TextSize::Default.rems(cx);
109        let buffer_font_size = TextSize::Small.rems(cx);
110        let mut text_style = cx.text_style();
111
112        text_style.refine(&TextStyleRefinement {
113            font_family: Some(theme_settings.ui_font.family.clone()),
114            font_size: Some(ui_font_size.into()),
115            color: Some(cx.theme().colors().text),
116            ..Default::default()
117        });
118
119        let markdown_style = MarkdownStyle {
120            base_text_style: text_style,
121            syntax: cx.theme().syntax().clone(),
122            selection_background_color: cx.theme().players().local().selection,
123            code_block: StyleRefinement {
124                margin: EdgesRefinement {
125                    top: Some(Length::Definite(rems(0.).into())),
126                    left: Some(Length::Definite(rems(0.).into())),
127                    right: Some(Length::Definite(rems(0.).into())),
128                    bottom: Some(Length::Definite(rems(0.5).into())),
129                },
130                padding: EdgesRefinement {
131                    top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
132                    left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
133                    right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
134                    bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
135                },
136                background: Some(colors.editor_background.into()),
137                border_color: Some(colors.border_variant),
138                border_widths: EdgesRefinement {
139                    top: Some(AbsoluteLength::Pixels(Pixels(1.))),
140                    left: Some(AbsoluteLength::Pixels(Pixels(1.))),
141                    right: Some(AbsoluteLength::Pixels(Pixels(1.))),
142                    bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
143                },
144                text: Some(TextStyleRefinement {
145                    font_family: Some(theme_settings.buffer_font.family.clone()),
146                    font_size: Some(buffer_font_size.into()),
147                    ..Default::default()
148                }),
149                ..Default::default()
150            },
151            inline_code: TextStyleRefinement {
152                font_family: Some(theme_settings.buffer_font.family.clone()),
153                font_size: Some(buffer_font_size.into()),
154                background_color: Some(colors.editor_foreground.opacity(0.1)),
155                ..Default::default()
156            },
157            link: TextStyleRefinement {
158                background_color: Some(colors.editor_foreground.opacity(0.025)),
159                underline: Some(UnderlineStyle {
160                    color: Some(colors.text_accent.opacity(0.5)),
161                    thickness: px(1.),
162                    ..Default::default()
163                }),
164                ..Default::default()
165            },
166            ..Default::default()
167        };
168
169        let markdown = cx.new_view(|cx| {
170            Markdown::new(
171                text,
172                markdown_style,
173                Some(self.language_registry.clone()),
174                None,
175                cx,
176            )
177        });
178        self.rendered_messages_by_id.insert(*id, markdown);
179        self.list_state.scroll_to(ListOffset {
180            item_ix: old_len,
181            offset_in_item: Pixels(0.0),
182        });
183    }
184
185    fn handle_thread_event(
186        &mut self,
187        _: Model<Thread>,
188        event: &ThreadEvent,
189        cx: &mut ViewContext<Self>,
190    ) {
191        match event {
192            ThreadEvent::ShowError(error) => {
193                self.last_error = Some(error.clone());
194            }
195            ThreadEvent::StreamedCompletion => {}
196            ThreadEvent::SummaryChanged => {}
197            ThreadEvent::StreamedAssistantText(message_id, text) => {
198                if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
199                    markdown.update(cx, |markdown, cx| {
200                        markdown.append(text, cx);
201                    });
202                }
203            }
204            ThreadEvent::MessageAdded(message_id) => {
205                if let Some(message_text) = self
206                    .thread
207                    .read(cx)
208                    .message(*message_id)
209                    .map(|message| message.text.clone())
210                {
211                    self.push_message(message_id, message_text, cx);
212                }
213
214                cx.notify();
215            }
216            ThreadEvent::UsePendingTools => {
217                let pending_tool_uses = self
218                    .thread
219                    .read(cx)
220                    .pending_tool_uses()
221                    .into_iter()
222                    .filter(|tool_use| tool_use.status.is_idle())
223                    .cloned()
224                    .collect::<Vec<_>>();
225
226                for tool_use in pending_tool_uses {
227                    if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
228                        let task = tool.run(tool_use.input, self.workspace.clone(), cx);
229
230                        self.thread.update(cx, |thread, cx| {
231                            thread.insert_tool_output(
232                                tool_use.assistant_message_id,
233                                tool_use.id.clone(),
234                                task,
235                                cx,
236                            );
237                        });
238                    }
239                }
240            }
241            ThreadEvent::ToolFinished { .. } => {}
242        }
243    }
244
245    fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
246        let message_id = self.messages[ix];
247        let Some(message) = self.thread.read(cx).message(message_id) else {
248            return Empty.into_any();
249        };
250
251        let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
252            return Empty.into_any();
253        };
254
255        let context = self.thread.read(cx).context_for_message(message_id);
256        let colors = cx.theme().colors();
257
258        let message_content = v_flex()
259            .child(div().p_2p5().text_ui(cx).child(markdown.clone()))
260            .when_some(context, |parent, context| {
261                if !context.is_empty() {
262                    parent.child(
263                        h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
264                            context
265                                .into_iter()
266                                .map(|context| ContextPill::added(context, false, false, None)),
267                        ),
268                    )
269                } else {
270                    parent
271                }
272            });
273
274        let styled_message = match message.role {
275            Role::User => v_flex()
276                .id(("message-container", ix))
277                .py_1()
278                .px_2p5()
279                .child(
280                    v_flex()
281                        .bg(colors.editor_background)
282                        .rounded_lg()
283                        .border_1()
284                        .border_color(colors.border)
285                        .shadow_sm()
286                        .child(
287                            h_flex()
288                                .py_1()
289                                .px_2()
290                                .bg(colors.editor_foreground.opacity(0.05))
291                                .border_b_1()
292                                .border_color(colors.border)
293                                .justify_between()
294                                .rounded_t(px(6.))
295                                .child(
296                                    h_flex()
297                                        .gap_1p5()
298                                        .child(
299                                            Icon::new(IconName::PersonCircle)
300                                                .size(IconSize::XSmall)
301                                                .color(Color::Muted),
302                                        )
303                                        .child(
304                                            Label::new("You")
305                                                .size(LabelSize::Small)
306                                                .color(Color::Muted),
307                                        ),
308                                ),
309                        )
310                        .child(message_content),
311                ),
312            Role::Assistant => div().id(("message-container", ix)).child(message_content),
313            Role::System => div().id(("message-container", ix)).py_1().px_2().child(
314                v_flex()
315                    .bg(colors.editor_background)
316                    .rounded_md()
317                    .child(message_content),
318            ),
319        };
320
321        styled_message.into_any()
322    }
323}
324
325impl Render for ActiveThread {
326    fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
327        v_flex()
328            .size_full()
329            .pt_1p5()
330            .child(list(self.list_state.clone()).flex_grow())
331    }
332}