1use std::sync::Arc;
2
3use assistant_tool::ToolWorkingSet;
4use collections::HashMap;
5use gpui::{
6 list, AbsoluteLength, AnyElement, App, DefiniteLength, EdgesRefinement, Empty, Entity, Length,
7 ListAlignment, ListOffset, ListState, StyleRefinement, Subscription, TextStyleRefinement,
8 UnderlineStyle, WeakEntity,
9};
10use language::LanguageRegistry;
11use language_model::Role;
12use markdown::{Markdown, MarkdownStyle};
13use settings::Settings as _;
14use theme::ThemeSettings;
15use ui::prelude::*;
16use workspace::Workspace;
17
18use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent};
19use crate::thread_store::ThreadStore;
20use crate::ui::ContextPill;
21
22pub struct ActiveThread {
23 workspace: WeakEntity<Workspace>,
24 language_registry: Arc<LanguageRegistry>,
25 tools: Arc<ToolWorkingSet>,
26 thread_store: Entity<ThreadStore>,
27 thread: Entity<Thread>,
28 messages: Vec<MessageId>,
29 list_state: ListState,
30 rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
31 last_error: Option<ThreadError>,
32 _subscriptions: Vec<Subscription>,
33}
34
35impl ActiveThread {
36 pub fn new(
37 thread: Entity<Thread>,
38 thread_store: Entity<ThreadStore>,
39 workspace: WeakEntity<Workspace>,
40 language_registry: Arc<LanguageRegistry>,
41 tools: Arc<ToolWorkingSet>,
42 window: &mut Window,
43 cx: &mut Context<Self>,
44 ) -> Self {
45 let subscriptions = vec![
46 cx.observe(&thread, |_, _, cx| cx.notify()),
47 cx.subscribe_in(&thread, window, Self::handle_thread_event),
48 ];
49
50 let mut this = Self {
51 workspace,
52 language_registry,
53 tools,
54 thread_store,
55 thread: thread.clone(),
56 messages: Vec::new(),
57 rendered_messages_by_id: HashMap::default(),
58 list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
59 let this = cx.entity().downgrade();
60 move |ix, _: &mut Window, cx: &mut App| {
61 this.update(cx, |this, cx| this.render_message(ix, cx))
62 .unwrap()
63 }
64 }),
65 last_error: None,
66 _subscriptions: subscriptions,
67 };
68
69 for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
70 this.push_message(&message.id, message.text.clone(), window, cx);
71 }
72
73 this
74 }
75
76 pub fn thread(&self) -> &Entity<Thread> {
77 &self.thread
78 }
79
80 pub fn is_empty(&self) -> bool {
81 self.messages.is_empty()
82 }
83
84 pub fn summary(&self, cx: &App) -> Option<SharedString> {
85 self.thread.read(cx).summary()
86 }
87
88 pub fn summary_or_default(&self, cx: &App) -> SharedString {
89 self.thread.read(cx).summary_or_default()
90 }
91
92 pub fn cancel_last_completion(&mut self, cx: &mut App) -> bool {
93 self.last_error.take();
94 self.thread
95 .update(cx, |thread, _cx| thread.cancel_last_completion())
96 }
97
98 pub fn last_error(&self) -> Option<ThreadError> {
99 self.last_error.clone()
100 }
101
102 pub fn clear_last_error(&mut self) {
103 self.last_error.take();
104 }
105
106 fn push_message(
107 &mut self,
108 id: &MessageId,
109 text: String,
110 window: &mut Window,
111 cx: &mut Context<Self>,
112 ) {
113 let old_len = self.messages.len();
114 self.messages.push(*id);
115 self.list_state.splice(old_len..old_len, 1);
116
117 let theme_settings = ThemeSettings::get_global(cx);
118 let colors = cx.theme().colors();
119 let ui_font_size = TextSize::Default.rems(cx);
120 let buffer_font_size = TextSize::Small.rems(cx);
121 let mut text_style = window.text_style();
122
123 text_style.refine(&TextStyleRefinement {
124 font_family: Some(theme_settings.ui_font.family.clone()),
125 font_size: Some(ui_font_size.into()),
126 color: Some(cx.theme().colors().text),
127 ..Default::default()
128 });
129
130 let markdown_style = MarkdownStyle {
131 base_text_style: text_style,
132 syntax: cx.theme().syntax().clone(),
133 selection_background_color: cx.theme().players().local().selection,
134 code_block: StyleRefinement {
135 margin: EdgesRefinement {
136 top: Some(Length::Definite(rems(0.).into())),
137 left: Some(Length::Definite(rems(0.).into())),
138 right: Some(Length::Definite(rems(0.).into())),
139 bottom: Some(Length::Definite(rems(0.5).into())),
140 },
141 padding: EdgesRefinement {
142 top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
143 left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
144 right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
145 bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
146 },
147 background: Some(colors.editor_background.into()),
148 border_color: Some(colors.border_variant),
149 border_widths: EdgesRefinement {
150 top: Some(AbsoluteLength::Pixels(Pixels(1.))),
151 left: Some(AbsoluteLength::Pixels(Pixels(1.))),
152 right: Some(AbsoluteLength::Pixels(Pixels(1.))),
153 bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
154 },
155 text: Some(TextStyleRefinement {
156 font_family: Some(theme_settings.buffer_font.family.clone()),
157 font_size: Some(buffer_font_size.into()),
158 ..Default::default()
159 }),
160 ..Default::default()
161 },
162 inline_code: TextStyleRefinement {
163 font_family: Some(theme_settings.buffer_font.family.clone()),
164 font_size: Some(buffer_font_size.into()),
165 background_color: Some(colors.editor_foreground.opacity(0.1)),
166 ..Default::default()
167 },
168 link: TextStyleRefinement {
169 background_color: Some(colors.editor_foreground.opacity(0.025)),
170 underline: Some(UnderlineStyle {
171 color: Some(colors.text_accent.opacity(0.5)),
172 thickness: px(1.),
173 ..Default::default()
174 }),
175 ..Default::default()
176 },
177 ..Default::default()
178 };
179
180 let markdown = cx.new(|cx| {
181 Markdown::new(
182 text.into(),
183 markdown_style,
184 Some(self.language_registry.clone()),
185 None,
186 cx,
187 )
188 });
189 self.rendered_messages_by_id.insert(*id, markdown);
190 self.list_state.scroll_to(ListOffset {
191 item_ix: old_len,
192 offset_in_item: Pixels(0.0),
193 });
194 }
195
196 fn handle_thread_event(
197 &mut self,
198 _: &Entity<Thread>,
199 event: &ThreadEvent,
200 window: &mut Window,
201 cx: &mut Context<Self>,
202 ) {
203 match event {
204 ThreadEvent::ShowError(error) => {
205 self.last_error = Some(error.clone());
206 }
207 ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => {
208 self.thread_store
209 .update(cx, |thread_store, cx| {
210 thread_store.save_thread(&self.thread, cx)
211 })
212 .detach_and_log_err(cx);
213 }
214 ThreadEvent::StreamedAssistantText(message_id, text) => {
215 if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
216 markdown.update(cx, |markdown, cx| {
217 markdown.append(text, cx);
218 });
219 }
220 }
221 ThreadEvent::MessageAdded(message_id) => {
222 if let Some(message_text) = self
223 .thread
224 .read(cx)
225 .message(*message_id)
226 .map(|message| message.text.clone())
227 {
228 self.push_message(message_id, message_text, window, cx);
229 }
230
231 self.thread_store
232 .update(cx, |thread_store, cx| {
233 thread_store.save_thread(&self.thread, cx)
234 })
235 .detach_and_log_err(cx);
236
237 cx.notify();
238 }
239 ThreadEvent::UsePendingTools => {
240 let pending_tool_uses = self
241 .thread
242 .read(cx)
243 .pending_tool_uses()
244 .into_iter()
245 .filter(|tool_use| tool_use.status.is_idle())
246 .cloned()
247 .collect::<Vec<_>>();
248
249 for tool_use in pending_tool_uses {
250 if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
251 let task = tool.run(tool_use.input, self.workspace.clone(), window, cx);
252
253 self.thread.update(cx, |thread, cx| {
254 thread.insert_tool_output(
255 tool_use.assistant_message_id,
256 tool_use.id.clone(),
257 task,
258 cx,
259 );
260 });
261 }
262 }
263 }
264 ThreadEvent::ToolFinished { .. } => {}
265 }
266 }
267
268 fn render_message(&self, ix: usize, cx: &mut Context<Self>) -> AnyElement {
269 let message_id = self.messages[ix];
270 let Some(message) = self.thread.read(cx).message(message_id) else {
271 return Empty.into_any();
272 };
273
274 let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
275 return Empty.into_any();
276 };
277
278 let context = self.thread.read(cx).context_for_message(message_id);
279 let colors = cx.theme().colors();
280
281 let message_content = v_flex()
282 .child(div().p_2p5().text_ui(cx).child(markdown.clone()))
283 .when_some(context, |parent, context| {
284 if !context.is_empty() {
285 parent.child(
286 h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
287 context
288 .into_iter()
289 .map(|context| ContextPill::added(context, false, false, None)),
290 ),
291 )
292 } else {
293 parent
294 }
295 });
296
297 let styled_message = match message.role {
298 Role::User => v_flex()
299 .id(("message-container", ix))
300 .pt_2p5()
301 .px_2p5()
302 .child(
303 v_flex()
304 .bg(colors.editor_background)
305 .rounded_lg()
306 .border_1()
307 .border_color(colors.border)
308 .shadow_sm()
309 .child(
310 h_flex()
311 .py_1()
312 .px_2()
313 .bg(colors.editor_foreground.opacity(0.05))
314 .border_b_1()
315 .border_color(colors.border)
316 .justify_between()
317 .rounded_t(px(6.))
318 .child(
319 h_flex()
320 .gap_1p5()
321 .child(
322 Icon::new(IconName::PersonCircle)
323 .size(IconSize::XSmall)
324 .color(Color::Muted),
325 )
326 .child(
327 Label::new("You")
328 .size(LabelSize::Small)
329 .color(Color::Muted),
330 ),
331 ),
332 )
333 .child(message_content),
334 ),
335 Role::Assistant => div().id(("message-container", ix)).child(message_content),
336 Role::System => div().id(("message-container", ix)).py_1().px_2().child(
337 v_flex()
338 .bg(colors.editor_background)
339 .rounded_md()
340 .child(message_content),
341 ),
342 };
343
344 styled_message.into_any()
345 }
346}
347
348impl Render for ActiveThread {
349 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
350 v_flex()
351 .size_full()
352 .child(list(self.list_state.clone()).flex_grow())
353 }
354}