1use std::sync::Arc;
2use std::time::Duration;
3
4use assistant_tool::ToolWorkingSet;
5use collections::HashMap;
6use gpui::{
7 linear_color_stop, linear_gradient, list, percentage, AbsoluteLength, Animation, AnimationExt,
8 AnyElement, AppContext, DefiniteLength, EdgesRefinement, Empty, FocusHandle, Length,
9 ListAlignment, ListOffset, ListState, Model, StyleRefinement, Subscription,
10 TextStyleRefinement, Transformation, UnderlineStyle, View, WeakView,
11};
12use language::LanguageRegistry;
13use language_model::Role;
14use markdown::{Markdown, MarkdownStyle};
15use settings::Settings as _;
16use theme::ThemeSettings;
17use ui::{prelude::*, Divider, KeyBinding};
18use workspace::Workspace;
19
20use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent};
21use crate::ui::ContextPill;
22
23pub struct ActiveThread {
24 workspace: WeakView<Workspace>,
25 language_registry: Arc<LanguageRegistry>,
26 tools: Arc<ToolWorkingSet>,
27 pub(crate) thread: Model<Thread>,
28 messages: Vec<MessageId>,
29 list_state: ListState,
30 rendered_messages_by_id: HashMap<MessageId, View<Markdown>>,
31 last_error: Option<ThreadError>,
32 focus_handle: FocusHandle,
33 _subscriptions: Vec<Subscription>,
34}
35
36impl ActiveThread {
37 pub fn new(
38 thread: Model<Thread>,
39 workspace: WeakView<Workspace>,
40 language_registry: Arc<LanguageRegistry>,
41 tools: Arc<ToolWorkingSet>,
42 focus_handle: FocusHandle,
43 cx: &mut ViewContext<Self>,
44 ) -> Self {
45 let subscriptions = vec![
46 cx.observe(&thread, |_, _, cx| cx.notify()),
47 cx.subscribe(&thread, Self::handle_thread_event),
48 ];
49
50 let mut this = Self {
51 workspace,
52 language_registry,
53 tools,
54 thread: thread.clone(),
55 messages: Vec::new(),
56 rendered_messages_by_id: HashMap::default(),
57 list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
58 let this = cx.view().downgrade();
59 move |ix, cx: &mut WindowContext| {
60 this.update(cx, |this, cx| this.render_message(ix, cx))
61 .unwrap()
62 }
63 }),
64 last_error: None,
65 focus_handle,
66 _subscriptions: subscriptions,
67 };
68
69 for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
70 this.push_message(&message.id, message.text.clone(), cx);
71 }
72
73 this
74 }
75
76 pub fn is_empty(&self) -> bool {
77 self.messages.is_empty()
78 }
79
80 pub fn summary(&self, cx: &AppContext) -> Option<SharedString> {
81 self.thread.read(cx).summary()
82 }
83
84 pub fn summary_or_default(&self, cx: &AppContext) -> SharedString {
85 self.thread.read(cx).summary_or_default()
86 }
87
88 pub fn cancel_last_completion(&mut self, cx: &mut AppContext) -> bool {
89 self.last_error.take();
90 self.thread
91 .update(cx, |thread, _cx| thread.cancel_last_completion())
92 }
93
94 pub fn last_error(&self) -> Option<ThreadError> {
95 self.last_error.clone()
96 }
97
98 pub fn clear_last_error(&mut self) {
99 self.last_error.take();
100 }
101
102 fn push_message(&mut self, id: &MessageId, text: String, cx: &mut ViewContext<Self>) {
103 let old_len = self.messages.len();
104 self.messages.push(*id);
105 self.list_state.splice(old_len..old_len, 1);
106
107 let theme_settings = ThemeSettings::get_global(cx);
108 let colors = cx.theme().colors();
109 let ui_font_size = TextSize::Default.rems(cx);
110 let buffer_font_size = TextSize::Small.rems(cx);
111 let mut text_style = cx.text_style();
112
113 text_style.refine(&TextStyleRefinement {
114 font_family: Some(theme_settings.ui_font.family.clone()),
115 font_size: Some(ui_font_size.into()),
116 color: Some(cx.theme().colors().text),
117 ..Default::default()
118 });
119
120 let markdown_style = MarkdownStyle {
121 base_text_style: text_style,
122 syntax: cx.theme().syntax().clone(),
123 selection_background_color: cx.theme().players().local().selection,
124 code_block: StyleRefinement {
125 margin: EdgesRefinement {
126 top: Some(Length::Definite(rems(0.).into())),
127 left: Some(Length::Definite(rems(0.).into())),
128 right: Some(Length::Definite(rems(0.).into())),
129 bottom: Some(Length::Definite(rems(0.5).into())),
130 },
131 padding: EdgesRefinement {
132 top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
133 left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
134 right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
135 bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
136 },
137 background: Some(colors.editor_background.into()),
138 border_color: Some(colors.border_variant),
139 border_widths: EdgesRefinement {
140 top: Some(AbsoluteLength::Pixels(Pixels(1.))),
141 left: Some(AbsoluteLength::Pixels(Pixels(1.))),
142 right: Some(AbsoluteLength::Pixels(Pixels(1.))),
143 bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
144 },
145 text: Some(TextStyleRefinement {
146 font_family: Some(theme_settings.buffer_font.family.clone()),
147 font_size: Some(buffer_font_size.into()),
148 ..Default::default()
149 }),
150 ..Default::default()
151 },
152 inline_code: TextStyleRefinement {
153 font_family: Some(theme_settings.buffer_font.family.clone()),
154 font_size: Some(buffer_font_size.into()),
155 background_color: Some(colors.editor_foreground.opacity(0.1)),
156 ..Default::default()
157 },
158 link: TextStyleRefinement {
159 background_color: Some(colors.editor_foreground.opacity(0.025)),
160 underline: Some(UnderlineStyle {
161 color: Some(colors.text_accent.opacity(0.5)),
162 thickness: px(1.),
163 ..Default::default()
164 }),
165 ..Default::default()
166 },
167 ..Default::default()
168 };
169
170 let markdown = cx.new_view(|cx| {
171 Markdown::new(
172 text,
173 markdown_style,
174 Some(self.language_registry.clone()),
175 None,
176 cx,
177 )
178 });
179 self.rendered_messages_by_id.insert(*id, markdown);
180 self.list_state.scroll_to(ListOffset {
181 item_ix: old_len,
182 offset_in_item: Pixels(0.0),
183 });
184 }
185
186 fn handle_thread_event(
187 &mut self,
188 _: Model<Thread>,
189 event: &ThreadEvent,
190 cx: &mut ViewContext<Self>,
191 ) {
192 match event {
193 ThreadEvent::ShowError(error) => {
194 self.last_error = Some(error.clone());
195 }
196 ThreadEvent::StreamedCompletion => {}
197 ThreadEvent::SummaryChanged => {}
198 ThreadEvent::StreamedAssistantText(message_id, text) => {
199 if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
200 markdown.update(cx, |markdown, cx| {
201 markdown.append(text, cx);
202 });
203 }
204 }
205 ThreadEvent::MessageAdded(message_id) => {
206 if let Some(message_text) = self
207 .thread
208 .read(cx)
209 .message(*message_id)
210 .map(|message| message.text.clone())
211 {
212 self.push_message(message_id, message_text, cx);
213 }
214
215 cx.notify();
216 }
217 ThreadEvent::UsePendingTools => {
218 let pending_tool_uses = self
219 .thread
220 .read(cx)
221 .pending_tool_uses()
222 .into_iter()
223 .filter(|tool_use| tool_use.status.is_idle())
224 .cloned()
225 .collect::<Vec<_>>();
226
227 for tool_use in pending_tool_uses {
228 if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
229 let task = tool.run(tool_use.input, self.workspace.clone(), cx);
230
231 self.thread.update(cx, |thread, cx| {
232 thread.insert_tool_output(
233 tool_use.assistant_message_id,
234 tool_use.id.clone(),
235 task,
236 cx,
237 );
238 });
239 }
240 }
241 }
242 ThreadEvent::ToolFinished { .. } => {}
243 }
244 }
245
246 fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
247 let message_id = self.messages[ix];
248 let Some(message) = self.thread.read(cx).message(message_id) else {
249 return Empty.into_any();
250 };
251
252 let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
253 return Empty.into_any();
254 };
255
256 let context = self.thread.read(cx).context_for_message(message_id);
257 let colors = cx.theme().colors();
258
259 let message_content = v_flex()
260 .child(div().p_2p5().text_ui(cx).child(markdown.clone()))
261 .when_some(context, |parent, context| {
262 if !context.is_empty() {
263 parent.child(
264 h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
265 context
266 .into_iter()
267 .map(|context| ContextPill::new_added(context, false, false, None)),
268 ),
269 )
270 } else {
271 parent
272 }
273 });
274
275 let styled_message = match message.role {
276 Role::User => v_flex()
277 .id(("message-container", ix))
278 .py_1()
279 .px_2p5()
280 .child(
281 v_flex()
282 .bg(colors.editor_background)
283 .ml_16()
284 .rounded_t_lg()
285 .rounded_bl_lg()
286 .rounded_br_none()
287 .border_1()
288 .border_color(colors.border)
289 .child(
290 h_flex()
291 .py_1()
292 .px_2()
293 .bg(colors.editor_foreground.opacity(0.05))
294 .border_b_1()
295 .border_color(colors.border)
296 .justify_between()
297 .rounded_t(px(6.))
298 .child(
299 h_flex()
300 .gap_1p5()
301 .child(
302 Icon::new(IconName::PersonCircle)
303 .size(IconSize::XSmall)
304 .color(Color::Muted),
305 )
306 .child(
307 Label::new("You")
308 .size(LabelSize::Small)
309 .color(Color::Muted),
310 ),
311 ),
312 )
313 .child(message_content),
314 ),
315 Role::Assistant => div().id(("message-container", ix)).child(message_content),
316 Role::System => div().id(("message-container", ix)).py_1().px_2().child(
317 v_flex()
318 .bg(colors.editor_background)
319 .rounded_md()
320 .child(message_content),
321 ),
322 };
323
324 styled_message.into_any()
325 }
326}
327
328impl Render for ActiveThread {
329 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
330 let is_streaming_completion = self.thread.read(cx).is_streaming();
331 let panel_bg = cx.theme().colors().panel_background;
332 let focus_handle = self.focus_handle.clone();
333
334 v_flex()
335 .size_full()
336 .pt_1p5()
337 .child(list(self.list_state.clone()).flex_grow())
338 .when(is_streaming_completion, |parent| {
339 parent.child(
340 h_flex()
341 .w_full()
342 .pb_2p5()
343 .absolute()
344 .bottom_0()
345 .flex_shrink()
346 .justify_center()
347 .bg(linear_gradient(
348 180.,
349 linear_color_stop(panel_bg.opacity(0.0), 0.),
350 linear_color_stop(panel_bg, 1.),
351 ))
352 .child(
353 h_flex()
354 .flex_none()
355 .p_1p5()
356 .bg(cx.theme().colors().editor_background)
357 .border_1()
358 .border_color(cx.theme().colors().border)
359 .rounded_md()
360 .shadow_lg()
361 .gap_1()
362 .child(
363 Icon::new(IconName::ArrowCircle)
364 .size(IconSize::Small)
365 .color(Color::Muted)
366 .with_animation(
367 "arrow-circle",
368 Animation::new(Duration::from_secs(2)).repeat(),
369 |icon, delta| {
370 icon.transform(Transformation::rotate(percentage(
371 delta,
372 )))
373 },
374 ),
375 )
376 .child(
377 Label::new("Generating…")
378 .size(LabelSize::Small)
379 .color(Color::Muted),
380 )
381 .child(Divider::vertical())
382 .child(
383 Button::new("cancel-generation", "Cancel")
384 .label_size(LabelSize::Small)
385 .key_binding(KeyBinding::for_action_in(
386 &editor::actions::Cancel,
387 &self.focus_handle,
388 cx,
389 ))
390 .on_click(move |_event, cx| {
391 focus_handle
392 .dispatch_action(&editor::actions::Cancel, cx);
393 }),
394 ),
395 ),
396 )
397 })
398 }
399}