1use std::sync::Arc;
2
3use assistant_tool::ToolWorkingSet;
4use collections::HashMap;
5use editor::{Editor, MultiBuffer};
6use gpui::{
7 list, AbsoluteLength, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty,
8 Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
9 Task, TextStyleRefinement, UnderlineStyle, WeakEntity,
10};
11use language::{Buffer, LanguageRegistry};
12use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
13use markdown::{Markdown, MarkdownStyle};
14use project::Project;
15use settings::Settings as _;
16use theme::ThemeSettings;
17use ui::{prelude::*, Disclosure, KeyBinding};
18use util::ResultExt as _;
19
20use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
21use crate::thread_store::ThreadStore;
22use crate::tool_use::{ToolUse, ToolUseStatus};
23use crate::ui::ContextPill;
24
25pub struct ActiveThread {
26 project: WeakEntity<Project>,
27 language_registry: Arc<LanguageRegistry>,
28 tools: Arc<ToolWorkingSet>,
29 thread_store: Entity<ThreadStore>,
30 thread: Entity<Thread>,
31 save_thread_task: Option<Task<()>>,
32 messages: Vec<MessageId>,
33 list_state: ListState,
34 rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
35 editing_message: Option<(MessageId, EditMessageState)>,
36 expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
37 last_error: Option<ThreadError>,
38 _subscriptions: Vec<Subscription>,
39}
40
41struct EditMessageState {
42 editor: Entity<Editor>,
43}
44
45impl ActiveThread {
46 pub fn new(
47 thread: Entity<Thread>,
48 thread_store: Entity<ThreadStore>,
49 project: WeakEntity<Project>,
50 language_registry: Arc<LanguageRegistry>,
51 tools: Arc<ToolWorkingSet>,
52 window: &mut Window,
53 cx: &mut Context<Self>,
54 ) -> Self {
55 let subscriptions = vec![
56 cx.observe(&thread, |_, _, cx| cx.notify()),
57 cx.subscribe_in(&thread, window, Self::handle_thread_event),
58 ];
59
60 let mut this = Self {
61 project,
62 language_registry,
63 tools,
64 thread_store,
65 thread: thread.clone(),
66 save_thread_task: None,
67 messages: Vec::new(),
68 rendered_messages_by_id: HashMap::default(),
69 expanded_tool_uses: HashMap::default(),
70 list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
71 let this = cx.entity().downgrade();
72 move |ix, window: &mut Window, cx: &mut App| {
73 this.update(cx, |this, cx| this.render_message(ix, window, cx))
74 .unwrap()
75 }
76 }),
77 editing_message: None,
78 last_error: None,
79 _subscriptions: subscriptions,
80 };
81
82 for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
83 this.push_message(&message.id, message.text.clone(), window, cx);
84 }
85
86 this
87 }
88
89 pub fn thread(&self) -> &Entity<Thread> {
90 &self.thread
91 }
92
93 pub fn is_empty(&self) -> bool {
94 self.messages.is_empty()
95 }
96
97 pub fn summary(&self, cx: &App) -> Option<SharedString> {
98 self.thread.read(cx).summary()
99 }
100
101 pub fn summary_or_default(&self, cx: &App) -> SharedString {
102 self.thread.read(cx).summary_or_default()
103 }
104
105 pub fn cancel_last_completion(&mut self, cx: &mut App) -> bool {
106 self.last_error.take();
107 self.thread
108 .update(cx, |thread, _cx| thread.cancel_last_completion())
109 }
110
111 pub fn last_error(&self) -> Option<ThreadError> {
112 self.last_error.clone()
113 }
114
115 pub fn clear_last_error(&mut self) {
116 self.last_error.take();
117 }
118
119 fn push_message(
120 &mut self,
121 id: &MessageId,
122 text: String,
123 window: &mut Window,
124 cx: &mut Context<Self>,
125 ) {
126 let old_len = self.messages.len();
127 self.messages.push(*id);
128 self.list_state.splice(old_len..old_len, 1);
129
130 let markdown = self.render_markdown(text.into(), window, cx);
131 self.rendered_messages_by_id.insert(*id, markdown);
132 self.list_state.scroll_to(ListOffset {
133 item_ix: old_len,
134 offset_in_item: Pixels(0.0),
135 });
136 }
137
138 fn edited_message(
139 &mut self,
140 id: &MessageId,
141 text: String,
142 window: &mut Window,
143 cx: &mut Context<Self>,
144 ) {
145 let Some(index) = self.messages.iter().position(|message_id| message_id == id) else {
146 return;
147 };
148 self.list_state.splice(index..index + 1, 1);
149 let markdown = self.render_markdown(text.into(), window, cx);
150 self.rendered_messages_by_id.insert(*id, markdown);
151 }
152
153 fn deleted_message(&mut self, id: &MessageId) {
154 let Some(index) = self.messages.iter().position(|message_id| message_id == id) else {
155 return;
156 };
157 self.messages.remove(index);
158 self.list_state.splice(index..index + 1, 0);
159 self.rendered_messages_by_id.remove(id);
160 }
161
162 fn render_markdown(
163 &self,
164 text: SharedString,
165 window: &Window,
166 cx: &mut Context<Self>,
167 ) -> Entity<Markdown> {
168 let theme_settings = ThemeSettings::get_global(cx);
169 let colors = cx.theme().colors();
170 let ui_font_size = TextSize::Default.rems(cx);
171 let buffer_font_size = TextSize::Small.rems(cx);
172 let mut text_style = window.text_style();
173
174 text_style.refine(&TextStyleRefinement {
175 font_family: Some(theme_settings.ui_font.family.clone()),
176 font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
177 font_features: Some(theme_settings.ui_font.features.clone()),
178 font_size: Some(ui_font_size.into()),
179 color: Some(cx.theme().colors().text),
180 ..Default::default()
181 });
182
183 let markdown_style = MarkdownStyle {
184 base_text_style: text_style,
185 syntax: cx.theme().syntax().clone(),
186 selection_background_color: cx.theme().players().local().selection,
187 code_block_overflow_x_scroll: true,
188 table_overflow_x_scroll: true,
189 code_block: StyleRefinement {
190 margin: EdgesRefinement {
191 top: Some(Length::Definite(rems(0.).into())),
192 left: Some(Length::Definite(rems(0.).into())),
193 right: Some(Length::Definite(rems(0.).into())),
194 bottom: Some(Length::Definite(rems(0.5).into())),
195 },
196 padding: EdgesRefinement {
197 top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
198 left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
199 right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
200 bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
201 },
202 background: Some(colors.editor_background.into()),
203 border_color: Some(colors.border_variant),
204 border_widths: EdgesRefinement {
205 top: Some(AbsoluteLength::Pixels(Pixels(1.))),
206 left: Some(AbsoluteLength::Pixels(Pixels(1.))),
207 right: Some(AbsoluteLength::Pixels(Pixels(1.))),
208 bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
209 },
210 text: Some(TextStyleRefinement {
211 font_family: Some(theme_settings.buffer_font.family.clone()),
212 font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
213 font_features: Some(theme_settings.buffer_font.features.clone()),
214 font_size: Some(buffer_font_size.into()),
215 ..Default::default()
216 }),
217 ..Default::default()
218 },
219 inline_code: TextStyleRefinement {
220 font_family: Some(theme_settings.buffer_font.family.clone()),
221 font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
222 font_features: Some(theme_settings.buffer_font.features.clone()),
223 font_size: Some(buffer_font_size.into()),
224 background_color: Some(colors.editor_foreground.opacity(0.1)),
225 ..Default::default()
226 },
227 link: TextStyleRefinement {
228 background_color: Some(colors.editor_foreground.opacity(0.025)),
229 underline: Some(UnderlineStyle {
230 color: Some(colors.text_accent.opacity(0.5)),
231 thickness: px(1.),
232 ..Default::default()
233 }),
234 ..Default::default()
235 },
236 ..Default::default()
237 };
238
239 cx.new(|cx| {
240 Markdown::new(
241 text,
242 markdown_style,
243 Some(self.language_registry.clone()),
244 None,
245 cx,
246 )
247 })
248 }
249
250 fn handle_thread_event(
251 &mut self,
252 _: &Entity<Thread>,
253 event: &ThreadEvent,
254 window: &mut Window,
255 cx: &mut Context<Self>,
256 ) {
257 match event {
258 ThreadEvent::ShowError(error) => {
259 self.last_error = Some(error.clone());
260 }
261 ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => {
262 self.save_thread(cx);
263 }
264 ThreadEvent::StreamedAssistantText(message_id, text) => {
265 if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
266 markdown.update(cx, |markdown, cx| {
267 markdown.append(text, cx);
268 });
269 }
270 }
271 ThreadEvent::MessageAdded(message_id) => {
272 if let Some(message_text) = self
273 .thread
274 .read(cx)
275 .message(*message_id)
276 .map(|message| message.text.clone())
277 {
278 self.push_message(message_id, message_text, window, cx);
279 }
280
281 self.save_thread(cx);
282 cx.notify();
283 }
284 ThreadEvent::MessageEdited(message_id) => {
285 if let Some(message_text) = self
286 .thread
287 .read(cx)
288 .message(*message_id)
289 .map(|message| message.text.clone())
290 {
291 self.edited_message(message_id, message_text, window, cx);
292 }
293
294 self.save_thread(cx);
295 cx.notify();
296 }
297 ThreadEvent::MessageDeleted(message_id) => {
298 self.deleted_message(message_id);
299 self.save_thread(cx);
300 cx.notify();
301 }
302 ThreadEvent::UsePendingTools => {
303 let pending_tool_uses = self
304 .thread
305 .read(cx)
306 .pending_tool_uses()
307 .into_iter()
308 .filter(|tool_use| tool_use.status.is_idle())
309 .cloned()
310 .collect::<Vec<_>>();
311
312 for tool_use in pending_tool_uses {
313 if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
314 let task = tool.run(tool_use.input, self.project.clone(), cx);
315
316 self.thread.update(cx, |thread, cx| {
317 thread.insert_tool_output(tool_use.id.clone(), task, cx);
318 });
319 }
320 }
321 }
322 ThreadEvent::ToolFinished { .. } => {
323 let all_tools_finished = self
324 .thread
325 .read(cx)
326 .pending_tool_uses()
327 .into_iter()
328 .all(|tool_use| tool_use.status.is_error());
329 if all_tools_finished {
330 let model_registry = LanguageModelRegistry::read_global(cx);
331 if let Some(model) = model_registry.active_model() {
332 self.thread.update(cx, |thread, cx| {
333 // Insert a user message to contain the tool results.
334 thread.insert_user_message(
335 // TODO: Sending up a user message without any content results in the model sending back
336 // responses that also don't have any content. We currently don't handle this case well,
337 // so for now we provide some text to keep the model on track.
338 "Here are the tool results.",
339 Vec::new(),
340 cx,
341 );
342 thread.send_to_model(model, RequestKind::Chat, true, cx);
343 });
344 }
345 }
346 }
347 }
348 }
349
350 /// Spawns a task to save the active thread.
351 ///
352 /// Only one task to save the thread will be in flight at a time.
353 fn save_thread(&mut self, cx: &mut Context<Self>) {
354 let thread = self.thread.clone();
355 self.save_thread_task = Some(cx.spawn(|this, mut cx| async move {
356 let task = this
357 .update(&mut cx, |this, cx| {
358 this.thread_store
359 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
360 })
361 .ok();
362
363 if let Some(task) = task {
364 task.await.log_err();
365 }
366 }));
367 }
368
369 fn start_editing_message(
370 &mut self,
371 message_id: MessageId,
372 message_text: String,
373 window: &mut Window,
374 cx: &mut Context<Self>,
375 ) {
376 let buffer = cx.new(|cx| {
377 MultiBuffer::singleton(cx.new(|cx| Buffer::local(message_text.clone(), cx)), cx)
378 });
379 let editor = cx.new(|cx| {
380 let mut editor = Editor::new(
381 editor::EditorMode::AutoHeight { max_lines: 8 },
382 buffer,
383 None,
384 false,
385 window,
386 cx,
387 );
388 editor.focus_handle(cx).focus(window);
389 editor.move_to_end(&editor::actions::MoveToEnd, window, cx);
390 editor
391 });
392 self.editing_message = Some((
393 message_id,
394 EditMessageState {
395 editor: editor.clone(),
396 },
397 ));
398 cx.notify();
399 }
400
401 fn cancel_editing_message(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
402 self.editing_message.take();
403 cx.notify();
404 }
405
406 fn confirm_editing_message(
407 &mut self,
408 _: &menu::Confirm,
409 _: &mut Window,
410 cx: &mut Context<Self>,
411 ) {
412 let Some((message_id, state)) = self.editing_message.take() else {
413 return;
414 };
415 let edited_text = state.editor.read(cx).text(cx);
416 self.thread.update(cx, |thread, cx| {
417 thread.edit_message(message_id, Role::User, edited_text, cx);
418 for message_id in self.messages_after(message_id) {
419 thread.delete_message(*message_id, cx);
420 }
421 });
422
423 let provider = LanguageModelRegistry::read_global(cx).active_provider();
424 if provider
425 .as_ref()
426 .map_or(false, |provider| provider.must_accept_terms(cx))
427 {
428 cx.notify();
429 return;
430 }
431 let model_registry = LanguageModelRegistry::read_global(cx);
432 let Some(model) = model_registry.active_model() else {
433 return;
434 };
435
436 self.thread.update(cx, |thread, cx| {
437 thread.send_to_model(model, RequestKind::Chat, false, cx)
438 });
439 cx.notify();
440 }
441
442 fn last_user_message(&self, cx: &Context<Self>) -> Option<MessageId> {
443 self.messages
444 .iter()
445 .rev()
446 .find(|message_id| {
447 self.thread
448 .read(cx)
449 .message(**message_id)
450 .map_or(false, |message| message.role == Role::User)
451 })
452 .cloned()
453 }
454
455 fn messages_after(&self, message_id: MessageId) -> &[MessageId] {
456 self.messages
457 .iter()
458 .position(|id| *id == message_id)
459 .map(|index| &self.messages[index + 1..])
460 .unwrap_or(&[])
461 }
462
463 fn handle_cancel_click(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
464 self.cancel_editing_message(&menu::Cancel, window, cx);
465 }
466
467 fn handle_regenerate_click(
468 &mut self,
469 _: &ClickEvent,
470 window: &mut Window,
471 cx: &mut Context<Self>,
472 ) {
473 self.confirm_editing_message(&menu::Confirm, window, cx);
474 }
475
476 fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
477 let message_id = self.messages[ix];
478 let Some(message) = self.thread.read(cx).message(message_id) else {
479 return Empty.into_any();
480 };
481
482 let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
483 return Empty.into_any();
484 };
485
486 let context = self.thread.read(cx).context_for_message(message_id);
487 let tool_uses = self.thread.read(cx).tool_uses_for_message(message_id);
488 let colors = cx.theme().colors();
489
490 // Don't render user messages that are just there for returning tool results.
491 if message.role == Role::User && self.thread.read(cx).message_has_tool_results(message_id) {
492 return Empty.into_any();
493 }
494
495 let allow_editing_message =
496 message.role == Role::User && self.last_user_message(cx) == Some(message_id);
497
498 let edit_message_editor = self
499 .editing_message
500 .as_ref()
501 .filter(|(id, _)| *id == message_id)
502 .map(|(_, state)| state.editor.clone());
503
504 let message_content = v_flex()
505 .child(
506 if let Some(edit_message_editor) = edit_message_editor.clone() {
507 div()
508 .key_context("EditMessageEditor")
509 .on_action(cx.listener(Self::cancel_editing_message))
510 .on_action(cx.listener(Self::confirm_editing_message))
511 .p_2p5()
512 .child(edit_message_editor)
513 } else {
514 div().p_2p5().text_ui(cx).child(markdown.clone())
515 },
516 )
517 .when_some(context, |parent, context| {
518 if !context.is_empty() {
519 parent.child(
520 h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
521 context
522 .into_iter()
523 .map(|context| ContextPill::added(context, false, false, None)),
524 ),
525 )
526 } else {
527 parent
528 }
529 });
530
531 let styled_message = match message.role {
532 Role::User => v_flex()
533 .id(("message-container", ix))
534 .pt_2p5()
535 .px_2p5()
536 .child(
537 v_flex()
538 .bg(colors.editor_background)
539 .rounded_lg()
540 .border_1()
541 .border_color(colors.border)
542 .shadow_sm()
543 .child(
544 h_flex()
545 .py_1()
546 .pl_2()
547 .pr_1()
548 .bg(colors.editor_foreground.opacity(0.05))
549 .border_b_1()
550 .border_color(colors.border)
551 .justify_between()
552 .rounded_t(px(6.))
553 .child(
554 h_flex()
555 .gap_1p5()
556 .child(
557 Icon::new(IconName::PersonCircle)
558 .size(IconSize::XSmall)
559 .color(Color::Muted),
560 )
561 .child(
562 Label::new("You")
563 .size(LabelSize::Small)
564 .color(Color::Muted),
565 ),
566 )
567 .when_some(
568 edit_message_editor.clone(),
569 |this, edit_message_editor| {
570 let focus_handle = edit_message_editor.focus_handle(cx);
571 this.child(
572 h_flex()
573 .gap_1()
574 .child(
575 Button::new("cancel-edit-message", "Cancel")
576 .label_size(LabelSize::Small)
577 .key_binding(
578 KeyBinding::for_action_in(
579 &menu::Cancel,
580 &focus_handle,
581 window,
582 cx,
583 )
584 .map(|kb| kb.size(rems_from_px(12.))),
585 )
586 .on_click(
587 cx.listener(Self::handle_cancel_click),
588 ),
589 )
590 .child(
591 Button::new(
592 "confirm-edit-message",
593 "Regenerate",
594 )
595 .label_size(LabelSize::Small)
596 .key_binding(
597 KeyBinding::for_action_in(
598 &menu::Confirm,
599 &focus_handle,
600 window,
601 cx,
602 )
603 .map(|kb| kb.size(rems_from_px(12.))),
604 )
605 .on_click(
606 cx.listener(Self::handle_regenerate_click),
607 ),
608 ),
609 )
610 },
611 )
612 .when(
613 edit_message_editor.is_none() && allow_editing_message,
614 |this| {
615 this.child(
616 Button::new("edit-message", "Edit")
617 .label_size(LabelSize::Small)
618 .on_click(cx.listener({
619 let message_text = message.text.clone();
620 move |this, _, window, cx| {
621 this.start_editing_message(
622 message_id,
623 message_text.clone(),
624 window,
625 cx,
626 );
627 }
628 })),
629 )
630 },
631 ),
632 )
633 .child(message_content),
634 ),
635 Role::Assistant => div()
636 .id(("message-container", ix))
637 .child(message_content)
638 .map(|parent| {
639 if tool_uses.is_empty() {
640 return parent;
641 }
642
643 parent.child(
644 v_flex().children(
645 tool_uses
646 .into_iter()
647 .map(|tool_use| self.render_tool_use(tool_use, cx)),
648 ),
649 )
650 }),
651 Role::System => div().id(("message-container", ix)).py_1().px_2().child(
652 v_flex()
653 .bg(colors.editor_background)
654 .rounded_sm()
655 .child(message_content),
656 ),
657 };
658
659 styled_message.into_any()
660 }
661
662 fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {
663 let is_open = self
664 .expanded_tool_uses
665 .get(&tool_use.id)
666 .copied()
667 .unwrap_or_default();
668
669 div().px_2p5().child(
670 v_flex()
671 .gap_1()
672 .rounded_lg()
673 .border_1()
674 .border_color(cx.theme().colors().border)
675 .child(
676 h_flex()
677 .justify_between()
678 .py_0p5()
679 .pl_1()
680 .pr_2()
681 .bg(cx.theme().colors().editor_foreground.opacity(0.02))
682 .when(is_open, |element| element.border_b_1().rounded_t(px(6.)))
683 .when(!is_open, |element| element.rounded_md())
684 .border_color(cx.theme().colors().border)
685 .child(
686 h_flex()
687 .gap_1()
688 .child(Disclosure::new("tool-use-disclosure", is_open).on_click(
689 cx.listener({
690 let tool_use_id = tool_use.id.clone();
691 move |this, _event, _window, _cx| {
692 let is_open = this
693 .expanded_tool_uses
694 .entry(tool_use_id.clone())
695 .or_insert(false);
696
697 *is_open = !*is_open;
698 }
699 }),
700 ))
701 .child(Label::new(tool_use.name)),
702 )
703 .child(
704 Label::new(match tool_use.status {
705 ToolUseStatus::Pending => "Pending",
706 ToolUseStatus::Running => "Running",
707 ToolUseStatus::Finished(_) => "Finished",
708 ToolUseStatus::Error(_) => "Error",
709 })
710 .size(LabelSize::XSmall)
711 .buffer_font(cx),
712 ),
713 )
714 .map(|parent| {
715 if !is_open {
716 return parent;
717 }
718
719 parent.child(
720 v_flex()
721 .child(
722 v_flex()
723 .gap_0p5()
724 .py_1()
725 .px_2p5()
726 .border_b_1()
727 .border_color(cx.theme().colors().border)
728 .child(Label::new("Input:"))
729 .child(Label::new(
730 serde_json::to_string_pretty(&tool_use.input)
731 .unwrap_or_default(),
732 )),
733 )
734 .map(|parent| match tool_use.status {
735 ToolUseStatus::Finished(output) => parent.child(
736 v_flex()
737 .gap_0p5()
738 .py_1()
739 .px_2p5()
740 .child(Label::new("Result:"))
741 .child(Label::new(output)),
742 ),
743 ToolUseStatus::Error(err) => parent.child(
744 v_flex()
745 .gap_0p5()
746 .py_1()
747 .px_2p5()
748 .child(Label::new("Error:"))
749 .child(Label::new(err)),
750 ),
751 ToolUseStatus::Pending | ToolUseStatus::Running => parent,
752 }),
753 )
754 }),
755 )
756 }
757}
758
759impl Render for ActiveThread {
760 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
761 v_flex()
762 .size_full()
763 .child(list(self.list_state.clone()).flex_grow())
764 }
765}