1use std::sync::Arc;
2use std::time::Duration;
3
4use assistant_tool::ToolWorkingSet;
5use collections::HashMap;
6use gpui::{
7 list, percentage, AbsoluteLength, Animation, AnimationExt, AnyElement, AppContext,
8 DefiniteLength, EdgesRefinement, Empty, Length, ListAlignment, ListOffset, ListState, Model,
9 StyleRefinement, Subscription, TextStyleRefinement, Transformation, UnderlineStyle, View,
10 WeakView,
11};
12use language::LanguageRegistry;
13use language_model::Role;
14use markdown::{Markdown, MarkdownStyle};
15use settings::Settings as _;
16use theme::ThemeSettings;
17use ui::prelude::*;
18use workspace::Workspace;
19
20use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent};
21use crate::ui::ContextPill;
22
23pub struct ActiveThread {
24 workspace: WeakView<Workspace>,
25 language_registry: Arc<LanguageRegistry>,
26 tools: Arc<ToolWorkingSet>,
27 pub(crate) thread: Model<Thread>,
28 messages: Vec<MessageId>,
29 list_state: ListState,
30 rendered_messages_by_id: HashMap<MessageId, View<Markdown>>,
31 last_error: Option<ThreadError>,
32 _subscriptions: Vec<Subscription>,
33}
34
35impl ActiveThread {
36 pub fn new(
37 thread: Model<Thread>,
38 workspace: WeakView<Workspace>,
39 language_registry: Arc<LanguageRegistry>,
40 tools: Arc<ToolWorkingSet>,
41 cx: &mut ViewContext<Self>,
42 ) -> Self {
43 let subscriptions = vec![
44 cx.observe(&thread, |_, _, cx| cx.notify()),
45 cx.subscribe(&thread, Self::handle_thread_event),
46 ];
47
48 let mut this = Self {
49 workspace,
50 language_registry,
51 tools,
52 thread: thread.clone(),
53 messages: Vec::new(),
54 rendered_messages_by_id: HashMap::default(),
55 list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
56 let this = cx.view().downgrade();
57 move |ix, cx: &mut WindowContext| {
58 this.update(cx, |this, cx| this.render_message(ix, cx))
59 .unwrap()
60 }
61 }),
62 last_error: None,
63 _subscriptions: subscriptions,
64 };
65
66 for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
67 this.push_message(&message.id, message.text.clone(), cx);
68 }
69
70 this
71 }
72
73 pub fn is_empty(&self) -> bool {
74 self.messages.is_empty()
75 }
76
77 pub fn summary(&self, cx: &AppContext) -> Option<SharedString> {
78 self.thread.read(cx).summary()
79 }
80
81 pub fn summary_or_default(&self, cx: &AppContext) -> SharedString {
82 self.thread.read(cx).summary_or_default()
83 }
84
85 pub fn cancel_last_completion(&mut self, cx: &mut AppContext) -> bool {
86 self.last_error.take();
87 self.thread
88 .update(cx, |thread, _cx| thread.cancel_last_completion())
89 }
90
91 pub fn last_error(&self) -> Option<ThreadError> {
92 self.last_error.clone()
93 }
94
95 pub fn clear_last_error(&mut self) {
96 self.last_error.take();
97 }
98
99 fn push_message(&mut self, id: &MessageId, text: String, cx: &mut ViewContext<Self>) {
100 let old_len = self.messages.len();
101 self.messages.push(*id);
102 self.list_state.splice(old_len..old_len, 1);
103
104 let theme_settings = ThemeSettings::get_global(cx);
105 let colors = cx.theme().colors();
106 let ui_font_size = TextSize::Default.rems(cx);
107 let buffer_font_size = TextSize::Small.rems(cx);
108 let mut text_style = cx.text_style();
109
110 text_style.refine(&TextStyleRefinement {
111 font_family: Some(theme_settings.ui_font.family.clone()),
112 font_size: Some(ui_font_size.into()),
113 color: Some(cx.theme().colors().text),
114 ..Default::default()
115 });
116
117 let markdown_style = MarkdownStyle {
118 base_text_style: text_style,
119 syntax: cx.theme().syntax().clone(),
120 selection_background_color: cx.theme().players().local().selection,
121 code_block: StyleRefinement {
122 margin: EdgesRefinement {
123 top: Some(Length::Definite(rems(1.0).into())),
124 left: Some(Length::Definite(rems(0.).into())),
125 right: Some(Length::Definite(rems(0.).into())),
126 bottom: Some(Length::Definite(rems(1.).into())),
127 },
128 padding: EdgesRefinement {
129 top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
130 left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
131 right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
132 bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
133 },
134 background: Some(colors.editor_foreground.opacity(0.01).into()),
135 border_color: Some(colors.border_variant.opacity(0.3)),
136 border_widths: EdgesRefinement {
137 top: Some(AbsoluteLength::Pixels(Pixels(1.0))),
138 left: Some(AbsoluteLength::Pixels(Pixels(1.))),
139 right: Some(AbsoluteLength::Pixels(Pixels(1.))),
140 bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
141 },
142 text: Some(TextStyleRefinement {
143 font_family: Some(theme_settings.buffer_font.family.clone()),
144 font_size: Some(buffer_font_size.into()),
145 ..Default::default()
146 }),
147 ..Default::default()
148 },
149 inline_code: TextStyleRefinement {
150 font_family: Some(theme_settings.buffer_font.family.clone()),
151 font_size: Some(buffer_font_size.into()),
152 background_color: Some(colors.editor_foreground.opacity(0.1)),
153 ..Default::default()
154 },
155 link: TextStyleRefinement {
156 background_color: Some(colors.editor_foreground.opacity(0.025)),
157 underline: Some(UnderlineStyle {
158 color: Some(colors.text_accent.opacity(0.5)),
159 thickness: px(1.),
160 ..Default::default()
161 }),
162 ..Default::default()
163 },
164 ..Default::default()
165 };
166
167 let markdown = cx.new_view(|cx| {
168 Markdown::new(
169 text,
170 markdown_style,
171 Some(self.language_registry.clone()),
172 None,
173 cx,
174 )
175 });
176 self.rendered_messages_by_id.insert(*id, markdown);
177 self.list_state.scroll_to(ListOffset {
178 item_ix: old_len,
179 offset_in_item: Pixels(0.0),
180 });
181 }
182
183 fn handle_thread_event(
184 &mut self,
185 _: Model<Thread>,
186 event: &ThreadEvent,
187 cx: &mut ViewContext<Self>,
188 ) {
189 match event {
190 ThreadEvent::ShowError(error) => {
191 self.last_error = Some(error.clone());
192 }
193 ThreadEvent::StreamedCompletion => {}
194 ThreadEvent::SummaryChanged => {}
195 ThreadEvent::StreamedAssistantText(message_id, text) => {
196 if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
197 markdown.update(cx, |markdown, cx| {
198 markdown.append(text, cx);
199 });
200 }
201 }
202 ThreadEvent::MessageAdded(message_id) => {
203 if let Some(message_text) = self
204 .thread
205 .read(cx)
206 .message(*message_id)
207 .map(|message| message.text.clone())
208 {
209 self.push_message(message_id, message_text, cx);
210 }
211
212 cx.notify();
213 }
214 ThreadEvent::UsePendingTools => {
215 let pending_tool_uses = self
216 .thread
217 .read(cx)
218 .pending_tool_uses()
219 .into_iter()
220 .filter(|tool_use| tool_use.status.is_idle())
221 .cloned()
222 .collect::<Vec<_>>();
223
224 for tool_use in pending_tool_uses {
225 if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
226 let task = tool.run(tool_use.input, self.workspace.clone(), cx);
227
228 self.thread.update(cx, |thread, cx| {
229 thread.insert_tool_output(
230 tool_use.assistant_message_id,
231 tool_use.id.clone(),
232 task,
233 cx,
234 );
235 });
236 }
237 }
238 }
239 ThreadEvent::ToolFinished { .. } => {}
240 }
241 }
242
243 fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
244 let message_id = self.messages[ix];
245 let Some(message) = self.thread.read(cx).message(message_id) else {
246 return Empty.into_any();
247 };
248
249 let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
250 return Empty.into_any();
251 };
252
253 let is_streaming_completion = self.thread.read(cx).is_streaming();
254 let context = self.thread.read(cx).context_for_message(message_id);
255 let colors = cx.theme().colors();
256
257 let (role_icon, role_name, role_color) = match message.role {
258 Role::User => (IconName::Person, "You", Color::Muted),
259 Role::Assistant => (IconName::ZedAssistant, "Assistant", Color::Accent),
260 Role::System => (IconName::Settings, "System", Color::Default),
261 };
262
263 div()
264 .id(("message-container", ix))
265 .py_1()
266 .px_2()
267 .child(
268 v_flex()
269 .border_1()
270 .border_color(colors.border_variant)
271 .bg(colors.editor_background)
272 .rounded_md()
273 .child(
274 h_flex()
275 .py_1p5()
276 .px_2p5()
277 .border_b_1()
278 .border_color(colors.border_variant)
279 .justify_between()
280 .child(
281 h_flex()
282 .gap_1p5()
283 .child(
284 Icon::new(role_icon)
285 .size(IconSize::XSmall)
286 .color(role_color),
287 )
288 .child(
289 Label::new(role_name)
290 .size(LabelSize::XSmall)
291 .color(role_color),
292 ),
293 ),
294 )
295 .child(div().p_2p5().text_ui(cx).child(markdown.clone()))
296 .when(
297 message.role == Role::Assistant && is_streaming_completion,
298 |parent| {
299 parent.child(
300 h_flex()
301 .gap_1()
302 .p_2p5()
303 .child(
304 Icon::new(IconName::ArrowCircle)
305 .size(IconSize::Small)
306 .color(Color::Muted)
307 .with_animation(
308 "arrow-circle",
309 Animation::new(Duration::from_secs(2)).repeat(),
310 |icon, delta| {
311 icon.transform(Transformation::rotate(
312 percentage(delta),
313 ))
314 },
315 ),
316 )
317 .child(
318 Label::new("Generating…")
319 .size(LabelSize::Small)
320 .color(Color::Muted),
321 ),
322 )
323 },
324 )
325 .when_some(context, |parent, context| {
326 if !context.is_empty() {
327 parent.child(
328 h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
329 context.into_iter().map(|context| {
330 ContextPill::new_added(context, false, None)
331 }),
332 ),
333 )
334 } else {
335 parent
336 }
337 }),
338 )
339 .into_any()
340 }
341}
342
343impl Render for ActiveThread {
344 fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
345 list(self.list_state.clone()).flex_1().py_1()
346 }
347}