1use std::sync::Arc;
2use std::time::Duration;
3
4use assistant_tool::ToolWorkingSet;
5use collections::HashMap;
6use gpui::{
7 list, percentage, AbsoluteLength, Animation, AnimationExt, AnyElement, AppContext,
8 DefiniteLength, EdgesRefinement, Empty, Length, ListAlignment, ListOffset, ListState, Model,
9 StyleRefinement, Subscription, TextStyleRefinement, Transformation, UnderlineStyle, View,
10 WeakView,
11};
12use language::LanguageRegistry;
13use language_model::Role;
14use markdown::{Markdown, MarkdownStyle};
15use settings::Settings as _;
16use theme::ThemeSettings;
17use ui::prelude::*;
18use workspace::Workspace;
19
20use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent};
21use crate::ui::ContextPill;
22
23pub struct ActiveThread {
24 workspace: WeakView<Workspace>,
25 language_registry: Arc<LanguageRegistry>,
26 tools: Arc<ToolWorkingSet>,
27 pub(crate) thread: Model<Thread>,
28 messages: Vec<MessageId>,
29 list_state: ListState,
30 rendered_messages_by_id: HashMap<MessageId, View<Markdown>>,
31 last_error: Option<ThreadError>,
32 _subscriptions: Vec<Subscription>,
33}
34
35impl ActiveThread {
36 pub fn new(
37 thread: Model<Thread>,
38 workspace: WeakView<Workspace>,
39 language_registry: Arc<LanguageRegistry>,
40 tools: Arc<ToolWorkingSet>,
41 cx: &mut ViewContext<Self>,
42 ) -> Self {
43 let subscriptions = vec![
44 cx.observe(&thread, |_, _, cx| cx.notify()),
45 cx.subscribe(&thread, Self::handle_thread_event),
46 ];
47
48 let mut this = Self {
49 workspace,
50 language_registry,
51 tools,
52 thread: thread.clone(),
53 messages: Vec::new(),
54 rendered_messages_by_id: HashMap::default(),
55 list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
56 let this = cx.view().downgrade();
57 move |ix, cx: &mut WindowContext| {
58 this.update(cx, |this, cx| this.render_message(ix, cx))
59 .unwrap()
60 }
61 }),
62 last_error: None,
63 _subscriptions: subscriptions,
64 };
65
66 for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
67 this.push_message(&message.id, message.text.clone(), cx);
68 }
69
70 this
71 }
72
73 pub fn is_empty(&self) -> bool {
74 self.messages.is_empty()
75 }
76
77 pub fn summary(&self, cx: &AppContext) -> Option<SharedString> {
78 self.thread.read(cx).summary()
79 }
80
81 pub fn summary_or_default(&self, cx: &AppContext) -> SharedString {
82 self.thread.read(cx).summary_or_default()
83 }
84
85 pub fn cancel_last_completion(&mut self, cx: &mut AppContext) -> bool {
86 self.last_error.take();
87 self.thread
88 .update(cx, |thread, _cx| thread.cancel_last_completion())
89 }
90
91 pub fn last_error(&self) -> Option<ThreadError> {
92 self.last_error.clone()
93 }
94
95 pub fn clear_last_error(&mut self) {
96 self.last_error.take();
97 }
98
99 fn push_message(&mut self, id: &MessageId, text: String, cx: &mut ViewContext<Self>) {
100 let old_len = self.messages.len();
101 self.messages.push(*id);
102 self.list_state.splice(old_len..old_len, 1);
103
104 let theme_settings = ThemeSettings::get_global(cx);
105 let colors = cx.theme().colors();
106 let ui_font_size = TextSize::Default.rems(cx);
107 let buffer_font_size = TextSize::Small.rems(cx);
108 let mut text_style = cx.text_style();
109
110 text_style.refine(&TextStyleRefinement {
111 font_family: Some(theme_settings.ui_font.family.clone()),
112 font_size: Some(ui_font_size.into()),
113 color: Some(cx.theme().colors().text),
114 ..Default::default()
115 });
116
117 let markdown_style = MarkdownStyle {
118 base_text_style: text_style,
119 syntax: cx.theme().syntax().clone(),
120 selection_background_color: cx.theme().players().local().selection,
121 code_block: StyleRefinement {
122 margin: EdgesRefinement {
123 top: Some(Length::Definite(rems(1.0).into())),
124 left: Some(Length::Definite(rems(0.).into())),
125 right: Some(Length::Definite(rems(0.).into())),
126 bottom: Some(Length::Definite(rems(1.).into())),
127 },
128 padding: EdgesRefinement {
129 top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
130 left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
131 right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
132 bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
133 },
134 background: Some(colors.editor_foreground.opacity(0.01).into()),
135 border_color: Some(colors.border_variant.opacity(0.3)),
136 border_widths: EdgesRefinement {
137 top: Some(AbsoluteLength::Pixels(Pixels(1.0))),
138 left: Some(AbsoluteLength::Pixels(Pixels(1.))),
139 right: Some(AbsoluteLength::Pixels(Pixels(1.))),
140 bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
141 },
142 text: Some(TextStyleRefinement {
143 font_family: Some(theme_settings.buffer_font.family.clone()),
144 font_size: Some(buffer_font_size.into()),
145 ..Default::default()
146 }),
147 ..Default::default()
148 },
149 inline_code: TextStyleRefinement {
150 font_family: Some(theme_settings.buffer_font.family.clone()),
151 font_size: Some(buffer_font_size.into()),
152 background_color: Some(colors.editor_foreground.opacity(0.1)),
153 ..Default::default()
154 },
155 link: TextStyleRefinement {
156 background_color: Some(colors.editor_foreground.opacity(0.025)),
157 underline: Some(UnderlineStyle {
158 color: Some(colors.text_accent.opacity(0.5)),
159 thickness: px(1.),
160 ..Default::default()
161 }),
162 ..Default::default()
163 },
164 ..Default::default()
165 };
166
167 let markdown = cx.new_view(|cx| {
168 Markdown::new(
169 text,
170 markdown_style,
171 Some(self.language_registry.clone()),
172 None,
173 cx,
174 )
175 });
176 self.rendered_messages_by_id.insert(*id, markdown);
177 self.list_state.scroll_to(ListOffset {
178 item_ix: old_len,
179 offset_in_item: Pixels(0.0),
180 });
181 }
182
183 fn handle_thread_event(
184 &mut self,
185 _: Model<Thread>,
186 event: &ThreadEvent,
187 cx: &mut ViewContext<Self>,
188 ) {
189 match event {
190 ThreadEvent::ShowError(error) => {
191 self.last_error = Some(error.clone());
192 }
193 ThreadEvent::StreamedCompletion => {}
194 ThreadEvent::SummaryChanged => {}
195 ThreadEvent::StreamedAssistantText(message_id, text) => {
196 if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
197 markdown.update(cx, |markdown, cx| {
198 markdown.append(text, cx);
199 });
200 }
201 }
202 ThreadEvent::MessageAdded(message_id) => {
203 if let Some(message_text) = self
204 .thread
205 .read(cx)
206 .message(*message_id)
207 .map(|message| message.text.clone())
208 {
209 self.push_message(message_id, message_text, cx);
210 }
211
212 cx.notify();
213 }
214 ThreadEvent::UsePendingTools => {
215 let pending_tool_uses = self
216 .thread
217 .read(cx)
218 .pending_tool_uses()
219 .into_iter()
220 .filter(|tool_use| tool_use.status.is_idle())
221 .cloned()
222 .collect::<Vec<_>>();
223
224 for tool_use in pending_tool_uses {
225 if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
226 let task = tool.run(tool_use.input, self.workspace.clone(), cx);
227
228 self.thread.update(cx, |thread, cx| {
229 thread.insert_tool_output(
230 tool_use.assistant_message_id,
231 tool_use.id.clone(),
232 task,
233 cx,
234 );
235 });
236 }
237 }
238 }
239 ThreadEvent::ToolFinished { .. } => {}
240 }
241 }
242
243 fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
244 let message_id = self.messages[ix];
245 let is_last_message = ix == self.messages.len() - 1;
246 let Some(message) = self.thread.read(cx).message(message_id) else {
247 return Empty.into_any();
248 };
249
250 let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
251 return Empty.into_any();
252 };
253
254 let is_streaming_completion = self.thread.read(cx).is_streaming();
255 let context = self.thread.read(cx).context_for_message(message_id);
256 let colors = cx.theme().colors();
257
258 let (role_icon, role_name, role_color) = match message.role {
259 Role::User => (IconName::Person, "You", Color::Muted),
260 Role::Assistant => (IconName::ZedAssistant, "Assistant", Color::Accent),
261 Role::System => (IconName::Settings, "System", Color::Default),
262 };
263
264 div()
265 .id(("message-container", ix))
266 .py_1()
267 .px_2()
268 .child(
269 v_flex()
270 .border_1()
271 .border_color(colors.border_variant)
272 .bg(colors.editor_background)
273 .rounded_md()
274 .child(
275 h_flex()
276 .py_1p5()
277 .px_2p5()
278 .border_b_1()
279 .border_color(colors.border_variant)
280 .justify_between()
281 .child(
282 h_flex()
283 .gap_1p5()
284 .child(
285 Icon::new(role_icon)
286 .size(IconSize::XSmall)
287 .color(role_color),
288 )
289 .child(
290 Label::new(role_name)
291 .size(LabelSize::XSmall)
292 .color(role_color),
293 ),
294 ),
295 )
296 .child(div().p_2p5().text_ui(cx).child(markdown.clone()))
297 .when(
298 message.role == Role::Assistant
299 && is_last_message
300 && is_streaming_completion,
301 |parent| {
302 parent.child(
303 h_flex()
304 .gap_1()
305 .p_2p5()
306 .child(
307 Icon::new(IconName::ArrowCircle)
308 .size(IconSize::Small)
309 .color(Color::Muted)
310 .with_animation(
311 "arrow-circle",
312 Animation::new(Duration::from_secs(2)).repeat(),
313 |icon, delta| {
314 icon.transform(Transformation::rotate(
315 percentage(delta),
316 ))
317 },
318 ),
319 )
320 .child(
321 Label::new("Generating…")
322 .size(LabelSize::Small)
323 .color(Color::Muted),
324 ),
325 )
326 },
327 )
328 .when_some(context, |parent, context| {
329 if !context.is_empty() {
330 parent.child(h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
331 context.into_iter().map(|context| {
332 ContextPill::new_added(context, false, false, None)
333 }),
334 ))
335 } else {
336 parent
337 }
338 }),
339 )
340 .into_any()
341 }
342}
343
344impl Render for ActiveThread {
345 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
346 let is_streaming_completion = self.thread.read(cx).is_streaming();
347
348 v_flex()
349 .size_full()
350 .child(list(self.list_state.clone()).flex_grow())
351 .child(
352 h_flex()
353 .absolute()
354 .bottom_1()
355 .flex_shrink()
356 .justify_center()
357 .w_full()
358 .when(is_streaming_completion, |parent| {
359 parent.child(
360 h_flex()
361 .gap_2()
362 .p_1p5()
363 .rounded_md()
364 .bg(cx.theme().colors().elevated_surface_background)
365 .child(Label::new("Generating…").size(LabelSize::Small))
366 .child(Label::new("esc to cancel").size(LabelSize::Small)),
367 )
368 }),
369 )
370 }
371}