1use std::sync::Arc;
2
3use assistant_tool::ToolWorkingSet;
4use collections::HashMap;
5use editor::{Editor, MultiBuffer};
6use gpui::{
7 list, AbsoluteLength, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty,
8 Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
9 TextStyleRefinement, UnderlineStyle, WeakEntity,
10};
11use language::{Buffer, LanguageRegistry};
12use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
13use markdown::{Markdown, MarkdownStyle};
14use settings::Settings as _;
15use theme::ThemeSettings;
16use ui::{prelude::*, Disclosure, KeyBinding};
17use workspace::Workspace;
18
19use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
20use crate::thread_store::ThreadStore;
21use crate::tool_use::{ToolUse, ToolUseStatus};
22use crate::ui::ContextPill;
23
24pub struct ActiveThread {
25 workspace: WeakEntity<Workspace>,
26 language_registry: Arc<LanguageRegistry>,
27 tools: Arc<ToolWorkingSet>,
28 thread_store: Entity<ThreadStore>,
29 thread: Entity<Thread>,
30 messages: Vec<MessageId>,
31 list_state: ListState,
32 rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
33 editing_message: Option<(MessageId, EditMessageState)>,
34 expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
35 last_error: Option<ThreadError>,
36 _subscriptions: Vec<Subscription>,
37}
38
39struct EditMessageState {
40 editor: Entity<Editor>,
41}
42
43impl ActiveThread {
44 pub fn new(
45 thread: Entity<Thread>,
46 thread_store: Entity<ThreadStore>,
47 workspace: WeakEntity<Workspace>,
48 language_registry: Arc<LanguageRegistry>,
49 tools: Arc<ToolWorkingSet>,
50 window: &mut Window,
51 cx: &mut Context<Self>,
52 ) -> Self {
53 let subscriptions = vec![
54 cx.observe(&thread, |_, _, cx| cx.notify()),
55 cx.subscribe_in(&thread, window, Self::handle_thread_event),
56 ];
57
58 let mut this = Self {
59 workspace,
60 language_registry,
61 tools,
62 thread_store,
63 thread: thread.clone(),
64 messages: Vec::new(),
65 rendered_messages_by_id: HashMap::default(),
66 expanded_tool_uses: HashMap::default(),
67 list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
68 let this = cx.entity().downgrade();
69 move |ix, window: &mut Window, cx: &mut App| {
70 this.update(cx, |this, cx| this.render_message(ix, window, cx))
71 .unwrap()
72 }
73 }),
74 editing_message: None,
75 last_error: None,
76 _subscriptions: subscriptions,
77 };
78
79 for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
80 this.push_message(&message.id, message.text.clone(), window, cx);
81 }
82
83 this
84 }
85
86 pub fn thread(&self) -> &Entity<Thread> {
87 &self.thread
88 }
89
90 pub fn is_empty(&self) -> bool {
91 self.messages.is_empty()
92 }
93
94 pub fn summary(&self, cx: &App) -> Option<SharedString> {
95 self.thread.read(cx).summary()
96 }
97
98 pub fn summary_or_default(&self, cx: &App) -> SharedString {
99 self.thread.read(cx).summary_or_default()
100 }
101
102 pub fn cancel_last_completion(&mut self, cx: &mut App) -> bool {
103 self.last_error.take();
104 self.thread
105 .update(cx, |thread, _cx| thread.cancel_last_completion())
106 }
107
108 pub fn last_error(&self) -> Option<ThreadError> {
109 self.last_error.clone()
110 }
111
112 pub fn clear_last_error(&mut self) {
113 self.last_error.take();
114 }
115
116 fn push_message(
117 &mut self,
118 id: &MessageId,
119 text: String,
120 window: &mut Window,
121 cx: &mut Context<Self>,
122 ) {
123 let old_len = self.messages.len();
124 self.messages.push(*id);
125 self.list_state.splice(old_len..old_len, 1);
126
127 let markdown = self.render_markdown(text.into(), window, cx);
128 self.rendered_messages_by_id.insert(*id, markdown);
129 self.list_state.scroll_to(ListOffset {
130 item_ix: old_len,
131 offset_in_item: Pixels(0.0),
132 });
133 }
134
135 fn edited_message(
136 &mut self,
137 id: &MessageId,
138 text: String,
139 window: &mut Window,
140 cx: &mut Context<Self>,
141 ) {
142 let Some(index) = self.messages.iter().position(|message_id| message_id == id) else {
143 return;
144 };
145 self.list_state.splice(index..index + 1, 1);
146 let markdown = self.render_markdown(text.into(), window, cx);
147 self.rendered_messages_by_id.insert(*id, markdown);
148 }
149
150 fn deleted_message(&mut self, id: &MessageId) {
151 let Some(index) = self.messages.iter().position(|message_id| message_id == id) else {
152 return;
153 };
154 self.messages.remove(index);
155 self.list_state.splice(index..index + 1, 0);
156 self.rendered_messages_by_id.remove(id);
157 }
158
159 fn render_markdown(
160 &self,
161 text: SharedString,
162 window: &Window,
163 cx: &mut Context<Self>,
164 ) -> Entity<Markdown> {
165 let theme_settings = ThemeSettings::get_global(cx);
166 let colors = cx.theme().colors();
167 let ui_font_size = TextSize::Default.rems(cx);
168 let buffer_font_size = TextSize::Small.rems(cx);
169 let mut text_style = window.text_style();
170
171 text_style.refine(&TextStyleRefinement {
172 font_family: Some(theme_settings.ui_font.family.clone()),
173 font_size: Some(ui_font_size.into()),
174 color: Some(cx.theme().colors().text),
175 ..Default::default()
176 });
177
178 let markdown_style = MarkdownStyle {
179 base_text_style: text_style,
180 syntax: cx.theme().syntax().clone(),
181 selection_background_color: cx.theme().players().local().selection,
182 code_block_overflow_x_scroll: true,
183 table_overflow_x_scroll: true,
184 code_block: StyleRefinement {
185 margin: EdgesRefinement {
186 top: Some(Length::Definite(rems(0.).into())),
187 left: Some(Length::Definite(rems(0.).into())),
188 right: Some(Length::Definite(rems(0.).into())),
189 bottom: Some(Length::Definite(rems(0.5).into())),
190 },
191 padding: EdgesRefinement {
192 top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
193 left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
194 right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
195 bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
196 },
197 background: Some(colors.editor_background.into()),
198 border_color: Some(colors.border_variant),
199 border_widths: EdgesRefinement {
200 top: Some(AbsoluteLength::Pixels(Pixels(1.))),
201 left: Some(AbsoluteLength::Pixels(Pixels(1.))),
202 right: Some(AbsoluteLength::Pixels(Pixels(1.))),
203 bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
204 },
205 text: Some(TextStyleRefinement {
206 font_family: Some(theme_settings.buffer_font.family.clone()),
207 font_size: Some(buffer_font_size.into()),
208 ..Default::default()
209 }),
210 ..Default::default()
211 },
212 inline_code: TextStyleRefinement {
213 font_family: Some(theme_settings.buffer_font.family.clone()),
214 font_size: Some(buffer_font_size.into()),
215 background_color: Some(colors.editor_foreground.opacity(0.1)),
216 ..Default::default()
217 },
218 link: TextStyleRefinement {
219 background_color: Some(colors.editor_foreground.opacity(0.025)),
220 underline: Some(UnderlineStyle {
221 color: Some(colors.text_accent.opacity(0.5)),
222 thickness: px(1.),
223 ..Default::default()
224 }),
225 ..Default::default()
226 },
227 ..Default::default()
228 };
229
230 cx.new(|cx| {
231 Markdown::new(
232 text,
233 markdown_style,
234 Some(self.language_registry.clone()),
235 None,
236 cx,
237 )
238 })
239 }
240
241 fn handle_thread_event(
242 &mut self,
243 _: &Entity<Thread>,
244 event: &ThreadEvent,
245 window: &mut Window,
246 cx: &mut Context<Self>,
247 ) {
248 match event {
249 ThreadEvent::ShowError(error) => {
250 self.last_error = Some(error.clone());
251 }
252 ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => {
253 self.thread_store
254 .update(cx, |thread_store, cx| {
255 thread_store.save_thread(&self.thread, cx)
256 })
257 .detach_and_log_err(cx);
258 }
259 ThreadEvent::StreamedAssistantText(message_id, text) => {
260 if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
261 markdown.update(cx, |markdown, cx| {
262 markdown.append(text, cx);
263 });
264 }
265 }
266 ThreadEvent::MessageAdded(message_id) => {
267 if let Some(message_text) = self
268 .thread
269 .read(cx)
270 .message(*message_id)
271 .map(|message| message.text.clone())
272 {
273 self.push_message(message_id, message_text, window, cx);
274 }
275
276 self.thread_store
277 .update(cx, |thread_store, cx| {
278 thread_store.save_thread(&self.thread, cx)
279 })
280 .detach_and_log_err(cx);
281
282 cx.notify();
283 }
284 ThreadEvent::MessageEdited(message_id) => {
285 if let Some(message_text) = self
286 .thread
287 .read(cx)
288 .message(*message_id)
289 .map(|message| message.text.clone())
290 {
291 self.edited_message(message_id, message_text, window, cx);
292 }
293
294 self.thread_store
295 .update(cx, |thread_store, cx| {
296 thread_store.save_thread(&self.thread, cx)
297 })
298 .detach_and_log_err(cx);
299
300 cx.notify();
301 }
302 ThreadEvent::MessageDeleted(message_id) => {
303 self.deleted_message(message_id);
304
305 self.thread_store
306 .update(cx, |thread_store, cx| {
307 thread_store.save_thread(&self.thread, cx)
308 })
309 .detach_and_log_err(cx);
310
311 cx.notify();
312 }
313 ThreadEvent::UsePendingTools => {
314 let pending_tool_uses = self
315 .thread
316 .read(cx)
317 .pending_tool_uses()
318 .into_iter()
319 .filter(|tool_use| tool_use.status.is_idle())
320 .cloned()
321 .collect::<Vec<_>>();
322
323 for tool_use in pending_tool_uses {
324 if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
325 let task = tool.run(tool_use.input, self.workspace.clone(), window, cx);
326
327 self.thread.update(cx, |thread, cx| {
328 thread.insert_tool_output(tool_use.id.clone(), task, cx);
329 });
330 }
331 }
332 }
333 ThreadEvent::ToolFinished { .. } => {
334 let all_tools_finished = self
335 .thread
336 .read(cx)
337 .pending_tool_uses()
338 .into_iter()
339 .all(|tool_use| tool_use.status.is_error());
340 if all_tools_finished {
341 let model_registry = LanguageModelRegistry::read_global(cx);
342 if let Some(model) = model_registry.active_model() {
343 self.thread.update(cx, |thread, cx| {
344 // Insert a user message to contain the tool results.
345 thread.insert_user_message(
346 // TODO: Sending up a user message without any content results in the model sending back
347 // responses that also don't have any content. We currently don't handle this case well,
348 // so for now we provide some text to keep the model on track.
349 "Here are the tool results.",
350 Vec::new(),
351 cx,
352 );
353 thread.send_to_model(model, RequestKind::Chat, true, cx);
354 });
355 }
356 }
357 }
358 }
359 }
360
361 fn start_editing_message(
362 &mut self,
363 message_id: MessageId,
364 message_text: String,
365 window: &mut Window,
366 cx: &mut Context<Self>,
367 ) {
368 let buffer = cx.new(|cx| {
369 MultiBuffer::singleton(cx.new(|cx| Buffer::local(message_text.clone(), cx)), cx)
370 });
371 let editor = cx.new(|cx| {
372 let mut editor = Editor::new(
373 editor::EditorMode::AutoHeight { max_lines: 8 },
374 buffer,
375 None,
376 false,
377 window,
378 cx,
379 );
380 editor.focus_handle(cx).focus(window);
381 editor.move_to_end(&editor::actions::MoveToEnd, window, cx);
382 editor
383 });
384 self.editing_message = Some((
385 message_id,
386 EditMessageState {
387 editor: editor.clone(),
388 },
389 ));
390 cx.notify();
391 }
392
393 fn cancel_editing_message(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
394 self.editing_message.take();
395 cx.notify();
396 }
397
398 fn confirm_editing_message(
399 &mut self,
400 _: &menu::Confirm,
401 _: &mut Window,
402 cx: &mut Context<Self>,
403 ) {
404 let Some((message_id, state)) = self.editing_message.take() else {
405 return;
406 };
407 let edited_text = state.editor.read(cx).text(cx);
408 self.thread.update(cx, |thread, cx| {
409 thread.edit_message(message_id, Role::User, edited_text, cx);
410 for message_id in self.messages_after(message_id) {
411 thread.delete_message(*message_id, cx);
412 }
413 });
414
415 let provider = LanguageModelRegistry::read_global(cx).active_provider();
416 if provider
417 .as_ref()
418 .map_or(false, |provider| provider.must_accept_terms(cx))
419 {
420 cx.notify();
421 return;
422 }
423 let model_registry = LanguageModelRegistry::read_global(cx);
424 let Some(model) = model_registry.active_model() else {
425 return;
426 };
427
428 self.thread.update(cx, |thread, cx| {
429 thread.send_to_model(model, RequestKind::Chat, false, cx)
430 });
431 cx.notify();
432 }
433
434 fn last_user_message(&self, cx: &Context<Self>) -> Option<MessageId> {
435 self.messages
436 .iter()
437 .rev()
438 .find(|message_id| {
439 self.thread
440 .read(cx)
441 .message(**message_id)
442 .map_or(false, |message| message.role == Role::User)
443 })
444 .cloned()
445 }
446
447 fn messages_after(&self, message_id: MessageId) -> &[MessageId] {
448 self.messages
449 .iter()
450 .position(|id| *id == message_id)
451 .map(|index| &self.messages[index + 1..])
452 .unwrap_or(&[])
453 }
454
455 fn handle_cancel_click(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
456 self.cancel_editing_message(&menu::Cancel, window, cx);
457 }
458
459 fn handle_regenerate_click(
460 &mut self,
461 _: &ClickEvent,
462 window: &mut Window,
463 cx: &mut Context<Self>,
464 ) {
465 self.confirm_editing_message(&menu::Confirm, window, cx);
466 }
467
468 fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
469 let message_id = self.messages[ix];
470 let Some(message) = self.thread.read(cx).message(message_id) else {
471 return Empty.into_any();
472 };
473
474 let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
475 return Empty.into_any();
476 };
477
478 let context = self.thread.read(cx).context_for_message(message_id);
479 let tool_uses = self.thread.read(cx).tool_uses_for_message(message_id);
480 let colors = cx.theme().colors();
481
482 // Don't render user messages that are just there for returning tool results.
483 if message.role == Role::User && self.thread.read(cx).message_has_tool_results(message_id) {
484 return Empty.into_any();
485 }
486
487 let allow_editing_message =
488 message.role == Role::User && self.last_user_message(cx) == Some(message_id);
489
490 let edit_message_editor = self
491 .editing_message
492 .as_ref()
493 .filter(|(id, _)| *id == message_id)
494 .map(|(_, state)| state.editor.clone());
495
496 let message_content = v_flex()
497 .child(
498 if let Some(edit_message_editor) = edit_message_editor.clone() {
499 div()
500 .key_context("EditMessageEditor")
501 .on_action(cx.listener(Self::cancel_editing_message))
502 .on_action(cx.listener(Self::confirm_editing_message))
503 .p_2p5()
504 .child(edit_message_editor)
505 } else {
506 div().p_2p5().text_ui(cx).child(markdown.clone())
507 },
508 )
509 .when_some(context, |parent, context| {
510 if !context.is_empty() {
511 parent.child(
512 h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
513 context
514 .into_iter()
515 .map(|context| ContextPill::added(context, false, false, None)),
516 ),
517 )
518 } else {
519 parent
520 }
521 });
522
523 let styled_message = match message.role {
524 Role::User => v_flex()
525 .id(("message-container", ix))
526 .pt_2p5()
527 .px_2p5()
528 .child(
529 v_flex()
530 .bg(colors.editor_background)
531 .rounded_lg()
532 .border_1()
533 .border_color(colors.border)
534 .shadow_sm()
535 .child(
536 h_flex()
537 .py_1()
538 .pl_2()
539 .pr_1()
540 .bg(colors.editor_foreground.opacity(0.05))
541 .border_b_1()
542 .border_color(colors.border)
543 .justify_between()
544 .rounded_t(px(6.))
545 .child(
546 h_flex()
547 .gap_1p5()
548 .child(
549 Icon::new(IconName::PersonCircle)
550 .size(IconSize::XSmall)
551 .color(Color::Muted),
552 )
553 .child(
554 Label::new("You")
555 .size(LabelSize::Small)
556 .color(Color::Muted),
557 ),
558 )
559 .when_some(
560 edit_message_editor.clone(),
561 |this, edit_message_editor| {
562 let focus_handle = edit_message_editor.focus_handle(cx);
563 this.child(
564 h_flex()
565 .gap_1()
566 .child(
567 Button::new("cancel-edit-message", "Cancel")
568 .label_size(LabelSize::Small)
569 .key_binding(
570 KeyBinding::for_action_in(
571 &menu::Cancel,
572 &focus_handle,
573 window,
574 cx,
575 )
576 .map(|kb| kb.size(rems_from_px(12.))),
577 )
578 .on_click(
579 cx.listener(Self::handle_cancel_click),
580 ),
581 )
582 .child(
583 Button::new(
584 "confirm-edit-message",
585 "Regenerate",
586 )
587 .label_size(LabelSize::Small)
588 .key_binding(
589 KeyBinding::for_action_in(
590 &menu::Confirm,
591 &focus_handle,
592 window,
593 cx,
594 )
595 .map(|kb| kb.size(rems_from_px(12.))),
596 )
597 .on_click(
598 cx.listener(Self::handle_regenerate_click),
599 ),
600 ),
601 )
602 },
603 )
604 .when(
605 edit_message_editor.is_none() && allow_editing_message,
606 |this| {
607 this.child(
608 Button::new("edit-message", "Edit")
609 .label_size(LabelSize::Small)
610 .on_click(cx.listener({
611 let message_text = message.text.clone();
612 move |this, _, window, cx| {
613 this.start_editing_message(
614 message_id,
615 message_text.clone(),
616 window,
617 cx,
618 );
619 }
620 })),
621 )
622 },
623 ),
624 )
625 .child(message_content),
626 ),
627 Role::Assistant => div()
628 .id(("message-container", ix))
629 .child(message_content)
630 .map(|parent| {
631 if tool_uses.is_empty() {
632 return parent;
633 }
634
635 parent.child(
636 v_flex().children(
637 tool_uses
638 .into_iter()
639 .map(|tool_use| self.render_tool_use(tool_use, cx)),
640 ),
641 )
642 }),
643 Role::System => div().id(("message-container", ix)).py_1().px_2().child(
644 v_flex()
645 .bg(colors.editor_background)
646 .rounded_md()
647 .child(message_content),
648 ),
649 };
650
651 styled_message.into_any()
652 }
653
654 fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {
655 let is_open = self
656 .expanded_tool_uses
657 .get(&tool_use.id)
658 .copied()
659 .unwrap_or_default();
660
661 div().px_2p5().child(
662 v_flex()
663 .gap_1()
664 .rounded_lg()
665 .border_1()
666 .border_color(cx.theme().colors().border)
667 .child(
668 h_flex()
669 .justify_between()
670 .py_0p5()
671 .pl_1()
672 .pr_2()
673 .bg(cx.theme().colors().editor_foreground.opacity(0.02))
674 .when(is_open, |element| element.border_b_1().rounded_t(px(6.)))
675 .when(!is_open, |element| element.rounded(px(6.)))
676 .border_color(cx.theme().colors().border)
677 .child(
678 h_flex()
679 .gap_1()
680 .child(Disclosure::new("tool-use-disclosure", is_open).on_click(
681 cx.listener({
682 let tool_use_id = tool_use.id.clone();
683 move |this, _event, _window, _cx| {
684 let is_open = this
685 .expanded_tool_uses
686 .entry(tool_use_id.clone())
687 .or_insert(false);
688
689 *is_open = !*is_open;
690 }
691 }),
692 ))
693 .child(Label::new(tool_use.name)),
694 )
695 .child(
696 Label::new(match tool_use.status {
697 ToolUseStatus::Pending => "Pending",
698 ToolUseStatus::Running => "Running",
699 ToolUseStatus::Finished(_) => "Finished",
700 ToolUseStatus::Error(_) => "Error",
701 })
702 .size(LabelSize::XSmall)
703 .buffer_font(cx),
704 ),
705 )
706 .map(|parent| {
707 if !is_open {
708 return parent;
709 }
710
711 parent.child(
712 v_flex()
713 .child(
714 v_flex()
715 .gap_0p5()
716 .py_1()
717 .px_2p5()
718 .border_b_1()
719 .border_color(cx.theme().colors().border)
720 .child(Label::new("Input:"))
721 .child(Label::new(
722 serde_json::to_string_pretty(&tool_use.input)
723 .unwrap_or_default(),
724 )),
725 )
726 .map(|parent| match tool_use.status {
727 ToolUseStatus::Finished(output) => parent.child(
728 v_flex()
729 .gap_0p5()
730 .py_1()
731 .px_2p5()
732 .child(Label::new("Result:"))
733 .child(Label::new(output)),
734 ),
735 ToolUseStatus::Error(err) => parent.child(
736 v_flex()
737 .gap_0p5()
738 .py_1()
739 .px_2p5()
740 .child(Label::new("Error:"))
741 .child(Label::new(err)),
742 ),
743 ToolUseStatus::Pending | ToolUseStatus::Running => parent,
744 }),
745 )
746 }),
747 )
748 }
749}
750
751impl Render for ActiveThread {
752 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
753 v_flex()
754 .size_full()
755 .child(list(self.list_state.clone()).flex_grow())
756 }
757}