1use std::sync::Arc;
2
3use assistant_tool::ToolWorkingSet;
4use collections::HashMap;
5use gpui::{
6 list, AbsoluteLength, AnyElement, AppContext, DefiniteLength, EdgesRefinement, Empty, Length,
7 ListAlignment, ListOffset, ListState, Model, StyleRefinement, Subscription,
8 TextStyleRefinement, UnderlineStyle, View, WeakView,
9};
10use language::LanguageRegistry;
11use language_model::Role;
12use markdown::{Markdown, MarkdownStyle};
13use settings::Settings as _;
14use theme::ThemeSettings;
15use ui::prelude::*;
16use workspace::Workspace;
17
18use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent};
19use crate::ui::ContextPill;
20
21pub struct ActiveThread {
22 workspace: WeakView<Workspace>,
23 language_registry: Arc<LanguageRegistry>,
24 tools: Arc<ToolWorkingSet>,
25 pub(crate) thread: Model<Thread>,
26 messages: Vec<MessageId>,
27 list_state: ListState,
28 rendered_messages_by_id: HashMap<MessageId, View<Markdown>>,
29 last_error: Option<ThreadError>,
30 _subscriptions: Vec<Subscription>,
31}
32
33impl ActiveThread {
34 pub fn new(
35 thread: Model<Thread>,
36 workspace: WeakView<Workspace>,
37 language_registry: Arc<LanguageRegistry>,
38 tools: Arc<ToolWorkingSet>,
39 cx: &mut ViewContext<Self>,
40 ) -> Self {
41 let subscriptions = vec![
42 cx.observe(&thread, |_, _, cx| cx.notify()),
43 cx.subscribe(&thread, Self::handle_thread_event),
44 ];
45
46 let mut this = Self {
47 workspace,
48 language_registry,
49 tools,
50 thread: thread.clone(),
51 messages: Vec::new(),
52 rendered_messages_by_id: HashMap::default(),
53 list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
54 let this = cx.view().downgrade();
55 move |ix, cx: &mut WindowContext| {
56 this.update(cx, |this, cx| this.render_message(ix, cx))
57 .unwrap()
58 }
59 }),
60 last_error: None,
61 _subscriptions: subscriptions,
62 };
63
64 for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
65 this.push_message(&message.id, message.text.clone(), cx);
66 }
67
68 this
69 }
70
71 pub fn is_empty(&self) -> bool {
72 self.messages.is_empty()
73 }
74
75 pub fn summary(&self, cx: &AppContext) -> Option<SharedString> {
76 self.thread.read(cx).summary()
77 }
78
79 pub fn summary_or_default(&self, cx: &AppContext) -> SharedString {
80 self.thread.read(cx).summary_or_default()
81 }
82
83 pub fn last_error(&self) -> Option<ThreadError> {
84 self.last_error.clone()
85 }
86
87 pub fn clear_last_error(&mut self) {
88 self.last_error.take();
89 }
90
91 fn push_message(&mut self, id: &MessageId, text: String, cx: &mut ViewContext<Self>) {
92 let old_len = self.messages.len();
93 self.messages.push(*id);
94 self.list_state.splice(old_len..old_len, 1);
95
96 let theme_settings = ThemeSettings::get_global(cx);
97 let colors = cx.theme().colors();
98 let ui_font_size = TextSize::Default.rems(cx);
99 let buffer_font_size = TextSize::Small.rems(cx);
100 let mut text_style = cx.text_style();
101
102 text_style.refine(&TextStyleRefinement {
103 font_family: Some(theme_settings.ui_font.family.clone()),
104 font_size: Some(ui_font_size.into()),
105 color: Some(cx.theme().colors().text),
106 ..Default::default()
107 });
108
109 let markdown_style = MarkdownStyle {
110 base_text_style: text_style,
111 syntax: cx.theme().syntax().clone(),
112 selection_background_color: cx.theme().players().local().selection,
113 code_block: StyleRefinement {
114 margin: EdgesRefinement {
115 top: Some(Length::Definite(rems(1.0).into())),
116 left: Some(Length::Definite(rems(0.).into())),
117 right: Some(Length::Definite(rems(0.).into())),
118 bottom: Some(Length::Definite(rems(1.).into())),
119 },
120 padding: EdgesRefinement {
121 top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
122 left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
123 right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
124 bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
125 },
126 background: Some(colors.editor_foreground.opacity(0.01).into()),
127 border_color: Some(colors.border_variant.opacity(0.3)),
128 border_widths: EdgesRefinement {
129 top: Some(AbsoluteLength::Pixels(Pixels(1.0))),
130 left: Some(AbsoluteLength::Pixels(Pixels(1.))),
131 right: Some(AbsoluteLength::Pixels(Pixels(1.))),
132 bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
133 },
134 text: Some(TextStyleRefinement {
135 font_family: Some(theme_settings.buffer_font.family.clone()),
136 font_size: Some(buffer_font_size.into()),
137 ..Default::default()
138 }),
139 ..Default::default()
140 },
141 inline_code: TextStyleRefinement {
142 font_family: Some(theme_settings.buffer_font.family.clone()),
143 font_size: Some(buffer_font_size.into()),
144 background_color: Some(colors.editor_foreground.opacity(0.1)),
145 ..Default::default()
146 },
147 link: TextStyleRefinement {
148 background_color: Some(colors.editor_foreground.opacity(0.025)),
149 underline: Some(UnderlineStyle {
150 color: Some(colors.text_accent.opacity(0.5)),
151 thickness: px(1.),
152 ..Default::default()
153 }),
154 ..Default::default()
155 },
156 ..Default::default()
157 };
158
159 let markdown = cx.new_view(|cx| {
160 Markdown::new(
161 text,
162 markdown_style,
163 Some(self.language_registry.clone()),
164 None,
165 cx,
166 )
167 });
168 self.rendered_messages_by_id.insert(*id, markdown);
169 self.list_state.scroll_to(ListOffset {
170 item_ix: old_len,
171 offset_in_item: Pixels(0.0),
172 });
173 }
174
175 fn handle_thread_event(
176 &mut self,
177 _: Model<Thread>,
178 event: &ThreadEvent,
179 cx: &mut ViewContext<Self>,
180 ) {
181 match event {
182 ThreadEvent::ShowError(error) => {
183 self.last_error = Some(error.clone());
184 }
185 ThreadEvent::StreamedCompletion => {}
186 ThreadEvent::SummaryChanged => {}
187 ThreadEvent::StreamedAssistantText(message_id, text) => {
188 if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
189 markdown.update(cx, |markdown, cx| {
190 markdown.append(text, cx);
191 });
192 }
193 }
194 ThreadEvent::MessageAdded(message_id) => {
195 if let Some(message_text) = self
196 .thread
197 .read(cx)
198 .message(*message_id)
199 .map(|message| message.text.clone())
200 {
201 self.push_message(message_id, message_text, cx);
202 }
203
204 cx.notify();
205 }
206 ThreadEvent::UsePendingTools => {
207 let pending_tool_uses = self
208 .thread
209 .read(cx)
210 .pending_tool_uses()
211 .into_iter()
212 .filter(|tool_use| tool_use.status.is_idle())
213 .cloned()
214 .collect::<Vec<_>>();
215
216 for tool_use in pending_tool_uses {
217 if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
218 let task = tool.run(tool_use.input, self.workspace.clone(), cx);
219
220 self.thread.update(cx, |thread, cx| {
221 thread.insert_tool_output(
222 tool_use.assistant_message_id,
223 tool_use.id.clone(),
224 task,
225 cx,
226 );
227 });
228 }
229 }
230 }
231 ThreadEvent::ToolFinished { .. } => {}
232 }
233 }
234
235 fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
236 let message_id = self.messages[ix];
237 let Some(message) = self.thread.read(cx).message(message_id) else {
238 return Empty.into_any();
239 };
240
241 let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
242 return Empty.into_any();
243 };
244
245 let context = self.thread.read(cx).context_for_message(message_id);
246 let colors = cx.theme().colors();
247
248 let (role_icon, role_name, role_color) = match message.role {
249 Role::User => (IconName::Person, "You", Color::Muted),
250 Role::Assistant => (IconName::ZedAssistant, "Assistant", Color::Accent),
251 Role::System => (IconName::Settings, "System", Color::Default),
252 };
253
254 div()
255 .id(("message-container", ix))
256 .py_1()
257 .px_2()
258 .child(
259 v_flex()
260 .border_1()
261 .border_color(colors.border_variant)
262 .bg(colors.editor_background)
263 .rounded_md()
264 .child(
265 h_flex()
266 .py_1p5()
267 .px_2p5()
268 .border_b_1()
269 .border_color(colors.border_variant)
270 .justify_between()
271 .child(
272 h_flex()
273 .gap_1p5()
274 .child(
275 Icon::new(role_icon)
276 .size(IconSize::XSmall)
277 .color(role_color),
278 )
279 .child(
280 Label::new(role_name)
281 .size(LabelSize::XSmall)
282 .color(role_color),
283 ),
284 ),
285 )
286 .child(div().p_2p5().text_ui(cx).child(markdown.clone()))
287 .when_some(context, |parent, context| {
288 if !context.is_empty() {
289 parent.child(
290 h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
291 context.into_iter().map(|context| {
292 ContextPill::new_added(context, false, None)
293 }),
294 ),
295 )
296 } else {
297 parent
298 }
299 }),
300 )
301 .into_any()
302 }
303}
304
305impl Render for ActiveThread {
306 fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
307 list(self.list_state.clone()).flex_1().py_1()
308 }
309}