1use std::sync::Arc;
2
3use assistant_tool::ToolWorkingSet;
4use collections::HashMap;
5use editor::{Editor, MultiBuffer};
6use gpui::{
7 list, AbsoluteLength, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty,
8 Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
9 Task, TextStyleRefinement, UnderlineStyle, WeakEntity,
10};
11use language::{Buffer, LanguageRegistry};
12use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
13use markdown::{Markdown, MarkdownStyle};
14use settings::Settings as _;
15use theme::ThemeSettings;
16use ui::{prelude::*, Disclosure, KeyBinding};
17use util::ResultExt as _;
18use workspace::Workspace;
19
20use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
21use crate::thread_store::ThreadStore;
22use crate::tool_use::{ToolUse, ToolUseStatus};
23use crate::ui::ContextPill;
24
25pub struct ActiveThread {
26 workspace: WeakEntity<Workspace>,
27 language_registry: Arc<LanguageRegistry>,
28 tools: Arc<ToolWorkingSet>,
29 thread_store: Entity<ThreadStore>,
30 thread: Entity<Thread>,
31 save_thread_task: Option<Task<()>>,
32 messages: Vec<MessageId>,
33 list_state: ListState,
34 rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
35 editing_message: Option<(MessageId, EditMessageState)>,
36 expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
37 last_error: Option<ThreadError>,
38 _subscriptions: Vec<Subscription>,
39}
40
41struct EditMessageState {
42 editor: Entity<Editor>,
43}
44
45impl ActiveThread {
46 pub fn new(
47 thread: Entity<Thread>,
48 thread_store: Entity<ThreadStore>,
49 workspace: WeakEntity<Workspace>,
50 language_registry: Arc<LanguageRegistry>,
51 tools: Arc<ToolWorkingSet>,
52 window: &mut Window,
53 cx: &mut Context<Self>,
54 ) -> Self {
55 let subscriptions = vec![
56 cx.observe(&thread, |_, _, cx| cx.notify()),
57 cx.subscribe_in(&thread, window, Self::handle_thread_event),
58 ];
59
60 let mut this = Self {
61 workspace,
62 language_registry,
63 tools,
64 thread_store,
65 thread: thread.clone(),
66 save_thread_task: None,
67 messages: Vec::new(),
68 rendered_messages_by_id: HashMap::default(),
69 expanded_tool_uses: HashMap::default(),
70 list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
71 let this = cx.entity().downgrade();
72 move |ix, window: &mut Window, cx: &mut App| {
73 this.update(cx, |this, cx| this.render_message(ix, window, cx))
74 .unwrap()
75 }
76 }),
77 editing_message: None,
78 last_error: None,
79 _subscriptions: subscriptions,
80 };
81
82 for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
83 this.push_message(&message.id, message.text.clone(), window, cx);
84 }
85
86 this
87 }
88
89 pub fn thread(&self) -> &Entity<Thread> {
90 &self.thread
91 }
92
93 pub fn is_empty(&self) -> bool {
94 self.messages.is_empty()
95 }
96
97 pub fn summary(&self, cx: &App) -> Option<SharedString> {
98 self.thread.read(cx).summary()
99 }
100
101 pub fn summary_or_default(&self, cx: &App) -> SharedString {
102 self.thread.read(cx).summary_or_default()
103 }
104
105 pub fn cancel_last_completion(&mut self, cx: &mut App) -> bool {
106 self.last_error.take();
107 self.thread
108 .update(cx, |thread, _cx| thread.cancel_last_completion())
109 }
110
111 pub fn last_error(&self) -> Option<ThreadError> {
112 self.last_error.clone()
113 }
114
115 pub fn clear_last_error(&mut self) {
116 self.last_error.take();
117 }
118
119 fn push_message(
120 &mut self,
121 id: &MessageId,
122 text: String,
123 window: &mut Window,
124 cx: &mut Context<Self>,
125 ) {
126 let old_len = self.messages.len();
127 self.messages.push(*id);
128 self.list_state.splice(old_len..old_len, 1);
129
130 let markdown = self.render_markdown(text.into(), window, cx);
131 self.rendered_messages_by_id.insert(*id, markdown);
132 self.list_state.scroll_to(ListOffset {
133 item_ix: old_len,
134 offset_in_item: Pixels(0.0),
135 });
136 }
137
138 fn edited_message(
139 &mut self,
140 id: &MessageId,
141 text: String,
142 window: &mut Window,
143 cx: &mut Context<Self>,
144 ) {
145 let Some(index) = self.messages.iter().position(|message_id| message_id == id) else {
146 return;
147 };
148 self.list_state.splice(index..index + 1, 1);
149 let markdown = self.render_markdown(text.into(), window, cx);
150 self.rendered_messages_by_id.insert(*id, markdown);
151 }
152
153 fn deleted_message(&mut self, id: &MessageId) {
154 let Some(index) = self.messages.iter().position(|message_id| message_id == id) else {
155 return;
156 };
157 self.messages.remove(index);
158 self.list_state.splice(index..index + 1, 0);
159 self.rendered_messages_by_id.remove(id);
160 }
161
162 fn render_markdown(
163 &self,
164 text: SharedString,
165 window: &Window,
166 cx: &mut Context<Self>,
167 ) -> Entity<Markdown> {
168 let theme_settings = ThemeSettings::get_global(cx);
169 let colors = cx.theme().colors();
170 let ui_font_size = TextSize::Default.rems(cx);
171 let buffer_font_size = TextSize::Small.rems(cx);
172 let mut text_style = window.text_style();
173
174 text_style.refine(&TextStyleRefinement {
175 font_family: Some(theme_settings.ui_font.family.clone()),
176 font_size: Some(ui_font_size.into()),
177 color: Some(cx.theme().colors().text),
178 ..Default::default()
179 });
180
181 let markdown_style = MarkdownStyle {
182 base_text_style: text_style,
183 syntax: cx.theme().syntax().clone(),
184 selection_background_color: cx.theme().players().local().selection,
185 code_block_overflow_x_scroll: true,
186 table_overflow_x_scroll: true,
187 code_block: StyleRefinement {
188 margin: EdgesRefinement {
189 top: Some(Length::Definite(rems(0.).into())),
190 left: Some(Length::Definite(rems(0.).into())),
191 right: Some(Length::Definite(rems(0.).into())),
192 bottom: Some(Length::Definite(rems(0.5).into())),
193 },
194 padding: EdgesRefinement {
195 top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
196 left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
197 right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
198 bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
199 },
200 background: Some(colors.editor_background.into()),
201 border_color: Some(colors.border_variant),
202 border_widths: EdgesRefinement {
203 top: Some(AbsoluteLength::Pixels(Pixels(1.))),
204 left: Some(AbsoluteLength::Pixels(Pixels(1.))),
205 right: Some(AbsoluteLength::Pixels(Pixels(1.))),
206 bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
207 },
208 text: Some(TextStyleRefinement {
209 font_family: Some(theme_settings.buffer_font.family.clone()),
210 font_size: Some(buffer_font_size.into()),
211 ..Default::default()
212 }),
213 ..Default::default()
214 },
215 inline_code: TextStyleRefinement {
216 font_family: Some(theme_settings.buffer_font.family.clone()),
217 font_size: Some(buffer_font_size.into()),
218 background_color: Some(colors.editor_foreground.opacity(0.1)),
219 ..Default::default()
220 },
221 link: TextStyleRefinement {
222 background_color: Some(colors.editor_foreground.opacity(0.025)),
223 underline: Some(UnderlineStyle {
224 color: Some(colors.text_accent.opacity(0.5)),
225 thickness: px(1.),
226 ..Default::default()
227 }),
228 ..Default::default()
229 },
230 ..Default::default()
231 };
232
233 cx.new(|cx| {
234 Markdown::new(
235 text,
236 markdown_style,
237 Some(self.language_registry.clone()),
238 None,
239 cx,
240 )
241 })
242 }
243
244 fn handle_thread_event(
245 &mut self,
246 _: &Entity<Thread>,
247 event: &ThreadEvent,
248 window: &mut Window,
249 cx: &mut Context<Self>,
250 ) {
251 match event {
252 ThreadEvent::ShowError(error) => {
253 self.last_error = Some(error.clone());
254 }
255 ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => {
256 self.save_thread(cx);
257 }
258 ThreadEvent::StreamedAssistantText(message_id, text) => {
259 if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
260 markdown.update(cx, |markdown, cx| {
261 markdown.append(text, cx);
262 });
263 }
264 }
265 ThreadEvent::MessageAdded(message_id) => {
266 if let Some(message_text) = self
267 .thread
268 .read(cx)
269 .message(*message_id)
270 .map(|message| message.text.clone())
271 {
272 self.push_message(message_id, message_text, window, cx);
273 }
274
275 self.save_thread(cx);
276 cx.notify();
277 }
278 ThreadEvent::MessageEdited(message_id) => {
279 if let Some(message_text) = self
280 .thread
281 .read(cx)
282 .message(*message_id)
283 .map(|message| message.text.clone())
284 {
285 self.edited_message(message_id, message_text, window, cx);
286 }
287
288 self.save_thread(cx);
289 cx.notify();
290 }
291 ThreadEvent::MessageDeleted(message_id) => {
292 self.deleted_message(message_id);
293 self.save_thread(cx);
294 cx.notify();
295 }
296 ThreadEvent::UsePendingTools => {
297 let pending_tool_uses = self
298 .thread
299 .read(cx)
300 .pending_tool_uses()
301 .into_iter()
302 .filter(|tool_use| tool_use.status.is_idle())
303 .cloned()
304 .collect::<Vec<_>>();
305
306 for tool_use in pending_tool_uses {
307 if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
308 let task = tool.run(tool_use.input, self.workspace.clone(), window, cx);
309
310 self.thread.update(cx, |thread, cx| {
311 thread.insert_tool_output(tool_use.id.clone(), task, cx);
312 });
313 }
314 }
315 }
316 ThreadEvent::ToolFinished { .. } => {
317 let all_tools_finished = self
318 .thread
319 .read(cx)
320 .pending_tool_uses()
321 .into_iter()
322 .all(|tool_use| tool_use.status.is_error());
323 if all_tools_finished {
324 let model_registry = LanguageModelRegistry::read_global(cx);
325 if let Some(model) = model_registry.active_model() {
326 self.thread.update(cx, |thread, cx| {
327 // Insert a user message to contain the tool results.
328 thread.insert_user_message(
329 // TODO: Sending up a user message without any content results in the model sending back
330 // responses that also don't have any content. We currently don't handle this case well,
331 // so for now we provide some text to keep the model on track.
332 "Here are the tool results.",
333 Vec::new(),
334 cx,
335 );
336 thread.send_to_model(model, RequestKind::Chat, true, cx);
337 });
338 }
339 }
340 }
341 }
342 }
343
344 /// Spawns a task to save the active thread.
345 ///
346 /// Only one task to save the thread will be in flight at a time.
347 fn save_thread(&mut self, cx: &mut Context<Self>) {
348 let thread = self.thread.clone();
349 self.save_thread_task = Some(cx.spawn(|this, mut cx| async move {
350 let task = this
351 .update(&mut cx, |this, cx| {
352 this.thread_store
353 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
354 })
355 .ok();
356
357 if let Some(task) = task {
358 task.await.log_err();
359 }
360 }));
361 }
362
363 fn start_editing_message(
364 &mut self,
365 message_id: MessageId,
366 message_text: String,
367 window: &mut Window,
368 cx: &mut Context<Self>,
369 ) {
370 let buffer = cx.new(|cx| {
371 MultiBuffer::singleton(cx.new(|cx| Buffer::local(message_text.clone(), cx)), cx)
372 });
373 let editor = cx.new(|cx| {
374 let mut editor = Editor::new(
375 editor::EditorMode::AutoHeight { max_lines: 8 },
376 buffer,
377 None,
378 false,
379 window,
380 cx,
381 );
382 editor.focus_handle(cx).focus(window);
383 editor.move_to_end(&editor::actions::MoveToEnd, window, cx);
384 editor
385 });
386 self.editing_message = Some((
387 message_id,
388 EditMessageState {
389 editor: editor.clone(),
390 },
391 ));
392 cx.notify();
393 }
394
395 fn cancel_editing_message(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
396 self.editing_message.take();
397 cx.notify();
398 }
399
400 fn confirm_editing_message(
401 &mut self,
402 _: &menu::Confirm,
403 _: &mut Window,
404 cx: &mut Context<Self>,
405 ) {
406 let Some((message_id, state)) = self.editing_message.take() else {
407 return;
408 };
409 let edited_text = state.editor.read(cx).text(cx);
410 self.thread.update(cx, |thread, cx| {
411 thread.edit_message(message_id, Role::User, edited_text, cx);
412 for message_id in self.messages_after(message_id) {
413 thread.delete_message(*message_id, cx);
414 }
415 });
416
417 let provider = LanguageModelRegistry::read_global(cx).active_provider();
418 if provider
419 .as_ref()
420 .map_or(false, |provider| provider.must_accept_terms(cx))
421 {
422 cx.notify();
423 return;
424 }
425 let model_registry = LanguageModelRegistry::read_global(cx);
426 let Some(model) = model_registry.active_model() else {
427 return;
428 };
429
430 self.thread.update(cx, |thread, cx| {
431 thread.send_to_model(model, RequestKind::Chat, false, cx)
432 });
433 cx.notify();
434 }
435
436 fn last_user_message(&self, cx: &Context<Self>) -> Option<MessageId> {
437 self.messages
438 .iter()
439 .rev()
440 .find(|message_id| {
441 self.thread
442 .read(cx)
443 .message(**message_id)
444 .map_or(false, |message| message.role == Role::User)
445 })
446 .cloned()
447 }
448
449 fn messages_after(&self, message_id: MessageId) -> &[MessageId] {
450 self.messages
451 .iter()
452 .position(|id| *id == message_id)
453 .map(|index| &self.messages[index + 1..])
454 .unwrap_or(&[])
455 }
456
457 fn handle_cancel_click(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
458 self.cancel_editing_message(&menu::Cancel, window, cx);
459 }
460
461 fn handle_regenerate_click(
462 &mut self,
463 _: &ClickEvent,
464 window: &mut Window,
465 cx: &mut Context<Self>,
466 ) {
467 self.confirm_editing_message(&menu::Confirm, window, cx);
468 }
469
470 fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
471 let message_id = self.messages[ix];
472 let Some(message) = self.thread.read(cx).message(message_id) else {
473 return Empty.into_any();
474 };
475
476 let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
477 return Empty.into_any();
478 };
479
480 let context = self.thread.read(cx).context_for_message(message_id);
481 let tool_uses = self.thread.read(cx).tool_uses_for_message(message_id);
482 let colors = cx.theme().colors();
483
484 // Don't render user messages that are just there for returning tool results.
485 if message.role == Role::User && self.thread.read(cx).message_has_tool_results(message_id) {
486 return Empty.into_any();
487 }
488
489 let allow_editing_message =
490 message.role == Role::User && self.last_user_message(cx) == Some(message_id);
491
492 let edit_message_editor = self
493 .editing_message
494 .as_ref()
495 .filter(|(id, _)| *id == message_id)
496 .map(|(_, state)| state.editor.clone());
497
498 let message_content = v_flex()
499 .child(
500 if let Some(edit_message_editor) = edit_message_editor.clone() {
501 div()
502 .key_context("EditMessageEditor")
503 .on_action(cx.listener(Self::cancel_editing_message))
504 .on_action(cx.listener(Self::confirm_editing_message))
505 .p_2p5()
506 .child(edit_message_editor)
507 } else {
508 div().p_2p5().text_ui(cx).child(markdown.clone())
509 },
510 )
511 .when_some(context, |parent, context| {
512 if !context.is_empty() {
513 parent.child(
514 h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
515 context
516 .into_iter()
517 .map(|context| ContextPill::added(context, false, false, None)),
518 ),
519 )
520 } else {
521 parent
522 }
523 });
524
525 let styled_message = match message.role {
526 Role::User => v_flex()
527 .id(("message-container", ix))
528 .pt_2p5()
529 .px_2p5()
530 .child(
531 v_flex()
532 .bg(colors.editor_background)
533 .rounded_lg()
534 .border_1()
535 .border_color(colors.border)
536 .shadow_sm()
537 .child(
538 h_flex()
539 .py_1()
540 .pl_2()
541 .pr_1()
542 .bg(colors.editor_foreground.opacity(0.05))
543 .border_b_1()
544 .border_color(colors.border)
545 .justify_between()
546 .rounded_t(px(6.))
547 .child(
548 h_flex()
549 .gap_1p5()
550 .child(
551 Icon::new(IconName::PersonCircle)
552 .size(IconSize::XSmall)
553 .color(Color::Muted),
554 )
555 .child(
556 Label::new("You")
557 .size(LabelSize::Small)
558 .color(Color::Muted),
559 ),
560 )
561 .when_some(
562 edit_message_editor.clone(),
563 |this, edit_message_editor| {
564 let focus_handle = edit_message_editor.focus_handle(cx);
565 this.child(
566 h_flex()
567 .gap_1()
568 .child(
569 Button::new("cancel-edit-message", "Cancel")
570 .label_size(LabelSize::Small)
571 .key_binding(
572 KeyBinding::for_action_in(
573 &menu::Cancel,
574 &focus_handle,
575 window,
576 cx,
577 )
578 .map(|kb| kb.size(rems_from_px(12.))),
579 )
580 .on_click(
581 cx.listener(Self::handle_cancel_click),
582 ),
583 )
584 .child(
585 Button::new(
586 "confirm-edit-message",
587 "Regenerate",
588 )
589 .label_size(LabelSize::Small)
590 .key_binding(
591 KeyBinding::for_action_in(
592 &menu::Confirm,
593 &focus_handle,
594 window,
595 cx,
596 )
597 .map(|kb| kb.size(rems_from_px(12.))),
598 )
599 .on_click(
600 cx.listener(Self::handle_regenerate_click),
601 ),
602 ),
603 )
604 },
605 )
606 .when(
607 edit_message_editor.is_none() && allow_editing_message,
608 |this| {
609 this.child(
610 Button::new("edit-message", "Edit")
611 .label_size(LabelSize::Small)
612 .on_click(cx.listener({
613 let message_text = message.text.clone();
614 move |this, _, window, cx| {
615 this.start_editing_message(
616 message_id,
617 message_text.clone(),
618 window,
619 cx,
620 );
621 }
622 })),
623 )
624 },
625 ),
626 )
627 .child(message_content),
628 ),
629 Role::Assistant => div()
630 .id(("message-container", ix))
631 .child(message_content)
632 .map(|parent| {
633 if tool_uses.is_empty() {
634 return parent;
635 }
636
637 parent.child(
638 v_flex().children(
639 tool_uses
640 .into_iter()
641 .map(|tool_use| self.render_tool_use(tool_use, cx)),
642 ),
643 )
644 }),
645 Role::System => div().id(("message-container", ix)).py_1().px_2().child(
646 v_flex()
647 .bg(colors.editor_background)
648 .rounded_sm()
649 .child(message_content),
650 ),
651 };
652
653 styled_message.into_any()
654 }
655
656 fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {
657 let is_open = self
658 .expanded_tool_uses
659 .get(&tool_use.id)
660 .copied()
661 .unwrap_or_default();
662
663 div().px_2p5().child(
664 v_flex()
665 .gap_1()
666 .rounded_lg()
667 .border_1()
668 .border_color(cx.theme().colors().border)
669 .child(
670 h_flex()
671 .justify_between()
672 .py_0p5()
673 .pl_1()
674 .pr_2()
675 .bg(cx.theme().colors().editor_foreground.opacity(0.02))
676 .when(is_open, |element| element.border_b_1().rounded_t(px(6.)))
677 .when(!is_open, |element| element.rounded_md())
678 .border_color(cx.theme().colors().border)
679 .child(
680 h_flex()
681 .gap_1()
682 .child(Disclosure::new("tool-use-disclosure", is_open).on_click(
683 cx.listener({
684 let tool_use_id = tool_use.id.clone();
685 move |this, _event, _window, _cx| {
686 let is_open = this
687 .expanded_tool_uses
688 .entry(tool_use_id.clone())
689 .or_insert(false);
690
691 *is_open = !*is_open;
692 }
693 }),
694 ))
695 .child(Label::new(tool_use.name)),
696 )
697 .child(
698 Label::new(match tool_use.status {
699 ToolUseStatus::Pending => "Pending",
700 ToolUseStatus::Running => "Running",
701 ToolUseStatus::Finished(_) => "Finished",
702 ToolUseStatus::Error(_) => "Error",
703 })
704 .size(LabelSize::XSmall)
705 .buffer_font(cx),
706 ),
707 )
708 .map(|parent| {
709 if !is_open {
710 return parent;
711 }
712
713 parent.child(
714 v_flex()
715 .child(
716 v_flex()
717 .gap_0p5()
718 .py_1()
719 .px_2p5()
720 .border_b_1()
721 .border_color(cx.theme().colors().border)
722 .child(Label::new("Input:"))
723 .child(Label::new(
724 serde_json::to_string_pretty(&tool_use.input)
725 .unwrap_or_default(),
726 )),
727 )
728 .map(|parent| match tool_use.status {
729 ToolUseStatus::Finished(output) => parent.child(
730 v_flex()
731 .gap_0p5()
732 .py_1()
733 .px_2p5()
734 .child(Label::new("Result:"))
735 .child(Label::new(output)),
736 ),
737 ToolUseStatus::Error(err) => parent.child(
738 v_flex()
739 .gap_0p5()
740 .py_1()
741 .px_2p5()
742 .child(Label::new("Error:"))
743 .child(Label::new(err)),
744 ),
745 ToolUseStatus::Pending | ToolUseStatus::Running => parent,
746 }),
747 )
748 }),
749 )
750 }
751}
752
753impl Render for ActiveThread {
754 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
755 v_flex()
756 .size_full()
757 .child(list(self.list_state.clone()).flex_grow())
758 }
759}