1use std::sync::Arc;
2
3use assistant_tool::ToolWorkingSet;
4use collections::HashMap;
5use editor::{Editor, MultiBuffer};
6use gpui::{
7 list, AbsoluteLength, AnyElement, App, DefiniteLength, EdgesRefinement, Empty, Entity,
8 Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
9 TextStyleRefinement, UnderlineStyle, WeakEntity,
10};
11use language::{Buffer, LanguageRegistry};
12use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
13use markdown::{Markdown, MarkdownStyle};
14use settings::Settings as _;
15use theme::ThemeSettings;
16use ui::{prelude::*, Disclosure, KeyBinding};
17use workspace::Workspace;
18
19use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
20use crate::thread_store::ThreadStore;
21use crate::tool_use::{ToolUse, ToolUseStatus};
22use crate::ui::ContextPill;
23
24pub struct ActiveThread {
25 workspace: WeakEntity<Workspace>,
26 language_registry: Arc<LanguageRegistry>,
27 tools: Arc<ToolWorkingSet>,
28 thread_store: Entity<ThreadStore>,
29 thread: Entity<Thread>,
30 messages: Vec<MessageId>,
31 list_state: ListState,
32 rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
33 editing_message: Option<(MessageId, EditMessageState)>,
34 expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
35 last_error: Option<ThreadError>,
36 _subscriptions: Vec<Subscription>,
37}
38
39struct EditMessageState {
40 editor: Entity<Editor>,
41}
42
43impl ActiveThread {
44 pub fn new(
45 thread: Entity<Thread>,
46 thread_store: Entity<ThreadStore>,
47 workspace: WeakEntity<Workspace>,
48 language_registry: Arc<LanguageRegistry>,
49 tools: Arc<ToolWorkingSet>,
50 window: &mut Window,
51 cx: &mut Context<Self>,
52 ) -> Self {
53 let subscriptions = vec![
54 cx.observe(&thread, |_, _, cx| cx.notify()),
55 cx.subscribe_in(&thread, window, Self::handle_thread_event),
56 ];
57
58 let mut this = Self {
59 workspace,
60 language_registry,
61 tools,
62 thread_store,
63 thread: thread.clone(),
64 messages: Vec::new(),
65 rendered_messages_by_id: HashMap::default(),
66 expanded_tool_uses: HashMap::default(),
67 list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
68 let this = cx.entity().downgrade();
69 move |ix, window: &mut Window, cx: &mut App| {
70 this.update(cx, |this, cx| this.render_message(ix, window, cx))
71 .unwrap()
72 }
73 }),
74 editing_message: None,
75 last_error: None,
76 _subscriptions: subscriptions,
77 };
78
79 for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
80 this.push_message(&message.id, message.text.clone(), window, cx);
81 }
82
83 this
84 }
85
86 pub fn thread(&self) -> &Entity<Thread> {
87 &self.thread
88 }
89
90 pub fn is_empty(&self) -> bool {
91 self.messages.is_empty()
92 }
93
94 pub fn summary(&self, cx: &App) -> Option<SharedString> {
95 self.thread.read(cx).summary()
96 }
97
98 pub fn summary_or_default(&self, cx: &App) -> SharedString {
99 self.thread.read(cx).summary_or_default()
100 }
101
102 pub fn cancel_last_completion(&mut self, cx: &mut App) -> bool {
103 self.last_error.take();
104 self.thread
105 .update(cx, |thread, _cx| thread.cancel_last_completion())
106 }
107
108 pub fn last_error(&self) -> Option<ThreadError> {
109 self.last_error.clone()
110 }
111
112 pub fn clear_last_error(&mut self) {
113 self.last_error.take();
114 }
115
116 fn push_message(
117 &mut self,
118 id: &MessageId,
119 text: String,
120 window: &mut Window,
121 cx: &mut Context<Self>,
122 ) {
123 let old_len = self.messages.len();
124 self.messages.push(*id);
125 self.list_state.splice(old_len..old_len, 1);
126
127 let markdown = self.render_markdown(text.into(), window, cx);
128 self.rendered_messages_by_id.insert(*id, markdown);
129 self.list_state.scroll_to(ListOffset {
130 item_ix: old_len,
131 offset_in_item: Pixels(0.0),
132 });
133 }
134
135 fn edited_message(
136 &mut self,
137 id: &MessageId,
138 text: String,
139 window: &mut Window,
140 cx: &mut Context<Self>,
141 ) {
142 let Some(index) = self.messages.iter().position(|message_id| message_id == id) else {
143 return;
144 };
145 self.list_state.splice(index..index + 1, 1);
146 let markdown = self.render_markdown(text.into(), window, cx);
147 self.rendered_messages_by_id.insert(*id, markdown);
148 }
149
150 fn deleted_message(&mut self, id: &MessageId) {
151 let Some(index) = self.messages.iter().position(|message_id| message_id == id) else {
152 return;
153 };
154 self.messages.remove(index);
155 self.list_state.splice(index..index + 1, 0);
156 self.rendered_messages_by_id.remove(id);
157 }
158
159 fn render_markdown(
160 &self,
161 text: SharedString,
162 window: &Window,
163 cx: &mut Context<Self>,
164 ) -> Entity<Markdown> {
165 let theme_settings = ThemeSettings::get_global(cx);
166 let colors = cx.theme().colors();
167 let ui_font_size = TextSize::Default.rems(cx);
168 let buffer_font_size = TextSize::Small.rems(cx);
169 let mut text_style = window.text_style();
170
171 text_style.refine(&TextStyleRefinement {
172 font_family: Some(theme_settings.ui_font.family.clone()),
173 font_size: Some(ui_font_size.into()),
174 color: Some(cx.theme().colors().text),
175 ..Default::default()
176 });
177
178 let markdown_style = MarkdownStyle {
179 base_text_style: text_style,
180 syntax: cx.theme().syntax().clone(),
181 selection_background_color: cx.theme().players().local().selection,
182 code_block_overflow_x_scroll: true,
183 table_overflow_x_scroll: true,
184 code_block: StyleRefinement {
185 margin: EdgesRefinement {
186 top: Some(Length::Definite(rems(0.).into())),
187 left: Some(Length::Definite(rems(0.).into())),
188 right: Some(Length::Definite(rems(0.).into())),
189 bottom: Some(Length::Definite(rems(0.5).into())),
190 },
191 padding: EdgesRefinement {
192 top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
193 left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
194 right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
195 bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
196 },
197 background: Some(colors.editor_background.into()),
198 border_color: Some(colors.border_variant),
199 border_widths: EdgesRefinement {
200 top: Some(AbsoluteLength::Pixels(Pixels(1.))),
201 left: Some(AbsoluteLength::Pixels(Pixels(1.))),
202 right: Some(AbsoluteLength::Pixels(Pixels(1.))),
203 bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
204 },
205 text: Some(TextStyleRefinement {
206 font_family: Some(theme_settings.buffer_font.family.clone()),
207 font_size: Some(buffer_font_size.into()),
208 ..Default::default()
209 }),
210 ..Default::default()
211 },
212 inline_code: TextStyleRefinement {
213 font_family: Some(theme_settings.buffer_font.family.clone()),
214 font_size: Some(buffer_font_size.into()),
215 background_color: Some(colors.editor_foreground.opacity(0.1)),
216 ..Default::default()
217 },
218 link: TextStyleRefinement {
219 background_color: Some(colors.editor_foreground.opacity(0.025)),
220 underline: Some(UnderlineStyle {
221 color: Some(colors.text_accent.opacity(0.5)),
222 thickness: px(1.),
223 ..Default::default()
224 }),
225 ..Default::default()
226 },
227 ..Default::default()
228 };
229
230 cx.new(|cx| {
231 Markdown::new(
232 text,
233 markdown_style,
234 Some(self.language_registry.clone()),
235 None,
236 cx,
237 )
238 })
239 }
240
241 fn handle_thread_event(
242 &mut self,
243 _: &Entity<Thread>,
244 event: &ThreadEvent,
245 window: &mut Window,
246 cx: &mut Context<Self>,
247 ) {
248 match event {
249 ThreadEvent::ShowError(error) => {
250 self.last_error = Some(error.clone());
251 }
252 ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => {
253 self.thread_store
254 .update(cx, |thread_store, cx| {
255 thread_store.save_thread(&self.thread, cx)
256 })
257 .detach_and_log_err(cx);
258 }
259 ThreadEvent::StreamedAssistantText(message_id, text) => {
260 if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
261 markdown.update(cx, |markdown, cx| {
262 markdown.append(text, cx);
263 });
264 }
265 }
266 ThreadEvent::MessageAdded(message_id) => {
267 if let Some(message_text) = self
268 .thread
269 .read(cx)
270 .message(*message_id)
271 .map(|message| message.text.clone())
272 {
273 self.push_message(message_id, message_text, window, cx);
274 }
275
276 self.thread_store
277 .update(cx, |thread_store, cx| {
278 thread_store.save_thread(&self.thread, cx)
279 })
280 .detach_and_log_err(cx);
281
282 cx.notify();
283 }
284 ThreadEvent::MessageEdited(message_id) => {
285 if let Some(message_text) = self
286 .thread
287 .read(cx)
288 .message(*message_id)
289 .map(|message| message.text.clone())
290 {
291 self.edited_message(message_id, message_text, window, cx);
292 }
293
294 self.thread_store
295 .update(cx, |thread_store, cx| {
296 thread_store.save_thread(&self.thread, cx)
297 })
298 .detach_and_log_err(cx);
299
300 cx.notify();
301 }
302 ThreadEvent::MessageDeleted(message_id) => {
303 self.deleted_message(message_id);
304
305 self.thread_store
306 .update(cx, |thread_store, cx| {
307 thread_store.save_thread(&self.thread, cx)
308 })
309 .detach_and_log_err(cx);
310
311 cx.notify();
312 }
313 ThreadEvent::UsePendingTools => {
314 let pending_tool_uses = self
315 .thread
316 .read(cx)
317 .pending_tool_uses()
318 .into_iter()
319 .filter(|tool_use| tool_use.status.is_idle())
320 .cloned()
321 .collect::<Vec<_>>();
322
323 for tool_use in pending_tool_uses {
324 if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
325 let task = tool.run(tool_use.input, self.workspace.clone(), window, cx);
326
327 self.thread.update(cx, |thread, cx| {
328 thread.insert_tool_output(tool_use.id.clone(), task, cx);
329 });
330 }
331 }
332 }
333 ThreadEvent::ToolFinished { .. } => {
334 let all_tools_finished = self
335 .thread
336 .read(cx)
337 .pending_tool_uses()
338 .into_iter()
339 .all(|tool_use| tool_use.status.is_error());
340 if all_tools_finished {
341 let model_registry = LanguageModelRegistry::read_global(cx);
342 if let Some(model) = model_registry.active_model() {
343 self.thread.update(cx, |thread, cx| {
344 // Insert a user message to contain the tool results.
345 thread.insert_user_message(
346 // TODO: Sending up a user message without any content results in the model sending back
347 // responses that also don't have any content. We currently don't handle this case well,
348 // so for now we provide some text to keep the model on track.
349 "Here are the tool results.",
350 Vec::new(),
351 cx,
352 );
353 thread.send_to_model(model, RequestKind::Chat, true, cx);
354 });
355 }
356 }
357 }
358 }
359 }
360
361 fn start_editing_message(
362 &mut self,
363 message_id: MessageId,
364 message_text: String,
365 window: &mut Window,
366 cx: &mut Context<Self>,
367 ) {
368 let buffer = cx.new(|cx| {
369 MultiBuffer::singleton(cx.new(|cx| Buffer::local(message_text.clone(), cx)), cx)
370 });
371 let editor = cx.new(|cx| {
372 let mut editor = Editor::new(
373 editor::EditorMode::AutoHeight { max_lines: 8 },
374 buffer,
375 None,
376 false,
377 window,
378 cx,
379 );
380 editor.focus_handle(cx).focus(window);
381 editor.move_to_end(&editor::actions::MoveToEnd, window, cx);
382 editor
383 });
384 self.editing_message = Some((
385 message_id,
386 EditMessageState {
387 editor: editor.clone(),
388 },
389 ));
390 cx.notify();
391 }
392
393 fn cancel_editing_message(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
394 self.editing_message.take();
395 cx.notify();
396 }
397
398 fn confirm_editing_message(
399 &mut self,
400 _: &menu::Confirm,
401 _: &mut Window,
402 cx: &mut Context<Self>,
403 ) {
404 let Some((message_id, state)) = self.editing_message.take() else {
405 return;
406 };
407 let edited_text = state.editor.read(cx).text(cx);
408 self.thread.update(cx, |thread, cx| {
409 thread.edit_message(message_id, Role::User, edited_text, cx);
410 for message_id in self.messages_after(message_id) {
411 thread.delete_message(*message_id, cx);
412 }
413 });
414
415 let provider = LanguageModelRegistry::read_global(cx).active_provider();
416 if provider
417 .as_ref()
418 .map_or(false, |provider| provider.must_accept_terms(cx))
419 {
420 cx.notify();
421 return;
422 }
423 let model_registry = LanguageModelRegistry::read_global(cx);
424 let Some(model) = model_registry.active_model() else {
425 return;
426 };
427
428 self.thread.update(cx, |thread, cx| {
429 thread.send_to_model(model, RequestKind::Chat, false, cx)
430 });
431 cx.notify();
432 }
433
434 fn last_user_message(&self, cx: &Context<Self>) -> Option<MessageId> {
435 self.messages
436 .iter()
437 .rev()
438 .find(|message_id| {
439 self.thread
440 .read(cx)
441 .message(**message_id)
442 .map_or(false, |message| message.role == Role::User)
443 })
444 .cloned()
445 }
446
447 fn messages_after(&self, message_id: MessageId) -> &[MessageId] {
448 self.messages
449 .iter()
450 .position(|id| *id == message_id)
451 .map(|index| &self.messages[index + 1..])
452 .unwrap_or(&[])
453 }
454
455 fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
456 let message_id = self.messages[ix];
457 let Some(message) = self.thread.read(cx).message(message_id) else {
458 return Empty.into_any();
459 };
460
461 let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
462 return Empty.into_any();
463 };
464
465 let context = self.thread.read(cx).context_for_message(message_id);
466 let tool_uses = self.thread.read(cx).tool_uses_for_message(message_id);
467 let colors = cx.theme().colors();
468
469 // Don't render user messages that are just there for returning tool results.
470 if message.role == Role::User && self.thread.read(cx).message_has_tool_results(message_id) {
471 return Empty.into_any();
472 }
473
474 let allow_editing_message =
475 message.role == Role::User && self.last_user_message(cx) == Some(message_id);
476
477 let edit_message_editor = self
478 .editing_message
479 .as_ref()
480 .filter(|(id, _)| *id == message_id)
481 .map(|(_, state)| state.editor.clone());
482
483 let message_content = v_flex()
484 .child(
485 if let Some(edit_message_editor) = edit_message_editor.clone() {
486 div()
487 .key_context("EditMessageEditor")
488 .on_action(cx.listener(Self::cancel_editing_message))
489 .on_action(cx.listener(Self::confirm_editing_message))
490 .p_2p5()
491 .child(edit_message_editor)
492 } else {
493 div().p_2p5().text_ui(cx).child(markdown.clone())
494 },
495 )
496 .when_some(context, |parent, context| {
497 if !context.is_empty() {
498 parent.child(
499 h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
500 context
501 .into_iter()
502 .map(|context| ContextPill::added(context, false, false, None)),
503 ),
504 )
505 } else {
506 parent
507 }
508 });
509
510 let styled_message = match message.role {
511 Role::User => v_flex()
512 .id(("message-container", ix))
513 .pt_2p5()
514 .px_2p5()
515 .child(
516 v_flex()
517 .bg(colors.editor_background)
518 .rounded_lg()
519 .border_1()
520 .border_color(colors.border)
521 .shadow_sm()
522 .child(
523 h_flex()
524 .py_1()
525 .px_2()
526 .bg(colors.editor_foreground.opacity(0.05))
527 .border_b_1()
528 .border_color(colors.border)
529 .justify_between()
530 .rounded_t(px(6.))
531 .child(
532 h_flex()
533 .gap_1p5()
534 .child(
535 Icon::new(IconName::PersonCircle)
536 .size(IconSize::XSmall)
537 .color(Color::Muted),
538 )
539 .child(
540 Label::new("You")
541 .size(LabelSize::Small)
542 .color(Color::Muted),
543 ),
544 )
545 .when_some(
546 edit_message_editor.clone(),
547 |this, edit_message_editor| {
548 let focus_handle = edit_message_editor.focus_handle(cx);
549 this.child(
550 h_flex()
551 .gap_1()
552 .child(
553 Button::new("cancel-edit-message", "Cancel")
554 .key_binding(KeyBinding::for_action_in(
555 &menu::Cancel,
556 &focus_handle,
557 window,
558 cx,
559 )),
560 )
561 .child(
562 Button::new(
563 "confirm-edit-message",
564 "Regenerate",
565 )
566 .key_binding(KeyBinding::for_action_in(
567 &menu::Confirm,
568 &focus_handle,
569 window,
570 cx,
571 )),
572 ),
573 )
574 },
575 )
576 .when(
577 edit_message_editor.is_none() && allow_editing_message,
578 |this| {
579 this.child(Button::new("edit-message", "Edit").on_click(
580 cx.listener({
581 let message_text = message.text.clone();
582 move |this, _, window, cx| {
583 this.start_editing_message(
584 message_id,
585 message_text.clone(),
586 window,
587 cx,
588 );
589 }
590 }),
591 ))
592 },
593 ),
594 )
595 .child(message_content),
596 ),
597 Role::Assistant => div()
598 .id(("message-container", ix))
599 .child(message_content)
600 .map(|parent| {
601 if tool_uses.is_empty() {
602 return parent;
603 }
604
605 parent.child(
606 v_flex().children(
607 tool_uses
608 .into_iter()
609 .map(|tool_use| self.render_tool_use(tool_use, cx)),
610 ),
611 )
612 }),
613 Role::System => div().id(("message-container", ix)).py_1().px_2().child(
614 v_flex()
615 .bg(colors.editor_background)
616 .rounded_md()
617 .child(message_content),
618 ),
619 };
620
621 styled_message.into_any()
622 }
623
624 fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {
625 let is_open = self
626 .expanded_tool_uses
627 .get(&tool_use.id)
628 .copied()
629 .unwrap_or_default();
630
631 div().px_2p5().child(
632 v_flex()
633 .gap_1()
634 .rounded_lg()
635 .border_1()
636 .border_color(cx.theme().colors().border)
637 .child(
638 h_flex()
639 .justify_between()
640 .py_0p5()
641 .pl_1()
642 .pr_2()
643 .bg(cx.theme().colors().editor_foreground.opacity(0.02))
644 .when(is_open, |element| element.border_b_1().rounded_t(px(6.)))
645 .when(!is_open, |element| element.rounded(px(6.)))
646 .border_color(cx.theme().colors().border)
647 .child(
648 h_flex()
649 .gap_1()
650 .child(Disclosure::new("tool-use-disclosure", is_open).on_click(
651 cx.listener({
652 let tool_use_id = tool_use.id.clone();
653 move |this, _event, _window, _cx| {
654 let is_open = this
655 .expanded_tool_uses
656 .entry(tool_use_id.clone())
657 .or_insert(false);
658
659 *is_open = !*is_open;
660 }
661 }),
662 ))
663 .child(Label::new(tool_use.name)),
664 )
665 .child(
666 Label::new(match tool_use.status {
667 ToolUseStatus::Pending => "Pending",
668 ToolUseStatus::Running => "Running",
669 ToolUseStatus::Finished(_) => "Finished",
670 ToolUseStatus::Error(_) => "Error",
671 })
672 .size(LabelSize::XSmall)
673 .buffer_font(cx),
674 ),
675 )
676 .map(|parent| {
677 if !is_open {
678 return parent;
679 }
680
681 parent.child(
682 v_flex()
683 .child(
684 v_flex()
685 .gap_0p5()
686 .py_1()
687 .px_2p5()
688 .border_b_1()
689 .border_color(cx.theme().colors().border)
690 .child(Label::new("Input:"))
691 .child(Label::new(
692 serde_json::to_string_pretty(&tool_use.input)
693 .unwrap_or_default(),
694 )),
695 )
696 .map(|parent| match tool_use.status {
697 ToolUseStatus::Finished(output) => parent.child(
698 v_flex()
699 .gap_0p5()
700 .py_1()
701 .px_2p5()
702 .child(Label::new("Result:"))
703 .child(Label::new(output)),
704 ),
705 ToolUseStatus::Error(err) => parent.child(
706 v_flex()
707 .gap_0p5()
708 .py_1()
709 .px_2p5()
710 .child(Label::new("Error:"))
711 .child(Label::new(err)),
712 ),
713 ToolUseStatus::Pending | ToolUseStatus::Running => parent,
714 }),
715 )
716 }),
717 )
718 }
719}
720
721impl Render for ActiveThread {
722 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
723 v_flex()
724 .size_full()
725 .child(list(self.list_state.clone()).flex_grow())
726 }
727}