1use std::sync::Arc;
2
3use assistant_tool::ToolWorkingSet;
4use collections::HashMap;
5use gpui::{
6 list, AbsoluteLength, AnyElement, App, DefiniteLength, EdgesRefinement, Empty, Entity, Length,
7 ListAlignment, ListOffset, ListState, StyleRefinement, Subscription, TextStyleRefinement,
8 UnderlineStyle, WeakEntity,
9};
10use language::LanguageRegistry;
11use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
12use markdown::{Markdown, MarkdownStyle};
13use settings::Settings as _;
14use theme::ThemeSettings;
15use ui::{prelude::*, Disclosure};
16use workspace::Workspace;
17
18use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
19use crate::thread_store::ThreadStore;
20use crate::tool_use::{ToolUse, ToolUseStatus};
21use crate::ui::ContextPill;
22
23pub struct ActiveThread {
24 workspace: WeakEntity<Workspace>,
25 language_registry: Arc<LanguageRegistry>,
26 tools: Arc<ToolWorkingSet>,
27 thread_store: Entity<ThreadStore>,
28 thread: Entity<Thread>,
29 messages: Vec<MessageId>,
30 list_state: ListState,
31 rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
32 expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
33 last_error: Option<ThreadError>,
34 _subscriptions: Vec<Subscription>,
35}
36
37impl ActiveThread {
38 pub fn new(
39 thread: Entity<Thread>,
40 thread_store: Entity<ThreadStore>,
41 workspace: WeakEntity<Workspace>,
42 language_registry: Arc<LanguageRegistry>,
43 tools: Arc<ToolWorkingSet>,
44 window: &mut Window,
45 cx: &mut Context<Self>,
46 ) -> Self {
47 let subscriptions = vec![
48 cx.observe(&thread, |_, _, cx| cx.notify()),
49 cx.subscribe_in(&thread, window, Self::handle_thread_event),
50 ];
51
52 let mut this = Self {
53 workspace,
54 language_registry,
55 tools,
56 thread_store,
57 thread: thread.clone(),
58 messages: Vec::new(),
59 rendered_messages_by_id: HashMap::default(),
60 expanded_tool_uses: HashMap::default(),
61 list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
62 let this = cx.entity().downgrade();
63 move |ix, _: &mut Window, cx: &mut App| {
64 this.update(cx, |this, cx| this.render_message(ix, cx))
65 .unwrap()
66 }
67 }),
68 last_error: None,
69 _subscriptions: subscriptions,
70 };
71
72 for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
73 this.push_message(&message.id, message.text.clone(), window, cx);
74 }
75
76 this
77 }
78
79 pub fn thread(&self) -> &Entity<Thread> {
80 &self.thread
81 }
82
83 pub fn is_empty(&self) -> bool {
84 self.messages.is_empty()
85 }
86
87 pub fn summary(&self, cx: &App) -> Option<SharedString> {
88 self.thread.read(cx).summary()
89 }
90
91 pub fn summary_or_default(&self, cx: &App) -> SharedString {
92 self.thread.read(cx).summary_or_default()
93 }
94
95 pub fn cancel_last_completion(&mut self, cx: &mut App) -> bool {
96 self.last_error.take();
97 self.thread
98 .update(cx, |thread, _cx| thread.cancel_last_completion())
99 }
100
101 pub fn last_error(&self) -> Option<ThreadError> {
102 self.last_error.clone()
103 }
104
105 pub fn clear_last_error(&mut self) {
106 self.last_error.take();
107 }
108
109 fn push_message(
110 &mut self,
111 id: &MessageId,
112 text: String,
113 window: &mut Window,
114 cx: &mut Context<Self>,
115 ) {
116 let old_len = self.messages.len();
117 self.messages.push(*id);
118 self.list_state.splice(old_len..old_len, 1);
119
120 let theme_settings = ThemeSettings::get_global(cx);
121 let colors = cx.theme().colors();
122 let ui_font_size = TextSize::Default.rems(cx);
123 let buffer_font_size = TextSize::Small.rems(cx);
124 let mut text_style = window.text_style();
125
126 text_style.refine(&TextStyleRefinement {
127 font_family: Some(theme_settings.ui_font.family.clone()),
128 font_size: Some(ui_font_size.into()),
129 color: Some(cx.theme().colors().text),
130 ..Default::default()
131 });
132
133 let markdown_style = MarkdownStyle {
134 base_text_style: text_style,
135 syntax: cx.theme().syntax().clone(),
136 selection_background_color: cx.theme().players().local().selection,
137 code_block_overflow_x_scroll: true,
138 table_overflow_x_scroll: true,
139 code_block: StyleRefinement {
140 margin: EdgesRefinement {
141 top: Some(Length::Definite(rems(0.).into())),
142 left: Some(Length::Definite(rems(0.).into())),
143 right: Some(Length::Definite(rems(0.).into())),
144 bottom: Some(Length::Definite(rems(0.5).into())),
145 },
146 padding: EdgesRefinement {
147 top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
148 left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
149 right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
150 bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
151 },
152 background: Some(colors.editor_background.into()),
153 border_color: Some(colors.border_variant),
154 border_widths: EdgesRefinement {
155 top: Some(AbsoluteLength::Pixels(Pixels(1.))),
156 left: Some(AbsoluteLength::Pixels(Pixels(1.))),
157 right: Some(AbsoluteLength::Pixels(Pixels(1.))),
158 bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
159 },
160 text: Some(TextStyleRefinement {
161 font_family: Some(theme_settings.buffer_font.family.clone()),
162 font_size: Some(buffer_font_size.into()),
163 ..Default::default()
164 }),
165 ..Default::default()
166 },
167 inline_code: TextStyleRefinement {
168 font_family: Some(theme_settings.buffer_font.family.clone()),
169 font_size: Some(buffer_font_size.into()),
170 background_color: Some(colors.editor_foreground.opacity(0.1)),
171 ..Default::default()
172 },
173 link: TextStyleRefinement {
174 background_color: Some(colors.editor_foreground.opacity(0.025)),
175 underline: Some(UnderlineStyle {
176 color: Some(colors.text_accent.opacity(0.5)),
177 thickness: px(1.),
178 ..Default::default()
179 }),
180 ..Default::default()
181 },
182 ..Default::default()
183 };
184
185 let markdown = cx.new(|cx| {
186 Markdown::new(
187 text.into(),
188 markdown_style,
189 Some(self.language_registry.clone()),
190 None,
191 cx,
192 )
193 });
194 self.rendered_messages_by_id.insert(*id, markdown);
195 self.list_state.scroll_to(ListOffset {
196 item_ix: old_len,
197 offset_in_item: Pixels(0.0),
198 });
199 }
200
201 fn handle_thread_event(
202 &mut self,
203 _: &Entity<Thread>,
204 event: &ThreadEvent,
205 window: &mut Window,
206 cx: &mut Context<Self>,
207 ) {
208 match event {
209 ThreadEvent::ShowError(error) => {
210 self.last_error = Some(error.clone());
211 }
212 ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => {
213 self.thread_store
214 .update(cx, |thread_store, cx| {
215 thread_store.save_thread(&self.thread, cx)
216 })
217 .detach_and_log_err(cx);
218 }
219 ThreadEvent::StreamedAssistantText(message_id, text) => {
220 if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
221 markdown.update(cx, |markdown, cx| {
222 markdown.append(text, cx);
223 });
224 }
225 }
226 ThreadEvent::MessageAdded(message_id) => {
227 if let Some(message_text) = self
228 .thread
229 .read(cx)
230 .message(*message_id)
231 .map(|message| message.text.clone())
232 {
233 self.push_message(message_id, message_text, window, cx);
234 }
235
236 self.thread_store
237 .update(cx, |thread_store, cx| {
238 thread_store.save_thread(&self.thread, cx)
239 })
240 .detach_and_log_err(cx);
241
242 cx.notify();
243 }
244 ThreadEvent::UsePendingTools => {
245 let pending_tool_uses = self
246 .thread
247 .read(cx)
248 .pending_tool_uses()
249 .into_iter()
250 .filter(|tool_use| tool_use.status.is_idle())
251 .cloned()
252 .collect::<Vec<_>>();
253
254 for tool_use in pending_tool_uses {
255 if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
256 let task = tool.run(tool_use.input, self.workspace.clone(), window, cx);
257
258 self.thread.update(cx, |thread, cx| {
259 thread.insert_tool_output(tool_use.id.clone(), task, cx);
260 });
261 }
262 }
263 }
264 ThreadEvent::ToolFinished { .. } => {
265 let all_tools_finished = self
266 .thread
267 .read(cx)
268 .pending_tool_uses()
269 .into_iter()
270 .all(|tool_use| tool_use.status.is_error());
271 if all_tools_finished {
272 let model_registry = LanguageModelRegistry::read_global(cx);
273 if let Some(model) = model_registry.active_model() {
274 self.thread.update(cx, |thread, cx| {
275 // Insert a user message to contain the tool results.
276 thread.insert_user_message(
277 // TODO: Sending up a user message without any content results in the model sending back
278 // responses that also don't have any content. We currently don't handle this case well,
279 // so for now we provide some text to keep the model on track.
280 "Here are the tool results.",
281 Vec::new(),
282 cx,
283 );
284 thread.send_to_model(model, RequestKind::Chat, true, cx);
285 });
286 }
287 }
288 }
289 }
290 }
291
292 fn render_message(&self, ix: usize, cx: &mut Context<Self>) -> AnyElement {
293 let message_id = self.messages[ix];
294 let Some(message) = self.thread.read(cx).message(message_id) else {
295 return Empty.into_any();
296 };
297
298 let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
299 return Empty.into_any();
300 };
301
302 let context = self.thread.read(cx).context_for_message(message_id);
303 let tool_uses = self.thread.read(cx).tool_uses_for_message(message_id);
304 let colors = cx.theme().colors();
305
306 // Don't render user messages that are just there for returning tool results.
307 if message.role == Role::User && self.thread.read(cx).message_has_tool_results(message_id) {
308 return Empty.into_any();
309 }
310
311 let message_content = v_flex()
312 .child(div().p_2p5().text_ui(cx).child(markdown.clone()))
313 .when_some(context, |parent, context| {
314 if !context.is_empty() {
315 parent.child(
316 h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
317 context
318 .into_iter()
319 .map(|context| ContextPill::added(context, false, false, None)),
320 ),
321 )
322 } else {
323 parent
324 }
325 });
326
327 let styled_message = match message.role {
328 Role::User => v_flex()
329 .id(("message-container", ix))
330 .pt_2p5()
331 .px_2p5()
332 .child(
333 v_flex()
334 .bg(colors.editor_background)
335 .rounded_lg()
336 .border_1()
337 .border_color(colors.border)
338 .shadow_sm()
339 .child(
340 h_flex()
341 .py_1()
342 .px_2()
343 .bg(colors.editor_foreground.opacity(0.05))
344 .border_b_1()
345 .border_color(colors.border)
346 .justify_between()
347 .rounded_t(px(6.))
348 .child(
349 h_flex()
350 .gap_1p5()
351 .child(
352 Icon::new(IconName::PersonCircle)
353 .size(IconSize::XSmall)
354 .color(Color::Muted),
355 )
356 .child(
357 Label::new("You")
358 .size(LabelSize::Small)
359 .color(Color::Muted),
360 ),
361 ),
362 )
363 .child(message_content),
364 ),
365 Role::Assistant => div()
366 .id(("message-container", ix))
367 .child(message_content)
368 .map(|parent| {
369 if tool_uses.is_empty() {
370 return parent;
371 }
372
373 parent.child(
374 v_flex().children(
375 tool_uses
376 .into_iter()
377 .map(|tool_use| self.render_tool_use(tool_use, cx)),
378 ),
379 )
380 }),
381 Role::System => div().id(("message-container", ix)).py_1().px_2().child(
382 v_flex()
383 .bg(colors.editor_background)
384 .rounded_md()
385 .child(message_content),
386 ),
387 };
388
389 styled_message.into_any()
390 }
391
392 fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {
393 let is_open = self
394 .expanded_tool_uses
395 .get(&tool_use.id)
396 .copied()
397 .unwrap_or_default();
398
399 div().px_2p5().child(
400 v_flex()
401 .gap_1()
402 .rounded_lg()
403 .border_1()
404 .border_color(cx.theme().colors().border)
405 .child(
406 h_flex()
407 .justify_between()
408 .py_0p5()
409 .pl_1()
410 .pr_2()
411 .bg(cx.theme().colors().editor_foreground.opacity(0.02))
412 .when(is_open, |element| element.border_b_1().rounded_t(px(6.)))
413 .when(!is_open, |element| element.rounded(px(6.)))
414 .border_color(cx.theme().colors().border)
415 .child(
416 h_flex()
417 .gap_1()
418 .child(Disclosure::new("tool-use-disclosure", is_open).on_click(
419 cx.listener({
420 let tool_use_id = tool_use.id.clone();
421 move |this, _event, _window, _cx| {
422 let is_open = this
423 .expanded_tool_uses
424 .entry(tool_use_id.clone())
425 .or_insert(false);
426
427 *is_open = !*is_open;
428 }
429 }),
430 ))
431 .child(Label::new(tool_use.name)),
432 )
433 .child(
434 Label::new(match tool_use.status {
435 ToolUseStatus::Pending => "Pending",
436 ToolUseStatus::Running => "Running",
437 ToolUseStatus::Finished(_) => "Finished",
438 ToolUseStatus::Error(_) => "Error",
439 })
440 .size(LabelSize::XSmall)
441 .buffer_font(cx),
442 ),
443 )
444 .map(|parent| {
445 if !is_open {
446 return parent;
447 }
448
449 parent.child(
450 v_flex()
451 .child(
452 v_flex()
453 .gap_0p5()
454 .py_1()
455 .px_2p5()
456 .border_b_1()
457 .border_color(cx.theme().colors().border)
458 .child(Label::new("Input:"))
459 .child(Label::new(
460 serde_json::to_string_pretty(&tool_use.input)
461 .unwrap_or_default(),
462 )),
463 )
464 .map(|parent| match tool_use.status {
465 ToolUseStatus::Finished(output) => parent.child(
466 v_flex()
467 .gap_0p5()
468 .py_1()
469 .px_2p5()
470 .child(Label::new("Result:"))
471 .child(Label::new(output)),
472 ),
473 ToolUseStatus::Error(err) => parent.child(
474 v_flex()
475 .gap_0p5()
476 .py_1()
477 .px_2p5()
478 .child(Label::new("Error:"))
479 .child(Label::new(err)),
480 ),
481 ToolUseStatus::Pending | ToolUseStatus::Running => parent,
482 }),
483 )
484 }),
485 )
486 }
487}
488
489impl Render for ActiveThread {
490 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
491 v_flex()
492 .size_full()
493 .child(list(self.list_state.clone()).flex_grow())
494 }
495}