1use std::sync::Arc;
2
3use assistant_tool::ToolWorkingSet;
4use collections::HashMap;
5use gpui::{
6 list, AbsoluteLength, AnyElement, App, DefiniteLength, EdgesRefinement, Empty, Entity, Length,
7 ListAlignment, ListOffset, ListState, StyleRefinement, Subscription, TextStyleRefinement,
8 UnderlineStyle, WeakEntity,
9};
10use language::LanguageRegistry;
11use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
12use markdown::{Markdown, MarkdownStyle};
13use settings::Settings as _;
14use theme::ThemeSettings;
15use ui::{prelude::*, Disclosure};
16use workspace::Workspace;
17
18use crate::thread::{
19 MessageId, RequestKind, Thread, ThreadError, ThreadEvent, ToolUse, ToolUseStatus,
20};
21use crate::thread_store::ThreadStore;
22use crate::ui::ContextPill;
23
24pub struct ActiveThread {
25 workspace: WeakEntity<Workspace>,
26 language_registry: Arc<LanguageRegistry>,
27 tools: Arc<ToolWorkingSet>,
28 thread_store: Entity<ThreadStore>,
29 thread: Entity<Thread>,
30 messages: Vec<MessageId>,
31 list_state: ListState,
32 rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
33 expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
34 last_error: Option<ThreadError>,
35 _subscriptions: Vec<Subscription>,
36}
37
38impl ActiveThread {
39 pub fn new(
40 thread: Entity<Thread>,
41 thread_store: Entity<ThreadStore>,
42 workspace: WeakEntity<Workspace>,
43 language_registry: Arc<LanguageRegistry>,
44 tools: Arc<ToolWorkingSet>,
45 window: &mut Window,
46 cx: &mut Context<Self>,
47 ) -> Self {
48 let subscriptions = vec![
49 cx.observe(&thread, |_, _, cx| cx.notify()),
50 cx.subscribe_in(&thread, window, Self::handle_thread_event),
51 ];
52
53 let mut this = Self {
54 workspace,
55 language_registry,
56 tools,
57 thread_store,
58 thread: thread.clone(),
59 messages: Vec::new(),
60 rendered_messages_by_id: HashMap::default(),
61 expanded_tool_uses: HashMap::default(),
62 list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
63 let this = cx.entity().downgrade();
64 move |ix, _: &mut Window, cx: &mut App| {
65 this.update(cx, |this, cx| this.render_message(ix, cx))
66 .unwrap()
67 }
68 }),
69 last_error: None,
70 _subscriptions: subscriptions,
71 };
72
73 for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
74 this.push_message(&message.id, message.text.clone(), window, cx);
75 }
76
77 this
78 }
79
80 pub fn thread(&self) -> &Entity<Thread> {
81 &self.thread
82 }
83
84 pub fn is_empty(&self) -> bool {
85 self.messages.is_empty()
86 }
87
88 pub fn summary(&self, cx: &App) -> Option<SharedString> {
89 self.thread.read(cx).summary()
90 }
91
92 pub fn summary_or_default(&self, cx: &App) -> SharedString {
93 self.thread.read(cx).summary_or_default()
94 }
95
96 pub fn cancel_last_completion(&mut self, cx: &mut App) -> bool {
97 self.last_error.take();
98 self.thread
99 .update(cx, |thread, _cx| thread.cancel_last_completion())
100 }
101
102 pub fn last_error(&self) -> Option<ThreadError> {
103 self.last_error.clone()
104 }
105
106 pub fn clear_last_error(&mut self) {
107 self.last_error.take();
108 }
109
110 fn push_message(
111 &mut self,
112 id: &MessageId,
113 text: String,
114 window: &mut Window,
115 cx: &mut Context<Self>,
116 ) {
117 let old_len = self.messages.len();
118 self.messages.push(*id);
119 self.list_state.splice(old_len..old_len, 1);
120
121 let theme_settings = ThemeSettings::get_global(cx);
122 let colors = cx.theme().colors();
123 let ui_font_size = TextSize::Default.rems(cx);
124 let buffer_font_size = TextSize::Small.rems(cx);
125 let mut text_style = window.text_style();
126
127 text_style.refine(&TextStyleRefinement {
128 font_family: Some(theme_settings.ui_font.family.clone()),
129 font_size: Some(ui_font_size.into()),
130 color: Some(cx.theme().colors().text),
131 ..Default::default()
132 });
133
134 let markdown_style = MarkdownStyle {
135 base_text_style: text_style,
136 syntax: cx.theme().syntax().clone(),
137 selection_background_color: cx.theme().players().local().selection,
138 code_block: StyleRefinement {
139 margin: EdgesRefinement {
140 top: Some(Length::Definite(rems(0.).into())),
141 left: Some(Length::Definite(rems(0.).into())),
142 right: Some(Length::Definite(rems(0.).into())),
143 bottom: Some(Length::Definite(rems(0.5).into())),
144 },
145 padding: EdgesRefinement {
146 top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
147 left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
148 right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
149 bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
150 },
151 background: Some(colors.editor_background.into()),
152 border_color: Some(colors.border_variant),
153 border_widths: EdgesRefinement {
154 top: Some(AbsoluteLength::Pixels(Pixels(1.))),
155 left: Some(AbsoluteLength::Pixels(Pixels(1.))),
156 right: Some(AbsoluteLength::Pixels(Pixels(1.))),
157 bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
158 },
159 text: Some(TextStyleRefinement {
160 font_family: Some(theme_settings.buffer_font.family.clone()),
161 font_size: Some(buffer_font_size.into()),
162 ..Default::default()
163 }),
164 ..Default::default()
165 },
166 inline_code: TextStyleRefinement {
167 font_family: Some(theme_settings.buffer_font.family.clone()),
168 font_size: Some(buffer_font_size.into()),
169 background_color: Some(colors.editor_foreground.opacity(0.1)),
170 ..Default::default()
171 },
172 link: TextStyleRefinement {
173 background_color: Some(colors.editor_foreground.opacity(0.025)),
174 underline: Some(UnderlineStyle {
175 color: Some(colors.text_accent.opacity(0.5)),
176 thickness: px(1.),
177 ..Default::default()
178 }),
179 ..Default::default()
180 },
181 ..Default::default()
182 };
183
184 let markdown = cx.new(|cx| {
185 Markdown::new(
186 text.into(),
187 markdown_style,
188 Some(self.language_registry.clone()),
189 None,
190 cx,
191 )
192 });
193 self.rendered_messages_by_id.insert(*id, markdown);
194 self.list_state.scroll_to(ListOffset {
195 item_ix: old_len,
196 offset_in_item: Pixels(0.0),
197 });
198 }
199
200 fn handle_thread_event(
201 &mut self,
202 _: &Entity<Thread>,
203 event: &ThreadEvent,
204 window: &mut Window,
205 cx: &mut Context<Self>,
206 ) {
207 match event {
208 ThreadEvent::ShowError(error) => {
209 self.last_error = Some(error.clone());
210 }
211 ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => {
212 self.thread_store
213 .update(cx, |thread_store, cx| {
214 thread_store.save_thread(&self.thread, cx)
215 })
216 .detach_and_log_err(cx);
217 }
218 ThreadEvent::StreamedAssistantText(message_id, text) => {
219 if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
220 markdown.update(cx, |markdown, cx| {
221 markdown.append(text, cx);
222 });
223 }
224 }
225 ThreadEvent::MessageAdded(message_id) => {
226 if let Some(message_text) = self
227 .thread
228 .read(cx)
229 .message(*message_id)
230 .map(|message| message.text.clone())
231 {
232 self.push_message(message_id, message_text, window, cx);
233 }
234
235 self.thread_store
236 .update(cx, |thread_store, cx| {
237 thread_store.save_thread(&self.thread, cx)
238 })
239 .detach_and_log_err(cx);
240
241 cx.notify();
242 }
243 ThreadEvent::UsePendingTools => {
244 let pending_tool_uses = self
245 .thread
246 .read(cx)
247 .pending_tool_uses()
248 .into_iter()
249 .filter(|tool_use| tool_use.status.is_idle())
250 .cloned()
251 .collect::<Vec<_>>();
252
253 for tool_use in pending_tool_uses {
254 if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
255 let task = tool.run(tool_use.input, self.workspace.clone(), window, cx);
256
257 self.thread.update(cx, |thread, cx| {
258 thread.insert_tool_output(tool_use.id.clone(), task, cx);
259 });
260 }
261 }
262 }
263 ThreadEvent::ToolFinished { .. } => {
264 let all_tools_finished = self
265 .thread
266 .read(cx)
267 .pending_tool_uses()
268 .into_iter()
269 .all(|tool_use| tool_use.status.is_error());
270 if all_tools_finished {
271 let model_registry = LanguageModelRegistry::read_global(cx);
272 if let Some(model) = model_registry.active_model() {
273 self.thread.update(cx, |thread, cx| {
274 // Insert an empty user message to contain the tool results.
275 thread.insert_user_message("", Vec::new(), cx);
276 thread.send_to_model(model, RequestKind::Chat, true, cx);
277 });
278 }
279 }
280 }
281 }
282 }
283
284 fn render_message(&self, ix: usize, cx: &mut Context<Self>) -> AnyElement {
285 let message_id = self.messages[ix];
286 let Some(message) = self.thread.read(cx).message(message_id) else {
287 return Empty.into_any();
288 };
289
290 let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
291 return Empty.into_any();
292 };
293
294 let context = self.thread.read(cx).context_for_message(message_id);
295 let tool_uses = self.thread.read(cx).tool_uses_for_message(message_id);
296 let colors = cx.theme().colors();
297
298 // Don't render user messages that are just there for returning tool results.
299 if message.role == Role::User
300 && message.text.is_empty()
301 && self.thread.read(cx).message_has_tool_results(message_id)
302 {
303 return Empty.into_any();
304 }
305
306 let message_content = v_flex()
307 .child(div().p_2p5().text_ui(cx).child(markdown.clone()))
308 .when_some(context, |parent, context| {
309 if !context.is_empty() {
310 parent.child(
311 h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
312 context
313 .into_iter()
314 .map(|context| ContextPill::added(context, false, false, None)),
315 ),
316 )
317 } else {
318 parent
319 }
320 });
321
322 let styled_message = match message.role {
323 Role::User => v_flex()
324 .id(("message-container", ix))
325 .pt_2p5()
326 .px_2p5()
327 .child(
328 v_flex()
329 .bg(colors.editor_background)
330 .rounded_lg()
331 .border_1()
332 .border_color(colors.border)
333 .shadow_sm()
334 .child(
335 h_flex()
336 .py_1()
337 .px_2()
338 .bg(colors.editor_foreground.opacity(0.05))
339 .border_b_1()
340 .border_color(colors.border)
341 .justify_between()
342 .rounded_t(px(6.))
343 .child(
344 h_flex()
345 .gap_1p5()
346 .child(
347 Icon::new(IconName::PersonCircle)
348 .size(IconSize::XSmall)
349 .color(Color::Muted),
350 )
351 .child(
352 Label::new("You")
353 .size(LabelSize::Small)
354 .color(Color::Muted),
355 ),
356 ),
357 )
358 .child(message_content),
359 ),
360 Role::Assistant => div()
361 .id(("message-container", ix))
362 .child(message_content)
363 .map(|parent| {
364 if tool_uses.is_empty() {
365 return parent;
366 }
367
368 parent.child(
369 v_flex().children(
370 tool_uses
371 .into_iter()
372 .map(|tool_use| self.render_tool_use(tool_use, cx)),
373 ),
374 )
375 }),
376 Role::System => div().id(("message-container", ix)).py_1().px_2().child(
377 v_flex()
378 .bg(colors.editor_background)
379 .rounded_md()
380 .child(message_content),
381 ),
382 };
383
384 styled_message.into_any()
385 }
386
387 fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {
388 let is_open = self
389 .expanded_tool_uses
390 .get(&tool_use.id)
391 .copied()
392 .unwrap_or_default();
393
394 div().px_2p5().child(
395 v_flex()
396 .gap_1()
397 .rounded_lg()
398 .border_1()
399 .border_color(cx.theme().colors().border)
400 .child(
401 h_flex()
402 .justify_between()
403 .py_0p5()
404 .pl_1()
405 .pr_2()
406 .bg(cx.theme().colors().editor_foreground.opacity(0.02))
407 .when(is_open, |element| element.border_b_1().rounded_t(px(6.)))
408 .when(!is_open, |element| element.rounded(px(6.)))
409 .border_color(cx.theme().colors().border)
410 .child(
411 h_flex()
412 .gap_1()
413 .child(Disclosure::new("tool-use-disclosure", is_open).on_click(
414 cx.listener({
415 let tool_use_id = tool_use.id.clone();
416 move |this, _event, _window, _cx| {
417 let is_open = this
418 .expanded_tool_uses
419 .entry(tool_use_id.clone())
420 .or_insert(false);
421
422 *is_open = !*is_open;
423 }
424 }),
425 ))
426 .child(Label::new(tool_use.name)),
427 )
428 .child(
429 Label::new(match tool_use.status {
430 ToolUseStatus::Pending => "Pending",
431 ToolUseStatus::Running => "Running",
432 ToolUseStatus::Finished(_) => "Finished",
433 ToolUseStatus::Error(_) => "Error",
434 })
435 .size(LabelSize::XSmall)
436 .buffer_font(cx),
437 ),
438 )
439 .map(|parent| {
440 if !is_open {
441 return parent;
442 }
443
444 parent.child(
445 v_flex()
446 .child(
447 v_flex()
448 .gap_0p5()
449 .py_1()
450 .px_2p5()
451 .border_b_1()
452 .border_color(cx.theme().colors().border)
453 .child(Label::new("Input:"))
454 .child(Label::new(
455 serde_json::to_string_pretty(&tool_use.input)
456 .unwrap_or_default(),
457 )),
458 )
459 .map(|parent| match tool_use.status {
460 ToolUseStatus::Finished(output) => parent.child(
461 v_flex()
462 .gap_0p5()
463 .py_1()
464 .px_2p5()
465 .child(Label::new("Result:"))
466 .child(Label::new(output)),
467 ),
468 ToolUseStatus::Error(err) => parent.child(
469 v_flex()
470 .gap_0p5()
471 .py_1()
472 .px_2p5()
473 .child(Label::new("Error:"))
474 .child(Label::new(err)),
475 ),
476 ToolUseStatus::Pending | ToolUseStatus::Running => parent,
477 }),
478 )
479 }),
480 )
481 }
482}
483
484impl Render for ActiveThread {
485 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
486 v_flex()
487 .size_full()
488 .child(list(self.list_state.clone()).flex_grow())
489 }
490}