active_thread.rs

  1use std::sync::Arc;
  2
  3use assistant_tool::ToolWorkingSet;
  4use collections::HashMap;
  5use gpui::{
  6    list, AbsoluteLength, AnyElement, AppContext, DefiniteLength, EdgesRefinement, Empty, Length,
  7    ListAlignment, ListOffset, ListState, Model, StyleRefinement, Subscription,
  8    TextStyleRefinement, UnderlineStyle, View, WeakView,
  9};
 10use language::LanguageRegistry;
 11use language_model::Role;
 12use markdown::{Markdown, MarkdownStyle};
 13use settings::Settings as _;
 14use theme::ThemeSettings;
 15use ui::prelude::*;
 16use workspace::Workspace;
 17
 18use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent};
 19use crate::thread_store::ThreadStore;
 20use crate::ui::ContextPill;
 21
 22pub struct ActiveThread {
 23    workspace: WeakView<Workspace>,
 24    language_registry: Arc<LanguageRegistry>,
 25    tools: Arc<ToolWorkingSet>,
 26    thread_store: Model<ThreadStore>,
 27    thread: Model<Thread>,
 28    messages: Vec<MessageId>,
 29    list_state: ListState,
 30    rendered_messages_by_id: HashMap<MessageId, View<Markdown>>,
 31    last_error: Option<ThreadError>,
 32    _subscriptions: Vec<Subscription>,
 33}
 34
 35impl ActiveThread {
 36    pub fn new(
 37        thread: Model<Thread>,
 38        thread_store: Model<ThreadStore>,
 39        workspace: WeakView<Workspace>,
 40        language_registry: Arc<LanguageRegistry>,
 41        tools: Arc<ToolWorkingSet>,
 42        cx: &mut ViewContext<Self>,
 43    ) -> Self {
 44        let subscriptions = vec![
 45            cx.observe(&thread, |_, _, cx| cx.notify()),
 46            cx.subscribe(&thread, Self::handle_thread_event),
 47        ];
 48
 49        let mut this = Self {
 50            workspace,
 51            language_registry,
 52            tools,
 53            thread_store,
 54            thread: thread.clone(),
 55            messages: Vec::new(),
 56            rendered_messages_by_id: HashMap::default(),
 57            list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
 58                let this = cx.view().downgrade();
 59                move |ix, cx: &mut WindowContext| {
 60                    this.update(cx, |this, cx| this.render_message(ix, cx))
 61                        .unwrap()
 62                }
 63            }),
 64            last_error: None,
 65            _subscriptions: subscriptions,
 66        };
 67
 68        for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
 69            this.push_message(&message.id, message.text.clone(), cx);
 70        }
 71
 72        this
 73    }
 74
 75    pub fn thread(&self) -> &Model<Thread> {
 76        &self.thread
 77    }
 78
 79    pub fn is_empty(&self) -> bool {
 80        self.messages.is_empty()
 81    }
 82
 83    pub fn summary(&self, cx: &AppContext) -> Option<SharedString> {
 84        self.thread.read(cx).summary()
 85    }
 86
 87    pub fn summary_or_default(&self, cx: &AppContext) -> SharedString {
 88        self.thread.read(cx).summary_or_default()
 89    }
 90
 91    pub fn cancel_last_completion(&mut self, cx: &mut AppContext) -> bool {
 92        self.last_error.take();
 93        self.thread
 94            .update(cx, |thread, _cx| thread.cancel_last_completion())
 95    }
 96
 97    pub fn last_error(&self) -> Option<ThreadError> {
 98        self.last_error.clone()
 99    }
100
101    pub fn clear_last_error(&mut self) {
102        self.last_error.take();
103    }
104
105    fn push_message(&mut self, id: &MessageId, text: String, cx: &mut ViewContext<Self>) {
106        let old_len = self.messages.len();
107        self.messages.push(*id);
108        self.list_state.splice(old_len..old_len, 1);
109
110        let theme_settings = ThemeSettings::get_global(cx);
111        let colors = cx.theme().colors();
112        let ui_font_size = TextSize::Default.rems(cx);
113        let buffer_font_size = TextSize::Small.rems(cx);
114        let mut text_style = cx.text_style();
115
116        text_style.refine(&TextStyleRefinement {
117            font_family: Some(theme_settings.ui_font.family.clone()),
118            font_size: Some(ui_font_size.into()),
119            color: Some(cx.theme().colors().text),
120            ..Default::default()
121        });
122
123        let markdown_style = MarkdownStyle {
124            base_text_style: text_style,
125            syntax: cx.theme().syntax().clone(),
126            selection_background_color: cx.theme().players().local().selection,
127            code_block: StyleRefinement {
128                margin: EdgesRefinement {
129                    top: Some(Length::Definite(rems(0.).into())),
130                    left: Some(Length::Definite(rems(0.).into())),
131                    right: Some(Length::Definite(rems(0.).into())),
132                    bottom: Some(Length::Definite(rems(0.5).into())),
133                },
134                padding: EdgesRefinement {
135                    top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
136                    left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
137                    right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
138                    bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
139                },
140                background: Some(colors.editor_background.into()),
141                border_color: Some(colors.border_variant),
142                border_widths: EdgesRefinement {
143                    top: Some(AbsoluteLength::Pixels(Pixels(1.))),
144                    left: Some(AbsoluteLength::Pixels(Pixels(1.))),
145                    right: Some(AbsoluteLength::Pixels(Pixels(1.))),
146                    bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
147                },
148                text: Some(TextStyleRefinement {
149                    font_family: Some(theme_settings.buffer_font.family.clone()),
150                    font_size: Some(buffer_font_size.into()),
151                    ..Default::default()
152                }),
153                ..Default::default()
154            },
155            inline_code: TextStyleRefinement {
156                font_family: Some(theme_settings.buffer_font.family.clone()),
157                font_size: Some(buffer_font_size.into()),
158                background_color: Some(colors.editor_foreground.opacity(0.1)),
159                ..Default::default()
160            },
161            link: TextStyleRefinement {
162                background_color: Some(colors.editor_foreground.opacity(0.025)),
163                underline: Some(UnderlineStyle {
164                    color: Some(colors.text_accent.opacity(0.5)),
165                    thickness: px(1.),
166                    ..Default::default()
167                }),
168                ..Default::default()
169            },
170            ..Default::default()
171        };
172
173        let markdown = cx.new_view(|cx| {
174            Markdown::new(
175                text,
176                markdown_style,
177                Some(self.language_registry.clone()),
178                None,
179                cx,
180            )
181        });
182        self.rendered_messages_by_id.insert(*id, markdown);
183        self.list_state.scroll_to(ListOffset {
184            item_ix: old_len,
185            offset_in_item: Pixels(0.0),
186        });
187    }
188
189    fn handle_thread_event(
190        &mut self,
191        _: Model<Thread>,
192        event: &ThreadEvent,
193        cx: &mut ViewContext<Self>,
194    ) {
195        match event {
196            ThreadEvent::ShowError(error) => {
197                self.last_error = Some(error.clone());
198            }
199            ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => {
200                self.thread_store
201                    .update(cx, |thread_store, cx| {
202                        thread_store.save_thread(&self.thread, cx)
203                    })
204                    .detach_and_log_err(cx);
205            }
206            ThreadEvent::StreamedAssistantText(message_id, text) => {
207                if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
208                    markdown.update(cx, |markdown, cx| {
209                        markdown.append(text, cx);
210                    });
211                }
212            }
213            ThreadEvent::MessageAdded(message_id) => {
214                if let Some(message_text) = self
215                    .thread
216                    .read(cx)
217                    .message(*message_id)
218                    .map(|message| message.text.clone())
219                {
220                    self.push_message(message_id, message_text, cx);
221                }
222
223                self.thread_store
224                    .update(cx, |thread_store, cx| {
225                        thread_store.save_thread(&self.thread, cx)
226                    })
227                    .detach_and_log_err(cx);
228
229                cx.notify();
230            }
231            ThreadEvent::UsePendingTools => {
232                let pending_tool_uses = self
233                    .thread
234                    .read(cx)
235                    .pending_tool_uses()
236                    .into_iter()
237                    .filter(|tool_use| tool_use.status.is_idle())
238                    .cloned()
239                    .collect::<Vec<_>>();
240
241                for tool_use in pending_tool_uses {
242                    if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
243                        let task = tool.run(tool_use.input, self.workspace.clone(), cx);
244
245                        self.thread.update(cx, |thread, cx| {
246                            thread.insert_tool_output(
247                                tool_use.assistant_message_id,
248                                tool_use.id.clone(),
249                                task,
250                                cx,
251                            );
252                        });
253                    }
254                }
255            }
256            ThreadEvent::ToolFinished { .. } => {}
257        }
258    }
259
260    fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
261        let message_id = self.messages[ix];
262        let Some(message) = self.thread.read(cx).message(message_id) else {
263            return Empty.into_any();
264        };
265
266        let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
267            return Empty.into_any();
268        };
269
270        let context = self.thread.read(cx).context_for_message(message_id);
271        let colors = cx.theme().colors();
272
273        let message_content = v_flex()
274            .child(div().p_2p5().text_ui(cx).child(markdown.clone()))
275            .when_some(context, |parent, context| {
276                if !context.is_empty() {
277                    parent.child(
278                        h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
279                            context
280                                .into_iter()
281                                .map(|context| ContextPill::added(context, false, false, None)),
282                        ),
283                    )
284                } else {
285                    parent
286                }
287            });
288
289        let styled_message = match message.role {
290            Role::User => v_flex()
291                .id(("message-container", ix))
292                .py_1()
293                .px_2p5()
294                .child(
295                    v_flex()
296                        .bg(colors.editor_background)
297                        .rounded_lg()
298                        .border_1()
299                        .border_color(colors.border)
300                        .shadow_sm()
301                        .child(
302                            h_flex()
303                                .py_1()
304                                .px_2()
305                                .bg(colors.editor_foreground.opacity(0.05))
306                                .border_b_1()
307                                .border_color(colors.border)
308                                .justify_between()
309                                .rounded_t(px(6.))
310                                .child(
311                                    h_flex()
312                                        .gap_1p5()
313                                        .child(
314                                            Icon::new(IconName::PersonCircle)
315                                                .size(IconSize::XSmall)
316                                                .color(Color::Muted),
317                                        )
318                                        .child(
319                                            Label::new("You")
320                                                .size(LabelSize::Small)
321                                                .color(Color::Muted),
322                                        ),
323                                ),
324                        )
325                        .child(message_content),
326                ),
327            Role::Assistant => div().id(("message-container", ix)).child(message_content),
328            Role::System => div().id(("message-container", ix)).py_1().px_2().child(
329                v_flex()
330                    .bg(colors.editor_background)
331                    .rounded_md()
332                    .child(message_content),
333            ),
334        };
335
336        styled_message.into_any()
337    }
338}
339
340impl Render for ActiveThread {
341    fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
342        v_flex()
343            .size_full()
344            .pt_1p5()
345            .child(list(self.list_state.clone()).flex_grow())
346    }
347}