1use std::sync::Arc;
2
3use assistant_tool::ToolWorkingSet;
4use collections::HashMap;
5use gpui::{
6 list, AbsoluteLength, AnyElement, App, DefiniteLength, EdgesRefinement, Empty, Entity, Length,
7 ListAlignment, ListOffset, ListState, StyleRefinement, Subscription, TextStyleRefinement,
8 UnderlineStyle, WeakEntity,
9};
10use language::LanguageRegistry;
11use language_model::{LanguageModelToolUseId, Role};
12use markdown::{Markdown, MarkdownStyle};
13use settings::Settings as _;
14use theme::ThemeSettings;
15use ui::{prelude::*, Disclosure};
16use workspace::Workspace;
17
18use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent, ToolUse, ToolUseStatus};
19use crate::thread_store::ThreadStore;
20use crate::ui::ContextPill;
21
22pub struct ActiveThread {
23 workspace: WeakEntity<Workspace>,
24 language_registry: Arc<LanguageRegistry>,
25 tools: Arc<ToolWorkingSet>,
26 thread_store: Entity<ThreadStore>,
27 thread: Entity<Thread>,
28 messages: Vec<MessageId>,
29 list_state: ListState,
30 rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
31 expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
32 last_error: Option<ThreadError>,
33 _subscriptions: Vec<Subscription>,
34}
35
36impl ActiveThread {
37 pub fn new(
38 thread: Entity<Thread>,
39 thread_store: Entity<ThreadStore>,
40 workspace: WeakEntity<Workspace>,
41 language_registry: Arc<LanguageRegistry>,
42 tools: Arc<ToolWorkingSet>,
43 window: &mut Window,
44 cx: &mut Context<Self>,
45 ) -> Self {
46 let subscriptions = vec![
47 cx.observe(&thread, |_, _, cx| cx.notify()),
48 cx.subscribe_in(&thread, window, Self::handle_thread_event),
49 ];
50
51 let mut this = Self {
52 workspace,
53 language_registry,
54 tools,
55 thread_store,
56 thread: thread.clone(),
57 messages: Vec::new(),
58 rendered_messages_by_id: HashMap::default(),
59 expanded_tool_uses: HashMap::default(),
60 list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
61 let this = cx.entity().downgrade();
62 move |ix, _: &mut Window, cx: &mut App| {
63 this.update(cx, |this, cx| this.render_message(ix, cx))
64 .unwrap()
65 }
66 }),
67 last_error: None,
68 _subscriptions: subscriptions,
69 };
70
71 for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
72 this.push_message(&message.id, message.text.clone(), window, cx);
73 }
74
75 this
76 }
77
78 pub fn thread(&self) -> &Entity<Thread> {
79 &self.thread
80 }
81
82 pub fn is_empty(&self) -> bool {
83 self.messages.is_empty()
84 }
85
86 pub fn summary(&self, cx: &App) -> Option<SharedString> {
87 self.thread.read(cx).summary()
88 }
89
90 pub fn summary_or_default(&self, cx: &App) -> SharedString {
91 self.thread.read(cx).summary_or_default()
92 }
93
94 pub fn cancel_last_completion(&mut self, cx: &mut App) -> bool {
95 self.last_error.take();
96 self.thread
97 .update(cx, |thread, _cx| thread.cancel_last_completion())
98 }
99
100 pub fn last_error(&self) -> Option<ThreadError> {
101 self.last_error.clone()
102 }
103
104 pub fn clear_last_error(&mut self) {
105 self.last_error.take();
106 }
107
108 fn push_message(
109 &mut self,
110 id: &MessageId,
111 text: String,
112 window: &mut Window,
113 cx: &mut Context<Self>,
114 ) {
115 let old_len = self.messages.len();
116 self.messages.push(*id);
117 self.list_state.splice(old_len..old_len, 1);
118
119 let theme_settings = ThemeSettings::get_global(cx);
120 let colors = cx.theme().colors();
121 let ui_font_size = TextSize::Default.rems(cx);
122 let buffer_font_size = TextSize::Small.rems(cx);
123 let mut text_style = window.text_style();
124
125 text_style.refine(&TextStyleRefinement {
126 font_family: Some(theme_settings.ui_font.family.clone()),
127 font_size: Some(ui_font_size.into()),
128 color: Some(cx.theme().colors().text),
129 ..Default::default()
130 });
131
132 let markdown_style = MarkdownStyle {
133 base_text_style: text_style,
134 syntax: cx.theme().syntax().clone(),
135 selection_background_color: cx.theme().players().local().selection,
136 code_block: StyleRefinement {
137 margin: EdgesRefinement {
138 top: Some(Length::Definite(rems(0.).into())),
139 left: Some(Length::Definite(rems(0.).into())),
140 right: Some(Length::Definite(rems(0.).into())),
141 bottom: Some(Length::Definite(rems(0.5).into())),
142 },
143 padding: EdgesRefinement {
144 top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
145 left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
146 right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
147 bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
148 },
149 background: Some(colors.editor_background.into()),
150 border_color: Some(colors.border_variant),
151 border_widths: EdgesRefinement {
152 top: Some(AbsoluteLength::Pixels(Pixels(1.))),
153 left: Some(AbsoluteLength::Pixels(Pixels(1.))),
154 right: Some(AbsoluteLength::Pixels(Pixels(1.))),
155 bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
156 },
157 text: Some(TextStyleRefinement {
158 font_family: Some(theme_settings.buffer_font.family.clone()),
159 font_size: Some(buffer_font_size.into()),
160 ..Default::default()
161 }),
162 ..Default::default()
163 },
164 inline_code: TextStyleRefinement {
165 font_family: Some(theme_settings.buffer_font.family.clone()),
166 font_size: Some(buffer_font_size.into()),
167 background_color: Some(colors.editor_foreground.opacity(0.1)),
168 ..Default::default()
169 },
170 link: TextStyleRefinement {
171 background_color: Some(colors.editor_foreground.opacity(0.025)),
172 underline: Some(UnderlineStyle {
173 color: Some(colors.text_accent.opacity(0.5)),
174 thickness: px(1.),
175 ..Default::default()
176 }),
177 ..Default::default()
178 },
179 ..Default::default()
180 };
181
182 let markdown = cx.new(|cx| {
183 Markdown::new(
184 text.into(),
185 markdown_style,
186 Some(self.language_registry.clone()),
187 None,
188 cx,
189 )
190 });
191 self.rendered_messages_by_id.insert(*id, markdown);
192 self.list_state.scroll_to(ListOffset {
193 item_ix: old_len,
194 offset_in_item: Pixels(0.0),
195 });
196 }
197
198 fn handle_thread_event(
199 &mut self,
200 _: &Entity<Thread>,
201 event: &ThreadEvent,
202 window: &mut Window,
203 cx: &mut Context<Self>,
204 ) {
205 match event {
206 ThreadEvent::ShowError(error) => {
207 self.last_error = Some(error.clone());
208 }
209 ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => {
210 self.thread_store
211 .update(cx, |thread_store, cx| {
212 thread_store.save_thread(&self.thread, cx)
213 })
214 .detach_and_log_err(cx);
215 }
216 ThreadEvent::StreamedAssistantText(message_id, text) => {
217 if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
218 markdown.update(cx, |markdown, cx| {
219 markdown.append(text, cx);
220 });
221 }
222 }
223 ThreadEvent::MessageAdded(message_id) => {
224 if let Some(message_text) = self
225 .thread
226 .read(cx)
227 .message(*message_id)
228 .map(|message| message.text.clone())
229 {
230 self.push_message(message_id, message_text, window, cx);
231 }
232
233 self.thread_store
234 .update(cx, |thread_store, cx| {
235 thread_store.save_thread(&self.thread, cx)
236 })
237 .detach_and_log_err(cx);
238
239 cx.notify();
240 }
241 ThreadEvent::UsePendingTools => {
242 let pending_tool_uses = self
243 .thread
244 .read(cx)
245 .pending_tool_uses()
246 .into_iter()
247 .filter(|tool_use| tool_use.status.is_idle())
248 .cloned()
249 .collect::<Vec<_>>();
250
251 for tool_use in pending_tool_uses {
252 if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
253 let task = tool.run(tool_use.input, self.workspace.clone(), window, cx);
254
255 self.thread.update(cx, |thread, cx| {
256 thread.insert_tool_output(
257 tool_use.assistant_message_id,
258 tool_use.id.clone(),
259 task,
260 cx,
261 );
262 });
263 }
264 }
265 }
266 ThreadEvent::ToolFinished { .. } => {}
267 }
268 }
269
270 fn render_message(&self, ix: usize, cx: &mut Context<Self>) -> AnyElement {
271 let message_id = self.messages[ix];
272 let Some(message) = self.thread.read(cx).message(message_id) else {
273 return Empty.into_any();
274 };
275
276 let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
277 return Empty.into_any();
278 };
279
280 let context = self.thread.read(cx).context_for_message(message_id);
281 let tool_uses = self.thread.read(cx).tool_uses_for_message(message_id);
282 let colors = cx.theme().colors();
283
284 let message_content = v_flex()
285 .child(div().p_2p5().text_ui(cx).child(markdown.clone()))
286 .when_some(context, |parent, context| {
287 if !context.is_empty() {
288 parent.child(
289 h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
290 context
291 .into_iter()
292 .map(|context| ContextPill::added(context, false, false, None)),
293 ),
294 )
295 } else {
296 parent
297 }
298 });
299
300 let styled_message = match message.role {
301 Role::User => v_flex()
302 .id(("message-container", ix))
303 .pt_2p5()
304 .px_2p5()
305 .child(
306 v_flex()
307 .bg(colors.editor_background)
308 .rounded_lg()
309 .border_1()
310 .border_color(colors.border)
311 .shadow_sm()
312 .child(
313 h_flex()
314 .py_1()
315 .px_2()
316 .bg(colors.editor_foreground.opacity(0.05))
317 .border_b_1()
318 .border_color(colors.border)
319 .justify_between()
320 .rounded_t(px(6.))
321 .child(
322 h_flex()
323 .gap_1p5()
324 .child(
325 Icon::new(IconName::PersonCircle)
326 .size(IconSize::XSmall)
327 .color(Color::Muted),
328 )
329 .child(
330 Label::new("You")
331 .size(LabelSize::Small)
332 .color(Color::Muted),
333 ),
334 ),
335 )
336 .child(message_content),
337 ),
338 Role::Assistant => div()
339 .id(("message-container", ix))
340 .child(message_content)
341 .map(|parent| {
342 if tool_uses.is_empty() {
343 return parent;
344 }
345
346 parent.child(
347 v_flex().children(
348 tool_uses
349 .into_iter()
350 .map(|tool_use| self.render_tool_use(tool_use, cx)),
351 ),
352 )
353 }),
354 Role::System => div().id(("message-container", ix)).py_1().px_2().child(
355 v_flex()
356 .bg(colors.editor_background)
357 .rounded_md()
358 .child(message_content),
359 ),
360 };
361
362 styled_message.into_any()
363 }
364
365 fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {
366 let is_open = self
367 .expanded_tool_uses
368 .get(&tool_use.id)
369 .copied()
370 .unwrap_or_default();
371
372 v_flex().px_2p5().child(
373 v_flex()
374 .gap_1()
375 .bg(cx.theme().colors().editor_background)
376 .rounded_lg()
377 .border_1()
378 .border_color(cx.theme().colors().border)
379 .shadow_sm()
380 .child(
381 h_flex()
382 .justify_between()
383 .py_1()
384 .px_2()
385 .bg(cx.theme().colors().editor_foreground.opacity(0.05))
386 .when(is_open, |element| element.border_b_1())
387 .border_color(cx.theme().colors().border)
388 .rounded_t(px(6.))
389 .child(
390 h_flex()
391 .gap_2()
392 .child(Disclosure::new("tool-use-disclosure", is_open).on_click(
393 cx.listener({
394 let tool_use_id = tool_use.id.clone();
395 move |this, _event, _window, _cx| {
396 let is_open = this
397 .expanded_tool_uses
398 .entry(tool_use_id.clone())
399 .or_insert(false);
400
401 *is_open = !*is_open;
402 }
403 }),
404 ))
405 .child(Label::new(tool_use.name)),
406 )
407 .child(Label::new(match tool_use.status {
408 ToolUseStatus::Pending => "Pending",
409 ToolUseStatus::Running => "Running",
410 ToolUseStatus::Finished(_) => "Finished",
411 ToolUseStatus::Error(_) => "Error",
412 })),
413 )
414 .map(|parent| {
415 if !is_open {
416 return parent;
417 }
418
419 parent.child(
420 v_flex()
421 .gap_2()
422 .p_2p5()
423 .child(
424 v_flex()
425 .gap_0p5()
426 .child(Label::new("Input:"))
427 .child(Label::new(
428 serde_json::to_string_pretty(&tool_use.input)
429 .unwrap_or_default(),
430 )),
431 )
432 .map(|parent| match tool_use.status {
433 ToolUseStatus::Finished(output) => parent.child(
434 v_flex()
435 .gap_0p5()
436 .child(Label::new("Result:"))
437 .child(Label::new(output)),
438 ),
439 ToolUseStatus::Error(err) => parent.child(
440 v_flex()
441 .gap_0p5()
442 .child(Label::new("Error:"))
443 .child(Label::new(err)),
444 ),
445 ToolUseStatus::Pending | ToolUseStatus::Running => parent,
446 }),
447 )
448 }),
449 )
450 }
451}
452
453impl Render for ActiveThread {
454 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
455 v_flex()
456 .size_full()
457 .child(list(self.list_state.clone()).flex_grow())
458 }
459}