1use std::sync::Arc;
2
3use collections::HashMap;
4use editor::{Editor, MultiBuffer};
5use gpui::{
6 list, AbsoluteLength, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty,
7 Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
8 Task, TextStyleRefinement, UnderlineStyle,
9};
10use language::{Buffer, LanguageRegistry};
11use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
12use markdown::{Markdown, MarkdownStyle};
13use settings::Settings as _;
14use theme::ThemeSettings;
15use ui::{prelude::*, Disclosure, KeyBinding};
16use util::ResultExt as _;
17
18use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
19use crate::thread_store::ThreadStore;
20use crate::tool_use::{ToolUse, ToolUseStatus};
21use crate::ui::ContextPill;
22
23pub struct ActiveThread {
24 language_registry: Arc<LanguageRegistry>,
25 thread_store: Entity<ThreadStore>,
26 thread: Entity<Thread>,
27 save_thread_task: Option<Task<()>>,
28 messages: Vec<MessageId>,
29 list_state: ListState,
30 rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
31 editing_message: Option<(MessageId, EditMessageState)>,
32 expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
33 last_error: Option<ThreadError>,
34 _subscriptions: Vec<Subscription>,
35}
36
37struct EditMessageState {
38 editor: Entity<Editor>,
39}
40
41impl ActiveThread {
42 pub fn new(
43 thread: Entity<Thread>,
44 thread_store: Entity<ThreadStore>,
45 language_registry: Arc<LanguageRegistry>,
46 window: &mut Window,
47 cx: &mut Context<Self>,
48 ) -> Self {
49 let subscriptions = vec![
50 cx.observe(&thread, |_, _, cx| cx.notify()),
51 cx.subscribe_in(&thread, window, Self::handle_thread_event),
52 ];
53
54 let mut this = Self {
55 language_registry,
56 thread_store,
57 thread: thread.clone(),
58 save_thread_task: None,
59 messages: Vec::new(),
60 rendered_messages_by_id: HashMap::default(),
61 expanded_tool_uses: HashMap::default(),
62 list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
63 let this = cx.entity().downgrade();
64 move |ix, window: &mut Window, cx: &mut App| {
65 this.update(cx, |this, cx| this.render_message(ix, window, cx))
66 .unwrap()
67 }
68 }),
69 editing_message: None,
70 last_error: None,
71 _subscriptions: subscriptions,
72 };
73
74 for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
75 this.push_message(&message.id, message.text.clone(), window, cx);
76 }
77
78 this
79 }
80
81 pub fn thread(&self) -> &Entity<Thread> {
82 &self.thread
83 }
84
85 pub fn is_empty(&self) -> bool {
86 self.messages.is_empty()
87 }
88
89 pub fn summary(&self, cx: &App) -> Option<SharedString> {
90 self.thread.read(cx).summary()
91 }
92
93 pub fn summary_or_default(&self, cx: &App) -> SharedString {
94 self.thread.read(cx).summary_or_default()
95 }
96
97 pub fn cancel_last_completion(&mut self, cx: &mut App) -> bool {
98 self.last_error.take();
99 self.thread
100 .update(cx, |thread, _cx| thread.cancel_last_completion())
101 }
102
103 pub fn last_error(&self) -> Option<ThreadError> {
104 self.last_error.clone()
105 }
106
107 pub fn clear_last_error(&mut self) {
108 self.last_error.take();
109 }
110
111 fn push_message(
112 &mut self,
113 id: &MessageId,
114 text: String,
115 window: &mut Window,
116 cx: &mut Context<Self>,
117 ) {
118 let old_len = self.messages.len();
119 self.messages.push(*id);
120 self.list_state.splice(old_len..old_len, 1);
121
122 let markdown = self.render_markdown(text.into(), window, cx);
123 self.rendered_messages_by_id.insert(*id, markdown);
124 self.list_state.scroll_to(ListOffset {
125 item_ix: old_len,
126 offset_in_item: Pixels(0.0),
127 });
128 }
129
130 fn edited_message(
131 &mut self,
132 id: &MessageId,
133 text: String,
134 window: &mut Window,
135 cx: &mut Context<Self>,
136 ) {
137 let Some(index) = self.messages.iter().position(|message_id| message_id == id) else {
138 return;
139 };
140 self.list_state.splice(index..index + 1, 1);
141 let markdown = self.render_markdown(text.into(), window, cx);
142 self.rendered_messages_by_id.insert(*id, markdown);
143 }
144
145 fn deleted_message(&mut self, id: &MessageId) {
146 let Some(index) = self.messages.iter().position(|message_id| message_id == id) else {
147 return;
148 };
149 self.messages.remove(index);
150 self.list_state.splice(index..index + 1, 0);
151 self.rendered_messages_by_id.remove(id);
152 }
153
154 fn render_markdown(
155 &self,
156 text: SharedString,
157 window: &Window,
158 cx: &mut Context<Self>,
159 ) -> Entity<Markdown> {
160 let theme_settings = ThemeSettings::get_global(cx);
161 let colors = cx.theme().colors();
162 let ui_font_size = TextSize::Default.rems(cx);
163 let buffer_font_size = TextSize::Small.rems(cx);
164 let mut text_style = window.text_style();
165
166 text_style.refine(&TextStyleRefinement {
167 font_family: Some(theme_settings.ui_font.family.clone()),
168 font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
169 font_features: Some(theme_settings.ui_font.features.clone()),
170 font_size: Some(ui_font_size.into()),
171 color: Some(cx.theme().colors().text),
172 ..Default::default()
173 });
174
175 let markdown_style = MarkdownStyle {
176 base_text_style: text_style,
177 syntax: cx.theme().syntax().clone(),
178 selection_background_color: cx.theme().players().local().selection,
179 code_block_overflow_x_scroll: true,
180 table_overflow_x_scroll: true,
181 code_block: StyleRefinement {
182 margin: EdgesRefinement {
183 top: Some(Length::Definite(rems(0.).into())),
184 left: Some(Length::Definite(rems(0.).into())),
185 right: Some(Length::Definite(rems(0.).into())),
186 bottom: Some(Length::Definite(rems(0.5).into())),
187 },
188 padding: EdgesRefinement {
189 top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
190 left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
191 right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
192 bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
193 },
194 background: Some(colors.editor_background.into()),
195 border_color: Some(colors.border_variant),
196 border_widths: EdgesRefinement {
197 top: Some(AbsoluteLength::Pixels(Pixels(1.))),
198 left: Some(AbsoluteLength::Pixels(Pixels(1.))),
199 right: Some(AbsoluteLength::Pixels(Pixels(1.))),
200 bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
201 },
202 text: Some(TextStyleRefinement {
203 font_family: Some(theme_settings.buffer_font.family.clone()),
204 font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
205 font_features: Some(theme_settings.buffer_font.features.clone()),
206 font_size: Some(buffer_font_size.into()),
207 ..Default::default()
208 }),
209 ..Default::default()
210 },
211 inline_code: TextStyleRefinement {
212 font_family: Some(theme_settings.buffer_font.family.clone()),
213 font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
214 font_features: Some(theme_settings.buffer_font.features.clone()),
215 font_size: Some(buffer_font_size.into()),
216 background_color: Some(colors.editor_foreground.opacity(0.1)),
217 ..Default::default()
218 },
219 link: TextStyleRefinement {
220 background_color: Some(colors.editor_foreground.opacity(0.025)),
221 underline: Some(UnderlineStyle {
222 color: Some(colors.text_accent.opacity(0.5)),
223 thickness: px(1.),
224 ..Default::default()
225 }),
226 ..Default::default()
227 },
228 ..Default::default()
229 };
230
231 cx.new(|cx| {
232 Markdown::new(
233 text,
234 markdown_style,
235 Some(self.language_registry.clone()),
236 None,
237 cx,
238 )
239 })
240 }
241
242 fn handle_thread_event(
243 &mut self,
244 _: &Entity<Thread>,
245 event: &ThreadEvent,
246 window: &mut Window,
247 cx: &mut Context<Self>,
248 ) {
249 match event {
250 ThreadEvent::ShowError(error) => {
251 self.last_error = Some(error.clone());
252 }
253 ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => {
254 self.save_thread(cx);
255 }
256 ThreadEvent::StreamedAssistantText(message_id, text) => {
257 if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
258 markdown.update(cx, |markdown, cx| {
259 markdown.append(text, cx);
260 });
261 }
262 }
263 ThreadEvent::MessageAdded(message_id) => {
264 if let Some(message_text) = self
265 .thread
266 .read(cx)
267 .message(*message_id)
268 .map(|message| message.text.clone())
269 {
270 self.push_message(message_id, message_text, window, cx);
271 }
272
273 self.save_thread(cx);
274 cx.notify();
275 }
276 ThreadEvent::MessageEdited(message_id) => {
277 if let Some(message_text) = self
278 .thread
279 .read(cx)
280 .message(*message_id)
281 .map(|message| message.text.clone())
282 {
283 self.edited_message(message_id, message_text, window, cx);
284 }
285
286 self.save_thread(cx);
287 cx.notify();
288 }
289 ThreadEvent::MessageDeleted(message_id) => {
290 self.deleted_message(message_id);
291 self.save_thread(cx);
292 cx.notify();
293 }
294 ThreadEvent::UsePendingTools => {
295 self.thread.update(cx, |thread, cx| {
296 thread.use_pending_tools(cx);
297 });
298 }
299 ThreadEvent::ToolFinished { .. } => {
300 let all_tools_finished = self
301 .thread
302 .read(cx)
303 .pending_tool_uses()
304 .into_iter()
305 .all(|tool_use| tool_use.status.is_error());
306 if all_tools_finished {
307 let model_registry = LanguageModelRegistry::read_global(cx);
308 if let Some(model) = model_registry.active_model() {
309 self.thread.update(cx, |thread, cx| {
310 thread.send_tool_results_to_model(model, cx);
311 });
312 }
313 }
314 }
315 }
316 }
317
318 /// Spawns a task to save the active thread.
319 ///
320 /// Only one task to save the thread will be in flight at a time.
321 fn save_thread(&mut self, cx: &mut Context<Self>) {
322 let thread = self.thread.clone();
323 self.save_thread_task = Some(cx.spawn(|this, mut cx| async move {
324 let task = this
325 .update(&mut cx, |this, cx| {
326 this.thread_store
327 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
328 })
329 .ok();
330
331 if let Some(task) = task {
332 task.await.log_err();
333 }
334 }));
335 }
336
337 fn start_editing_message(
338 &mut self,
339 message_id: MessageId,
340 message_text: String,
341 window: &mut Window,
342 cx: &mut Context<Self>,
343 ) {
344 let buffer = cx.new(|cx| {
345 MultiBuffer::singleton(cx.new(|cx| Buffer::local(message_text.clone(), cx)), cx)
346 });
347 let editor = cx.new(|cx| {
348 let mut editor = Editor::new(
349 editor::EditorMode::AutoHeight { max_lines: 8 },
350 buffer,
351 None,
352 false,
353 window,
354 cx,
355 );
356 editor.focus_handle(cx).focus(window);
357 editor.move_to_end(&editor::actions::MoveToEnd, window, cx);
358 editor
359 });
360 self.editing_message = Some((
361 message_id,
362 EditMessageState {
363 editor: editor.clone(),
364 },
365 ));
366 cx.notify();
367 }
368
369 fn cancel_editing_message(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
370 self.editing_message.take();
371 cx.notify();
372 }
373
374 fn confirm_editing_message(
375 &mut self,
376 _: &menu::Confirm,
377 _: &mut Window,
378 cx: &mut Context<Self>,
379 ) {
380 let Some((message_id, state)) = self.editing_message.take() else {
381 return;
382 };
383 let edited_text = state.editor.read(cx).text(cx);
384 self.thread.update(cx, |thread, cx| {
385 thread.edit_message(message_id, Role::User, edited_text, cx);
386 for message_id in self.messages_after(message_id) {
387 thread.delete_message(*message_id, cx);
388 }
389 });
390
391 let provider = LanguageModelRegistry::read_global(cx).active_provider();
392 if provider
393 .as_ref()
394 .map_or(false, |provider| provider.must_accept_terms(cx))
395 {
396 cx.notify();
397 return;
398 }
399 let model_registry = LanguageModelRegistry::read_global(cx);
400 let Some(model) = model_registry.active_model() else {
401 return;
402 };
403
404 self.thread.update(cx, |thread, cx| {
405 thread.send_to_model(model, RequestKind::Chat, false, cx)
406 });
407 cx.notify();
408 }
409
410 fn last_user_message(&self, cx: &Context<Self>) -> Option<MessageId> {
411 self.messages
412 .iter()
413 .rev()
414 .find(|message_id| {
415 self.thread
416 .read(cx)
417 .message(**message_id)
418 .map_or(false, |message| message.role == Role::User)
419 })
420 .cloned()
421 }
422
423 fn messages_after(&self, message_id: MessageId) -> &[MessageId] {
424 self.messages
425 .iter()
426 .position(|id| *id == message_id)
427 .map(|index| &self.messages[index + 1..])
428 .unwrap_or(&[])
429 }
430
431 fn handle_cancel_click(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
432 self.cancel_editing_message(&menu::Cancel, window, cx);
433 }
434
435 fn handle_regenerate_click(
436 &mut self,
437 _: &ClickEvent,
438 window: &mut Window,
439 cx: &mut Context<Self>,
440 ) {
441 self.confirm_editing_message(&menu::Confirm, window, cx);
442 }
443
444 fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
445 let message_id = self.messages[ix];
446 let Some(message) = self.thread.read(cx).message(message_id) else {
447 return Empty.into_any();
448 };
449
450 let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
451 return Empty.into_any();
452 };
453
454 let context = self.thread.read(cx).context_for_message(message_id);
455 let tool_uses = self.thread.read(cx).tool_uses_for_message(message_id);
456 let colors = cx.theme().colors();
457
458 // Don't render user messages that are just there for returning tool results.
459 if message.role == Role::User && self.thread.read(cx).message_has_tool_results(message_id) {
460 return Empty.into_any();
461 }
462
463 let allow_editing_message =
464 message.role == Role::User && self.last_user_message(cx) == Some(message_id);
465
466 let edit_message_editor = self
467 .editing_message
468 .as_ref()
469 .filter(|(id, _)| *id == message_id)
470 .map(|(_, state)| state.editor.clone());
471
472 let message_content = v_flex()
473 .child(
474 if let Some(edit_message_editor) = edit_message_editor.clone() {
475 div()
476 .key_context("EditMessageEditor")
477 .on_action(cx.listener(Self::cancel_editing_message))
478 .on_action(cx.listener(Self::confirm_editing_message))
479 .p_2p5()
480 .child(edit_message_editor)
481 } else {
482 div().p_2p5().text_ui(cx).child(markdown.clone())
483 },
484 )
485 .when_some(context, |parent, context| {
486 if !context.is_empty() {
487 parent.child(
488 h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
489 context
490 .into_iter()
491 .map(|context| ContextPill::added(context, false, false, None)),
492 ),
493 )
494 } else {
495 parent
496 }
497 });
498
499 let styled_message = match message.role {
500 Role::User => v_flex()
501 .id(("message-container", ix))
502 .pt_2p5()
503 .px_2p5()
504 .child(
505 v_flex()
506 .bg(colors.editor_background)
507 .rounded_lg()
508 .border_1()
509 .border_color(colors.border)
510 .shadow_sm()
511 .child(
512 h_flex()
513 .py_1()
514 .pl_2()
515 .pr_1()
516 .bg(colors.editor_foreground.opacity(0.05))
517 .border_b_1()
518 .border_color(colors.border)
519 .justify_between()
520 .rounded_t(px(6.))
521 .child(
522 h_flex()
523 .gap_1p5()
524 .child(
525 Icon::new(IconName::PersonCircle)
526 .size(IconSize::XSmall)
527 .color(Color::Muted),
528 )
529 .child(
530 Label::new("You")
531 .size(LabelSize::Small)
532 .color(Color::Muted),
533 ),
534 )
535 .when_some(
536 edit_message_editor.clone(),
537 |this, edit_message_editor| {
538 let focus_handle = edit_message_editor.focus_handle(cx);
539 this.child(
540 h_flex()
541 .gap_1()
542 .child(
543 Button::new("cancel-edit-message", "Cancel")
544 .label_size(LabelSize::Small)
545 .key_binding(
546 KeyBinding::for_action_in(
547 &menu::Cancel,
548 &focus_handle,
549 window,
550 cx,
551 )
552 .map(|kb| kb.size(rems_from_px(12.))),
553 )
554 .on_click(
555 cx.listener(Self::handle_cancel_click),
556 ),
557 )
558 .child(
559 Button::new(
560 "confirm-edit-message",
561 "Regenerate",
562 )
563 .label_size(LabelSize::Small)
564 .key_binding(
565 KeyBinding::for_action_in(
566 &menu::Confirm,
567 &focus_handle,
568 window,
569 cx,
570 )
571 .map(|kb| kb.size(rems_from_px(12.))),
572 )
573 .on_click(
574 cx.listener(Self::handle_regenerate_click),
575 ),
576 ),
577 )
578 },
579 )
580 .when(
581 edit_message_editor.is_none() && allow_editing_message,
582 |this| {
583 this.child(
584 Button::new("edit-message", "Edit")
585 .label_size(LabelSize::Small)
586 .on_click(cx.listener({
587 let message_text = message.text.clone();
588 move |this, _, window, cx| {
589 this.start_editing_message(
590 message_id,
591 message_text.clone(),
592 window,
593 cx,
594 );
595 }
596 })),
597 )
598 },
599 ),
600 )
601 .child(message_content),
602 ),
603 Role::Assistant => div()
604 .id(("message-container", ix))
605 .child(message_content)
606 .map(|parent| {
607 if tool_uses.is_empty() {
608 return parent;
609 }
610
611 parent.child(
612 v_flex().children(
613 tool_uses
614 .into_iter()
615 .map(|tool_use| self.render_tool_use(tool_use, cx)),
616 ),
617 )
618 }),
619 Role::System => div().id(("message-container", ix)).py_1().px_2().child(
620 v_flex()
621 .bg(colors.editor_background)
622 .rounded_sm()
623 .child(message_content),
624 ),
625 };
626
627 styled_message.into_any()
628 }
629
630 fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {
631 let is_open = self
632 .expanded_tool_uses
633 .get(&tool_use.id)
634 .copied()
635 .unwrap_or_default();
636
637 div().px_2p5().child(
638 v_flex()
639 .gap_1()
640 .rounded_lg()
641 .border_1()
642 .border_color(cx.theme().colors().border)
643 .child(
644 h_flex()
645 .justify_between()
646 .py_0p5()
647 .pl_1()
648 .pr_2()
649 .bg(cx.theme().colors().editor_foreground.opacity(0.02))
650 .when(is_open, |element| element.border_b_1().rounded_t(px(6.)))
651 .when(!is_open, |element| element.rounded_md())
652 .border_color(cx.theme().colors().border)
653 .child(
654 h_flex()
655 .gap_1()
656 .child(Disclosure::new("tool-use-disclosure", is_open).on_click(
657 cx.listener({
658 let tool_use_id = tool_use.id.clone();
659 move |this, _event, _window, _cx| {
660 let is_open = this
661 .expanded_tool_uses
662 .entry(tool_use_id.clone())
663 .or_insert(false);
664
665 *is_open = !*is_open;
666 }
667 }),
668 ))
669 .child(Label::new(tool_use.name)),
670 )
671 .child(
672 Label::new(match tool_use.status {
673 ToolUseStatus::Pending => "Pending",
674 ToolUseStatus::Running => "Running",
675 ToolUseStatus::Finished(_) => "Finished",
676 ToolUseStatus::Error(_) => "Error",
677 })
678 .size(LabelSize::XSmall)
679 .buffer_font(cx),
680 ),
681 )
682 .map(|parent| {
683 if !is_open {
684 return parent;
685 }
686
687 parent.child(
688 v_flex()
689 .child(
690 v_flex()
691 .gap_0p5()
692 .py_1()
693 .px_2p5()
694 .border_b_1()
695 .border_color(cx.theme().colors().border)
696 .child(Label::new("Input:"))
697 .child(Label::new(
698 serde_json::to_string_pretty(&tool_use.input)
699 .unwrap_or_default(),
700 )),
701 )
702 .map(|parent| match tool_use.status {
703 ToolUseStatus::Finished(output) => parent.child(
704 v_flex()
705 .gap_0p5()
706 .py_1()
707 .px_2p5()
708 .child(Label::new("Result:"))
709 .child(Label::new(output)),
710 ),
711 ToolUseStatus::Error(err) => parent.child(
712 v_flex()
713 .gap_0p5()
714 .py_1()
715 .px_2p5()
716 .child(Label::new("Error:"))
717 .child(Label::new(err)),
718 ),
719 ToolUseStatus::Pending | ToolUseStatus::Running => parent,
720 }),
721 )
722 }),
723 )
724 }
725}
726
727impl Render for ActiveThread {
728 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
729 v_flex()
730 .size_full()
731 .child(list(self.list_state.clone()).flex_grow())
732 }
733}