active_thread.rs

  1use std::sync::Arc;
  2
  3use assistant_tool::ToolWorkingSet;
  4use collections::HashMap;
  5use gpui::{
  6    list, AnyElement, AppContext, Empty, ListAlignment, ListState, Model, StyleRefinement,
  7    Subscription, TextStyleRefinement, View, WeakView,
  8};
  9use language::LanguageRegistry;
 10use language_model::Role;
 11use markdown::{Markdown, MarkdownStyle};
 12use settings::Settings as _;
 13use theme::ThemeSettings;
 14use ui::prelude::*;
 15use workspace::Workspace;
 16
 17use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent};
 18use crate::ui::ContextPill;
 19
 20pub struct ActiveThread {
 21    workspace: WeakView<Workspace>,
 22    language_registry: Arc<LanguageRegistry>,
 23    tools: Arc<ToolWorkingSet>,
 24    thread: Model<Thread>,
 25    messages: Vec<MessageId>,
 26    list_state: ListState,
 27    rendered_messages_by_id: HashMap<MessageId, View<Markdown>>,
 28    last_error: Option<ThreadError>,
 29    _subscriptions: Vec<Subscription>,
 30}
 31
 32impl ActiveThread {
 33    pub fn new(
 34        thread: Model<Thread>,
 35        workspace: WeakView<Workspace>,
 36        language_registry: Arc<LanguageRegistry>,
 37        tools: Arc<ToolWorkingSet>,
 38        cx: &mut ViewContext<Self>,
 39    ) -> Self {
 40        let subscriptions = vec![
 41            cx.observe(&thread, |_, _, cx| cx.notify()),
 42            cx.subscribe(&thread, Self::handle_thread_event),
 43        ];
 44
 45        let mut this = Self {
 46            workspace,
 47            language_registry,
 48            tools,
 49            thread: thread.clone(),
 50            messages: Vec::new(),
 51            rendered_messages_by_id: HashMap::default(),
 52            list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
 53                let this = cx.view().downgrade();
 54                move |ix, cx: &mut WindowContext| {
 55                    this.update(cx, |this, cx| this.render_message(ix, cx))
 56                        .unwrap()
 57                }
 58            }),
 59            last_error: None,
 60            _subscriptions: subscriptions,
 61        };
 62
 63        for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
 64            this.push_message(&message.id, message.text.clone(), cx);
 65        }
 66
 67        this
 68    }
 69
 70    pub fn is_empty(&self) -> bool {
 71        self.messages.is_empty()
 72    }
 73
 74    pub fn summary(&self, cx: &AppContext) -> Option<SharedString> {
 75        self.thread.read(cx).summary()
 76    }
 77
 78    pub fn last_error(&self) -> Option<ThreadError> {
 79        self.last_error.clone()
 80    }
 81
 82    pub fn clear_last_error(&mut self) {
 83        self.last_error.take();
 84    }
 85
 86    fn push_message(&mut self, id: &MessageId, text: String, cx: &mut ViewContext<Self>) {
 87        let old_len = self.messages.len();
 88        self.messages.push(*id);
 89        self.list_state.splice(old_len..old_len, 1);
 90
 91        let theme_settings = ThemeSettings::get_global(cx);
 92        let ui_font_size = TextSize::Default.rems(cx);
 93        let buffer_font_size = theme_settings.buffer_font_size;
 94
 95        let mut text_style = cx.text_style();
 96        text_style.refine(&TextStyleRefinement {
 97            font_family: Some(theme_settings.ui_font.family.clone()),
 98            font_size: Some(ui_font_size.into()),
 99            color: Some(cx.theme().colors().text),
100            ..Default::default()
101        });
102
103        let markdown_style = MarkdownStyle {
104            base_text_style: text_style,
105            syntax: cx.theme().syntax().clone(),
106            selection_background_color: cx.theme().players().local().selection,
107            code_block: StyleRefinement {
108                text: Some(TextStyleRefinement {
109                    font_family: Some(theme_settings.buffer_font.family.clone()),
110                    font_size: Some(buffer_font_size.into()),
111                    ..Default::default()
112                }),
113                ..Default::default()
114            },
115            inline_code: TextStyleRefinement {
116                font_family: Some(theme_settings.buffer_font.family.clone()),
117                font_size: Some(ui_font_size.into()),
118                background_color: Some(cx.theme().colors().editor_background),
119                ..Default::default()
120            },
121            ..Default::default()
122        };
123
124        let markdown = cx.new_view(|cx| {
125            Markdown::new(
126                text,
127                markdown_style,
128                Some(self.language_registry.clone()),
129                None,
130                cx,
131            )
132        });
133        self.rendered_messages_by_id.insert(*id, markdown);
134    }
135
136    fn handle_thread_event(
137        &mut self,
138        _: Model<Thread>,
139        event: &ThreadEvent,
140        cx: &mut ViewContext<Self>,
141    ) {
142        match event {
143            ThreadEvent::ShowError(error) => {
144                self.last_error = Some(error.clone());
145            }
146            ThreadEvent::StreamedCompletion => {}
147            ThreadEvent::SummaryChanged => {}
148            ThreadEvent::StreamedAssistantText(message_id, text) => {
149                if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
150                    markdown.update(cx, |markdown, cx| {
151                        markdown.append(text, cx);
152                    });
153                }
154            }
155            ThreadEvent::MessageAdded(message_id) => {
156                if let Some(message_text) = self
157                    .thread
158                    .read(cx)
159                    .message(*message_id)
160                    .map(|message| message.text.clone())
161                {
162                    self.push_message(message_id, message_text, cx);
163                }
164
165                cx.notify();
166            }
167            ThreadEvent::UsePendingTools => {
168                let pending_tool_uses = self
169                    .thread
170                    .read(cx)
171                    .pending_tool_uses()
172                    .into_iter()
173                    .filter(|tool_use| tool_use.status.is_idle())
174                    .cloned()
175                    .collect::<Vec<_>>();
176
177                for tool_use in pending_tool_uses {
178                    if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
179                        let task = tool.run(tool_use.input, self.workspace.clone(), cx);
180
181                        self.thread.update(cx, |thread, cx| {
182                            thread.insert_tool_output(
183                                tool_use.assistant_message_id,
184                                tool_use.id.clone(),
185                                task,
186                                cx,
187                            );
188                        });
189                    }
190                }
191            }
192            ThreadEvent::ToolFinished { .. } => {}
193        }
194    }
195
196    fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
197        let message_id = self.messages[ix];
198        let Some(message) = self.thread.read(cx).message(message_id) else {
199            return Empty.into_any();
200        };
201
202        let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
203            return Empty.into_any();
204        };
205
206        let context = self.thread.read(cx).context_for_message(message_id);
207
208        let (role_icon, role_name) = match message.role {
209            Role::User => (IconName::Person, "You"),
210            Role::Assistant => (IconName::ZedAssistant, "Assistant"),
211            Role::System => (IconName::Settings, "System"),
212        };
213
214        div()
215            .id(("message-container", ix))
216            .py_1()
217            .px_2()
218            .child(
219                v_flex()
220                    .border_1()
221                    .border_color(cx.theme().colors().border)
222                    .bg(cx.theme().colors().editor_background)
223                    .rounded_md()
224                    .child(
225                        h_flex()
226                            .justify_between()
227                            .py_1()
228                            .px_2()
229                            .border_b_1()
230                            .border_color(cx.theme().colors().border_variant)
231                            .child(
232                                h_flex()
233                                    .gap_1p5()
234                                    .child(
235                                        Icon::new(role_icon)
236                                            .size(IconSize::XSmall)
237                                            .color(Color::Muted),
238                                    )
239                                    .child(Label::new(role_name).size(LabelSize::XSmall)),
240                            ),
241                    )
242                    .child(v_flex().px_2().py_1().text_ui(cx).child(markdown.clone()))
243                    .when_some(context, |parent, context| {
244                        parent.child(
245                            h_flex().flex_wrap().gap_2().p_1p5().children(
246                                context
247                                    .iter()
248                                    .map(|context| ContextPill::new(context.clone())),
249                            ),
250                        )
251                    }),
252            )
253            .into_any()
254    }
255}
256
257impl Render for ActiveThread {
258    fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
259        list(self.list_state.clone()).flex_1().py_1()
260    }
261}