1use std::sync::Arc;
2
3use collections::HashMap;
4use editor::{Editor, MultiBuffer};
5use gpui::{
6 list, AbsoluteLength, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty,
7 Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
8 Task, TextStyleRefinement, UnderlineStyle,
9};
10use language::{Buffer, LanguageRegistry};
11use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
12use markdown::{Markdown, MarkdownStyle};
13use settings::Settings as _;
14use theme::ThemeSettings;
15use ui::{prelude::*, Disclosure, KeyBinding};
16use util::ResultExt as _;
17
18use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
19use crate::thread_store::ThreadStore;
20use crate::tool_use::{ToolUse, ToolUseStatus};
21use crate::ui::ContextPill;
22
23pub struct ActiveThread {
24 language_registry: Arc<LanguageRegistry>,
25 thread_store: Entity<ThreadStore>,
26 thread: Entity<Thread>,
27 save_thread_task: Option<Task<()>>,
28 messages: Vec<MessageId>,
29 list_state: ListState,
30 rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
31 editing_message: Option<(MessageId, EditMessageState)>,
32 expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
33 last_error: Option<ThreadError>,
34 _subscriptions: Vec<Subscription>,
35}
36
37struct EditMessageState {
38 editor: Entity<Editor>,
39}
40
41impl ActiveThread {
42 pub fn new(
43 thread: Entity<Thread>,
44 thread_store: Entity<ThreadStore>,
45 language_registry: Arc<LanguageRegistry>,
46 window: &mut Window,
47 cx: &mut Context<Self>,
48 ) -> Self {
49 let subscriptions = vec![
50 cx.observe(&thread, |_, _, cx| cx.notify()),
51 cx.subscribe_in(&thread, window, Self::handle_thread_event),
52 ];
53
54 let mut this = Self {
55 language_registry,
56 thread_store,
57 thread: thread.clone(),
58 save_thread_task: None,
59 messages: Vec::new(),
60 rendered_messages_by_id: HashMap::default(),
61 expanded_tool_uses: HashMap::default(),
62 list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
63 let this = cx.entity().downgrade();
64 move |ix, window: &mut Window, cx: &mut App| {
65 this.update(cx, |this, cx| this.render_message(ix, window, cx))
66 .unwrap()
67 }
68 }),
69 editing_message: None,
70 last_error: None,
71 _subscriptions: subscriptions,
72 };
73
74 for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
75 this.push_message(&message.id, message.text.clone(), window, cx);
76 }
77
78 this
79 }
80
81 pub fn thread(&self) -> &Entity<Thread> {
82 &self.thread
83 }
84
85 pub fn is_empty(&self) -> bool {
86 self.messages.is_empty()
87 }
88
89 pub fn summary(&self, cx: &App) -> Option<SharedString> {
90 self.thread.read(cx).summary()
91 }
92
93 pub fn summary_or_default(&self, cx: &App) -> SharedString {
94 self.thread.read(cx).summary_or_default()
95 }
96
97 pub fn cancel_last_completion(&mut self, cx: &mut App) -> bool {
98 self.last_error.take();
99 self.thread
100 .update(cx, |thread, _cx| thread.cancel_last_completion())
101 }
102
103 pub fn last_error(&self) -> Option<ThreadError> {
104 self.last_error.clone()
105 }
106
107 pub fn clear_last_error(&mut self) {
108 self.last_error.take();
109 }
110
111 fn push_message(
112 &mut self,
113 id: &MessageId,
114 text: String,
115 window: &mut Window,
116 cx: &mut Context<Self>,
117 ) {
118 let old_len = self.messages.len();
119 self.messages.push(*id);
120 self.list_state.splice(old_len..old_len, 1);
121
122 let markdown = self.render_markdown(text.into(), window, cx);
123 self.rendered_messages_by_id.insert(*id, markdown);
124 self.list_state.scroll_to(ListOffset {
125 item_ix: old_len,
126 offset_in_item: Pixels(0.0),
127 });
128 }
129
130 fn edited_message(
131 &mut self,
132 id: &MessageId,
133 text: String,
134 window: &mut Window,
135 cx: &mut Context<Self>,
136 ) {
137 let Some(index) = self.messages.iter().position(|message_id| message_id == id) else {
138 return;
139 };
140 self.list_state.splice(index..index + 1, 1);
141 let markdown = self.render_markdown(text.into(), window, cx);
142 self.rendered_messages_by_id.insert(*id, markdown);
143 }
144
145 fn deleted_message(&mut self, id: &MessageId) {
146 let Some(index) = self.messages.iter().position(|message_id| message_id == id) else {
147 return;
148 };
149 self.messages.remove(index);
150 self.list_state.splice(index..index + 1, 0);
151 self.rendered_messages_by_id.remove(id);
152 }
153
154 fn render_markdown(
155 &self,
156 text: SharedString,
157 window: &Window,
158 cx: &mut Context<Self>,
159 ) -> Entity<Markdown> {
160 let theme_settings = ThemeSettings::get_global(cx);
161 let colors = cx.theme().colors();
162 let ui_font_size = TextSize::Default.rems(cx);
163 let buffer_font_size = TextSize::Small.rems(cx);
164 let mut text_style = window.text_style();
165
166 text_style.refine(&TextStyleRefinement {
167 font_family: Some(theme_settings.ui_font.family.clone()),
168 font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
169 font_features: Some(theme_settings.ui_font.features.clone()),
170 font_size: Some(ui_font_size.into()),
171 color: Some(cx.theme().colors().text),
172 ..Default::default()
173 });
174
175 let markdown_style = MarkdownStyle {
176 base_text_style: text_style,
177 syntax: cx.theme().syntax().clone(),
178 selection_background_color: cx.theme().players().local().selection,
179 code_block_overflow_x_scroll: true,
180 table_overflow_x_scroll: true,
181 code_block: StyleRefinement {
182 margin: EdgesRefinement {
183 top: Some(Length::Definite(rems(0.).into())),
184 left: Some(Length::Definite(rems(0.).into())),
185 right: Some(Length::Definite(rems(0.).into())),
186 bottom: Some(Length::Definite(rems(0.5).into())),
187 },
188 padding: EdgesRefinement {
189 top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
190 left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
191 right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
192 bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
193 },
194 background: Some(colors.editor_background.into()),
195 border_color: Some(colors.border_variant),
196 border_widths: EdgesRefinement {
197 top: Some(AbsoluteLength::Pixels(Pixels(1.))),
198 left: Some(AbsoluteLength::Pixels(Pixels(1.))),
199 right: Some(AbsoluteLength::Pixels(Pixels(1.))),
200 bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
201 },
202 text: Some(TextStyleRefinement {
203 font_family: Some(theme_settings.buffer_font.family.clone()),
204 font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
205 font_features: Some(theme_settings.buffer_font.features.clone()),
206 font_size: Some(buffer_font_size.into()),
207 ..Default::default()
208 }),
209 ..Default::default()
210 },
211 inline_code: TextStyleRefinement {
212 font_family: Some(theme_settings.buffer_font.family.clone()),
213 font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
214 font_features: Some(theme_settings.buffer_font.features.clone()),
215 font_size: Some(buffer_font_size.into()),
216 background_color: Some(colors.editor_foreground.opacity(0.1)),
217 ..Default::default()
218 },
219 link: TextStyleRefinement {
220 background_color: Some(colors.editor_foreground.opacity(0.025)),
221 underline: Some(UnderlineStyle {
222 color: Some(colors.text_accent.opacity(0.5)),
223 thickness: px(1.),
224 ..Default::default()
225 }),
226 ..Default::default()
227 },
228 ..Default::default()
229 };
230
231 cx.new(|cx| {
232 Markdown::new(
233 text,
234 markdown_style,
235 Some(self.language_registry.clone()),
236 None,
237 cx,
238 )
239 })
240 }
241
242 fn handle_thread_event(
243 &mut self,
244 _: &Entity<Thread>,
245 event: &ThreadEvent,
246 window: &mut Window,
247 cx: &mut Context<Self>,
248 ) {
249 match event {
250 ThreadEvent::ShowError(error) => {
251 self.last_error = Some(error.clone());
252 }
253 ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => {
254 self.save_thread(cx);
255 }
256 ThreadEvent::StreamedAssistantText(message_id, text) => {
257 if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
258 markdown.update(cx, |markdown, cx| {
259 markdown.append(text, cx);
260 });
261 }
262 }
263 ThreadEvent::MessageAdded(message_id) => {
264 if let Some(message_text) = self
265 .thread
266 .read(cx)
267 .message(*message_id)
268 .map(|message| message.text.clone())
269 {
270 self.push_message(message_id, message_text, window, cx);
271 }
272
273 self.save_thread(cx);
274 cx.notify();
275 }
276 ThreadEvent::MessageEdited(message_id) => {
277 if let Some(message_text) = self
278 .thread
279 .read(cx)
280 .message(*message_id)
281 .map(|message| message.text.clone())
282 {
283 self.edited_message(message_id, message_text, window, cx);
284 }
285
286 self.save_thread(cx);
287 cx.notify();
288 }
289 ThreadEvent::MessageDeleted(message_id) => {
290 self.deleted_message(message_id);
291 self.save_thread(cx);
292 cx.notify();
293 }
294 ThreadEvent::UsePendingTools => {
295 self.thread.update(cx, |thread, cx| {
296 thread.use_pending_tools(cx);
297 });
298 }
299 ThreadEvent::ToolFinished { .. } => {
300 if self.thread.read(cx).all_tools_finished() {
301 let model_registry = LanguageModelRegistry::read_global(cx);
302 if let Some(model) = model_registry.active_model() {
303 self.thread.update(cx, |thread, cx| {
304 thread.send_tool_results_to_model(model, cx);
305 });
306 }
307 }
308 }
309 }
310 }
311
312 /// Spawns a task to save the active thread.
313 ///
314 /// Only one task to save the thread will be in flight at a time.
315 fn save_thread(&mut self, cx: &mut Context<Self>) {
316 let thread = self.thread.clone();
317 self.save_thread_task = Some(cx.spawn(|this, mut cx| async move {
318 let task = this
319 .update(&mut cx, |this, cx| {
320 this.thread_store
321 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
322 })
323 .ok();
324
325 if let Some(task) = task {
326 task.await.log_err();
327 }
328 }));
329 }
330
331 fn start_editing_message(
332 &mut self,
333 message_id: MessageId,
334 message_text: String,
335 window: &mut Window,
336 cx: &mut Context<Self>,
337 ) {
338 let buffer = cx.new(|cx| {
339 MultiBuffer::singleton(cx.new(|cx| Buffer::local(message_text.clone(), cx)), cx)
340 });
341 let editor = cx.new(|cx| {
342 let mut editor = Editor::new(
343 editor::EditorMode::AutoHeight { max_lines: 8 },
344 buffer,
345 None,
346 false,
347 window,
348 cx,
349 );
350 editor.focus_handle(cx).focus(window);
351 editor.move_to_end(&editor::actions::MoveToEnd, window, cx);
352 editor
353 });
354 self.editing_message = Some((
355 message_id,
356 EditMessageState {
357 editor: editor.clone(),
358 },
359 ));
360 cx.notify();
361 }
362
363 fn cancel_editing_message(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
364 self.editing_message.take();
365 cx.notify();
366 }
367
368 fn confirm_editing_message(
369 &mut self,
370 _: &menu::Confirm,
371 _: &mut Window,
372 cx: &mut Context<Self>,
373 ) {
374 let Some((message_id, state)) = self.editing_message.take() else {
375 return;
376 };
377 let edited_text = state.editor.read(cx).text(cx);
378 self.thread.update(cx, |thread, cx| {
379 thread.edit_message(message_id, Role::User, edited_text, cx);
380 for message_id in self.messages_after(message_id) {
381 thread.delete_message(*message_id, cx);
382 }
383 });
384
385 let provider = LanguageModelRegistry::read_global(cx).active_provider();
386 if provider
387 .as_ref()
388 .map_or(false, |provider| provider.must_accept_terms(cx))
389 {
390 cx.notify();
391 return;
392 }
393 let model_registry = LanguageModelRegistry::read_global(cx);
394 let Some(model) = model_registry.active_model() else {
395 return;
396 };
397
398 self.thread.update(cx, |thread, cx| {
399 thread.send_to_model(model, RequestKind::Chat, false, cx)
400 });
401 cx.notify();
402 }
403
404 fn last_user_message(&self, cx: &Context<Self>) -> Option<MessageId> {
405 self.messages
406 .iter()
407 .rev()
408 .find(|message_id| {
409 self.thread
410 .read(cx)
411 .message(**message_id)
412 .map_or(false, |message| message.role == Role::User)
413 })
414 .cloned()
415 }
416
417 fn messages_after(&self, message_id: MessageId) -> &[MessageId] {
418 self.messages
419 .iter()
420 .position(|id| *id == message_id)
421 .map(|index| &self.messages[index + 1..])
422 .unwrap_or(&[])
423 }
424
425 fn handle_cancel_click(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
426 self.cancel_editing_message(&menu::Cancel, window, cx);
427 }
428
429 fn handle_regenerate_click(
430 &mut self,
431 _: &ClickEvent,
432 window: &mut Window,
433 cx: &mut Context<Self>,
434 ) {
435 self.confirm_editing_message(&menu::Confirm, window, cx);
436 }
437
438 fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
439 let message_id = self.messages[ix];
440 let Some(message) = self.thread.read(cx).message(message_id) else {
441 return Empty.into_any();
442 };
443
444 let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
445 return Empty.into_any();
446 };
447
448 let context = self.thread.read(cx).context_for_message(message_id);
449 let tool_uses = self.thread.read(cx).tool_uses_for_message(message_id);
450 let colors = cx.theme().colors();
451
452 // Don't render user messages that are just there for returning tool results.
453 if message.role == Role::User && self.thread.read(cx).message_has_tool_results(message_id) {
454 return Empty.into_any();
455 }
456
457 let allow_editing_message =
458 message.role == Role::User && self.last_user_message(cx) == Some(message_id);
459
460 let edit_message_editor = self
461 .editing_message
462 .as_ref()
463 .filter(|(id, _)| *id == message_id)
464 .map(|(_, state)| state.editor.clone());
465
466 let message_content = v_flex()
467 .child(
468 if let Some(edit_message_editor) = edit_message_editor.clone() {
469 div()
470 .key_context("EditMessageEditor")
471 .on_action(cx.listener(Self::cancel_editing_message))
472 .on_action(cx.listener(Self::confirm_editing_message))
473 .p_2p5()
474 .child(edit_message_editor)
475 } else {
476 div().p_2p5().text_ui(cx).child(markdown.clone())
477 },
478 )
479 .when_some(context, |parent, context| {
480 if !context.is_empty() {
481 parent.child(
482 h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
483 context
484 .into_iter()
485 .map(|context| ContextPill::added(context, false, false, None)),
486 ),
487 )
488 } else {
489 parent
490 }
491 });
492
493 let styled_message = match message.role {
494 Role::User => v_flex()
495 .id(("message-container", ix))
496 .pt_2p5()
497 .px_2p5()
498 .child(
499 v_flex()
500 .bg(colors.editor_background)
501 .rounded_lg()
502 .border_1()
503 .border_color(colors.border)
504 .shadow_sm()
505 .child(
506 h_flex()
507 .py_1()
508 .pl_2()
509 .pr_1()
510 .bg(colors.editor_foreground.opacity(0.05))
511 .border_b_1()
512 .border_color(colors.border)
513 .justify_between()
514 .rounded_t(px(6.))
515 .child(
516 h_flex()
517 .gap_1p5()
518 .child(
519 Icon::new(IconName::PersonCircle)
520 .size(IconSize::XSmall)
521 .color(Color::Muted),
522 )
523 .child(
524 Label::new("You")
525 .size(LabelSize::Small)
526 .color(Color::Muted),
527 ),
528 )
529 .when_some(
530 edit_message_editor.clone(),
531 |this, edit_message_editor| {
532 let focus_handle = edit_message_editor.focus_handle(cx);
533 this.child(
534 h_flex()
535 .gap_1()
536 .child(
537 Button::new("cancel-edit-message", "Cancel")
538 .label_size(LabelSize::Small)
539 .key_binding(
540 KeyBinding::for_action_in(
541 &menu::Cancel,
542 &focus_handle,
543 window,
544 cx,
545 )
546 .map(|kb| kb.size(rems_from_px(12.))),
547 )
548 .on_click(
549 cx.listener(Self::handle_cancel_click),
550 ),
551 )
552 .child(
553 Button::new(
554 "confirm-edit-message",
555 "Regenerate",
556 )
557 .label_size(LabelSize::Small)
558 .key_binding(
559 KeyBinding::for_action_in(
560 &menu::Confirm,
561 &focus_handle,
562 window,
563 cx,
564 )
565 .map(|kb| kb.size(rems_from_px(12.))),
566 )
567 .on_click(
568 cx.listener(Self::handle_regenerate_click),
569 ),
570 ),
571 )
572 },
573 )
574 .when(
575 edit_message_editor.is_none() && allow_editing_message,
576 |this| {
577 this.child(
578 Button::new("edit-message", "Edit")
579 .label_size(LabelSize::Small)
580 .on_click(cx.listener({
581 let message_text = message.text.clone();
582 move |this, _, window, cx| {
583 this.start_editing_message(
584 message_id,
585 message_text.clone(),
586 window,
587 cx,
588 );
589 }
590 })),
591 )
592 },
593 ),
594 )
595 .child(message_content),
596 ),
597 Role::Assistant => div()
598 .id(("message-container", ix))
599 .child(message_content)
600 .map(|parent| {
601 if tool_uses.is_empty() {
602 return parent;
603 }
604
605 parent.child(
606 v_flex().children(
607 tool_uses
608 .into_iter()
609 .map(|tool_use| self.render_tool_use(tool_use, cx)),
610 ),
611 )
612 }),
613 Role::System => div().id(("message-container", ix)).py_1().px_2().child(
614 v_flex()
615 .bg(colors.editor_background)
616 .rounded_sm()
617 .child(message_content),
618 ),
619 };
620
621 styled_message.into_any()
622 }
623
624 fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {
625 let is_open = self
626 .expanded_tool_uses
627 .get(&tool_use.id)
628 .copied()
629 .unwrap_or_default();
630
631 div().px_2p5().child(
632 v_flex()
633 .gap_1()
634 .rounded_lg()
635 .border_1()
636 .border_color(cx.theme().colors().border)
637 .child(
638 h_flex()
639 .justify_between()
640 .py_0p5()
641 .pl_1()
642 .pr_2()
643 .bg(cx.theme().colors().editor_foreground.opacity(0.02))
644 .when(is_open, |element| element.border_b_1().rounded_t(px(6.)))
645 .when(!is_open, |element| element.rounded_md())
646 .border_color(cx.theme().colors().border)
647 .child(
648 h_flex()
649 .gap_1()
650 .child(Disclosure::new("tool-use-disclosure", is_open).on_click(
651 cx.listener({
652 let tool_use_id = tool_use.id.clone();
653 move |this, _event, _window, _cx| {
654 let is_open = this
655 .expanded_tool_uses
656 .entry(tool_use_id.clone())
657 .or_insert(false);
658
659 *is_open = !*is_open;
660 }
661 }),
662 ))
663 .child(Label::new(tool_use.name)),
664 )
665 .child(
666 Label::new(match tool_use.status {
667 ToolUseStatus::Pending => "Pending",
668 ToolUseStatus::Running => "Running",
669 ToolUseStatus::Finished(_) => "Finished",
670 ToolUseStatus::Error(_) => "Error",
671 })
672 .size(LabelSize::XSmall)
673 .buffer_font(cx),
674 ),
675 )
676 .map(|parent| {
677 if !is_open {
678 return parent;
679 }
680
681 parent.child(
682 v_flex()
683 .child(
684 v_flex()
685 .gap_0p5()
686 .py_1()
687 .px_2p5()
688 .border_b_1()
689 .border_color(cx.theme().colors().border)
690 .child(Label::new("Input:"))
691 .child(Label::new(
692 serde_json::to_string_pretty(&tool_use.input)
693 .unwrap_or_default(),
694 )),
695 )
696 .map(|parent| match tool_use.status {
697 ToolUseStatus::Finished(output) => parent.child(
698 v_flex()
699 .gap_0p5()
700 .py_1()
701 .px_2p5()
702 .child(Label::new("Result:"))
703 .child(Label::new(output)),
704 ),
705 ToolUseStatus::Error(err) => parent.child(
706 v_flex()
707 .gap_0p5()
708 .py_1()
709 .px_2p5()
710 .child(Label::new("Error:"))
711 .child(Label::new(err)),
712 ),
713 ToolUseStatus::Pending | ToolUseStatus::Running => parent,
714 }),
715 )
716 }),
717 )
718 }
719}
720
721impl Render for ActiveThread {
722 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
723 v_flex()
724 .size_full()
725 .child(list(self.list_state.clone()).flex_grow())
726 }
727}