1use std::sync::Arc;
2
3use assistant_tool::ToolWorkingSet;
4use collections::HashMap;
5use gpui::{
6 list, AbsoluteLength, AnyElement, AppContext, DefiniteLength, EdgesRefinement, Empty, Length,
7 ListAlignment, ListOffset, ListState, Model, StyleRefinement, Subscription,
8 TextStyleRefinement, UnderlineStyle, View, WeakView,
9};
10use language::LanguageRegistry;
11use language_model::Role;
12use markdown::{Markdown, MarkdownStyle};
13use settings::Settings as _;
14use theme::ThemeSettings;
15use ui::prelude::*;
16use workspace::Workspace;
17
18use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent};
19use crate::ui::ContextPill;
20
21pub struct ActiveThread {
22 workspace: WeakView<Workspace>,
23 language_registry: Arc<LanguageRegistry>,
24 tools: Arc<ToolWorkingSet>,
25 thread: Model<Thread>,
26 messages: Vec<MessageId>,
27 list_state: ListState,
28 rendered_messages_by_id: HashMap<MessageId, View<Markdown>>,
29 last_error: Option<ThreadError>,
30 _subscriptions: Vec<Subscription>,
31}
32
33impl ActiveThread {
34 pub fn new(
35 thread: Model<Thread>,
36 workspace: WeakView<Workspace>,
37 language_registry: Arc<LanguageRegistry>,
38 tools: Arc<ToolWorkingSet>,
39 cx: &mut ViewContext<Self>,
40 ) -> Self {
41 let subscriptions = vec![
42 cx.observe(&thread, |_, _, cx| cx.notify()),
43 cx.subscribe(&thread, Self::handle_thread_event),
44 ];
45
46 let mut this = Self {
47 workspace,
48 language_registry,
49 tools,
50 thread: thread.clone(),
51 messages: Vec::new(),
52 rendered_messages_by_id: HashMap::default(),
53 list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
54 let this = cx.view().downgrade();
55 move |ix, cx: &mut WindowContext| {
56 this.update(cx, |this, cx| this.render_message(ix, cx))
57 .unwrap()
58 }
59 }),
60 last_error: None,
61 _subscriptions: subscriptions,
62 };
63
64 for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
65 this.push_message(&message.id, message.text.clone(), cx);
66 }
67
68 this
69 }
70
71 pub fn thread(&self) -> &Model<Thread> {
72 &self.thread
73 }
74
75 pub fn is_empty(&self) -> bool {
76 self.messages.is_empty()
77 }
78
79 pub fn summary(&self, cx: &AppContext) -> Option<SharedString> {
80 self.thread.read(cx).summary()
81 }
82
83 pub fn summary_or_default(&self, cx: &AppContext) -> SharedString {
84 self.thread.read(cx).summary_or_default()
85 }
86
87 pub fn cancel_last_completion(&mut self, cx: &mut AppContext) -> bool {
88 self.last_error.take();
89 self.thread
90 .update(cx, |thread, _cx| thread.cancel_last_completion())
91 }
92
93 pub fn last_error(&self) -> Option<ThreadError> {
94 self.last_error.clone()
95 }
96
97 pub fn clear_last_error(&mut self) {
98 self.last_error.take();
99 }
100
101 fn push_message(&mut self, id: &MessageId, text: String, cx: &mut ViewContext<Self>) {
102 let old_len = self.messages.len();
103 self.messages.push(*id);
104 self.list_state.splice(old_len..old_len, 1);
105
106 let theme_settings = ThemeSettings::get_global(cx);
107 let colors = cx.theme().colors();
108 let ui_font_size = TextSize::Default.rems(cx);
109 let buffer_font_size = TextSize::Small.rems(cx);
110 let mut text_style = cx.text_style();
111
112 text_style.refine(&TextStyleRefinement {
113 font_family: Some(theme_settings.ui_font.family.clone()),
114 font_size: Some(ui_font_size.into()),
115 color: Some(cx.theme().colors().text),
116 ..Default::default()
117 });
118
119 let markdown_style = MarkdownStyle {
120 base_text_style: text_style,
121 syntax: cx.theme().syntax().clone(),
122 selection_background_color: cx.theme().players().local().selection,
123 code_block: StyleRefinement {
124 margin: EdgesRefinement {
125 top: Some(Length::Definite(rems(0.).into())),
126 left: Some(Length::Definite(rems(0.).into())),
127 right: Some(Length::Definite(rems(0.).into())),
128 bottom: Some(Length::Definite(rems(0.5).into())),
129 },
130 padding: EdgesRefinement {
131 top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
132 left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
133 right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
134 bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
135 },
136 background: Some(colors.editor_background.into()),
137 border_color: Some(colors.border_variant),
138 border_widths: EdgesRefinement {
139 top: Some(AbsoluteLength::Pixels(Pixels(1.))),
140 left: Some(AbsoluteLength::Pixels(Pixels(1.))),
141 right: Some(AbsoluteLength::Pixels(Pixels(1.))),
142 bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
143 },
144 text: Some(TextStyleRefinement {
145 font_family: Some(theme_settings.buffer_font.family.clone()),
146 font_size: Some(buffer_font_size.into()),
147 ..Default::default()
148 }),
149 ..Default::default()
150 },
151 inline_code: TextStyleRefinement {
152 font_family: Some(theme_settings.buffer_font.family.clone()),
153 font_size: Some(buffer_font_size.into()),
154 background_color: Some(colors.editor_foreground.opacity(0.1)),
155 ..Default::default()
156 },
157 link: TextStyleRefinement {
158 background_color: Some(colors.editor_foreground.opacity(0.025)),
159 underline: Some(UnderlineStyle {
160 color: Some(colors.text_accent.opacity(0.5)),
161 thickness: px(1.),
162 ..Default::default()
163 }),
164 ..Default::default()
165 },
166 ..Default::default()
167 };
168
169 let markdown = cx.new_view(|cx| {
170 Markdown::new(
171 text,
172 markdown_style,
173 Some(self.language_registry.clone()),
174 None,
175 cx,
176 )
177 });
178 self.rendered_messages_by_id.insert(*id, markdown);
179 self.list_state.scroll_to(ListOffset {
180 item_ix: old_len,
181 offset_in_item: Pixels(0.0),
182 });
183 }
184
185 fn handle_thread_event(
186 &mut self,
187 _: Model<Thread>,
188 event: &ThreadEvent,
189 cx: &mut ViewContext<Self>,
190 ) {
191 match event {
192 ThreadEvent::ShowError(error) => {
193 self.last_error = Some(error.clone());
194 }
195 ThreadEvent::StreamedCompletion => {}
196 ThreadEvent::SummaryChanged => {}
197 ThreadEvent::StreamedAssistantText(message_id, text) => {
198 if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
199 markdown.update(cx, |markdown, cx| {
200 markdown.append(text, cx);
201 });
202 }
203 }
204 ThreadEvent::MessageAdded(message_id) => {
205 if let Some(message_text) = self
206 .thread
207 .read(cx)
208 .message(*message_id)
209 .map(|message| message.text.clone())
210 {
211 self.push_message(message_id, message_text, cx);
212 }
213
214 cx.notify();
215 }
216 ThreadEvent::UsePendingTools => {
217 let pending_tool_uses = self
218 .thread
219 .read(cx)
220 .pending_tool_uses()
221 .into_iter()
222 .filter(|tool_use| tool_use.status.is_idle())
223 .cloned()
224 .collect::<Vec<_>>();
225
226 for tool_use in pending_tool_uses {
227 if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
228 let task = tool.run(tool_use.input, self.workspace.clone(), cx);
229
230 self.thread.update(cx, |thread, cx| {
231 thread.insert_tool_output(
232 tool_use.assistant_message_id,
233 tool_use.id.clone(),
234 task,
235 cx,
236 );
237 });
238 }
239 }
240 }
241 ThreadEvent::ToolFinished { .. } => {}
242 }
243 }
244
245 fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
246 let message_id = self.messages[ix];
247 let Some(message) = self.thread.read(cx).message(message_id) else {
248 return Empty.into_any();
249 };
250
251 let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
252 return Empty.into_any();
253 };
254
255 let context = self.thread.read(cx).context_for_message(message_id);
256 let colors = cx.theme().colors();
257
258 let message_content = v_flex()
259 .child(div().p_2p5().text_ui(cx).child(markdown.clone()))
260 .when_some(context, |parent, context| {
261 if !context.is_empty() {
262 parent.child(
263 h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
264 context
265 .into_iter()
266 .map(|context| ContextPill::added(context, false, false, None)),
267 ),
268 )
269 } else {
270 parent
271 }
272 });
273
274 let styled_message = match message.role {
275 Role::User => v_flex()
276 .id(("message-container", ix))
277 .py_1()
278 .px_2p5()
279 .child(
280 v_flex()
281 .bg(colors.editor_background)
282 .rounded_lg()
283 .border_1()
284 .border_color(colors.border)
285 .shadow_sm()
286 .child(
287 h_flex()
288 .py_1()
289 .px_2()
290 .bg(colors.editor_foreground.opacity(0.05))
291 .border_b_1()
292 .border_color(colors.border)
293 .justify_between()
294 .rounded_t(px(6.))
295 .child(
296 h_flex()
297 .gap_1p5()
298 .child(
299 Icon::new(IconName::PersonCircle)
300 .size(IconSize::XSmall)
301 .color(Color::Muted),
302 )
303 .child(
304 Label::new("You")
305 .size(LabelSize::Small)
306 .color(Color::Muted),
307 ),
308 ),
309 )
310 .child(message_content),
311 ),
312 Role::Assistant => div().id(("message-container", ix)).child(message_content),
313 Role::System => div().id(("message-container", ix)).py_1().px_2().child(
314 v_flex()
315 .bg(colors.editor_background)
316 .rounded_md()
317 .child(message_content),
318 ),
319 };
320
321 styled_message.into_any()
322 }
323}
324
325impl Render for ActiveThread {
326 fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
327 v_flex()
328 .size_full()
329 .pt_1p5()
330 .child(list(self.list_state.clone()).flex_grow())
331 }
332}