1use std::sync::Arc;
2
3use assistant_tool::ToolWorkingSet;
4use collections::HashMap;
5use gpui::{
6 list, AbsoluteLength, AnyElement, AppContext, DefiniteLength, EdgesRefinement, Empty, Length,
7 ListAlignment, ListOffset, ListState, Model, StyleRefinement, Subscription,
8 TextStyleRefinement, UnderlineStyle, View, WeakView,
9};
10use language::LanguageRegistry;
11use language_model::Role;
12use markdown::{Markdown, MarkdownStyle};
13use settings::Settings as _;
14use theme::ThemeSettings;
15use ui::prelude::*;
16use workspace::Workspace;
17
18use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent};
19use crate::thread_store::ThreadStore;
20use crate::ui::ContextPill;
21
22pub struct ActiveThread {
23 workspace: WeakView<Workspace>,
24 language_registry: Arc<LanguageRegistry>,
25 tools: Arc<ToolWorkingSet>,
26 thread_store: Model<ThreadStore>,
27 thread: Model<Thread>,
28 messages: Vec<MessageId>,
29 list_state: ListState,
30 rendered_messages_by_id: HashMap<MessageId, View<Markdown>>,
31 last_error: Option<ThreadError>,
32 _subscriptions: Vec<Subscription>,
33}
34
35impl ActiveThread {
36 pub fn new(
37 thread: Model<Thread>,
38 thread_store: Model<ThreadStore>,
39 workspace: WeakView<Workspace>,
40 language_registry: Arc<LanguageRegistry>,
41 tools: Arc<ToolWorkingSet>,
42 cx: &mut ViewContext<Self>,
43 ) -> Self {
44 let subscriptions = vec![
45 cx.observe(&thread, |_, _, cx| cx.notify()),
46 cx.subscribe(&thread, Self::handle_thread_event),
47 ];
48
49 let mut this = Self {
50 workspace,
51 language_registry,
52 tools,
53 thread_store,
54 thread: thread.clone(),
55 messages: Vec::new(),
56 rendered_messages_by_id: HashMap::default(),
57 list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
58 let this = cx.view().downgrade();
59 move |ix, cx: &mut WindowContext| {
60 this.update(cx, |this, cx| this.render_message(ix, cx))
61 .unwrap()
62 }
63 }),
64 last_error: None,
65 _subscriptions: subscriptions,
66 };
67
68 for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
69 this.push_message(&message.id, message.text.clone(), cx);
70 }
71
72 this
73 }
74
75 pub fn thread(&self) -> &Model<Thread> {
76 &self.thread
77 }
78
79 pub fn is_empty(&self) -> bool {
80 self.messages.is_empty()
81 }
82
83 pub fn summary(&self, cx: &AppContext) -> Option<SharedString> {
84 self.thread.read(cx).summary()
85 }
86
87 pub fn summary_or_default(&self, cx: &AppContext) -> SharedString {
88 self.thread.read(cx).summary_or_default()
89 }
90
91 pub fn cancel_last_completion(&mut self, cx: &mut AppContext) -> bool {
92 self.last_error.take();
93 self.thread
94 .update(cx, |thread, _cx| thread.cancel_last_completion())
95 }
96
97 pub fn last_error(&self) -> Option<ThreadError> {
98 self.last_error.clone()
99 }
100
101 pub fn clear_last_error(&mut self) {
102 self.last_error.take();
103 }
104
105 fn push_message(&mut self, id: &MessageId, text: String, cx: &mut ViewContext<Self>) {
106 let old_len = self.messages.len();
107 self.messages.push(*id);
108 self.list_state.splice(old_len..old_len, 1);
109
110 let theme_settings = ThemeSettings::get_global(cx);
111 let colors = cx.theme().colors();
112 let ui_font_size = TextSize::Default.rems(cx);
113 let buffer_font_size = TextSize::Small.rems(cx);
114 let mut text_style = cx.text_style();
115
116 text_style.refine(&TextStyleRefinement {
117 font_family: Some(theme_settings.ui_font.family.clone()),
118 font_size: Some(ui_font_size.into()),
119 color: Some(cx.theme().colors().text),
120 ..Default::default()
121 });
122
123 let markdown_style = MarkdownStyle {
124 base_text_style: text_style,
125 syntax: cx.theme().syntax().clone(),
126 selection_background_color: cx.theme().players().local().selection,
127 code_block: StyleRefinement {
128 margin: EdgesRefinement {
129 top: Some(Length::Definite(rems(0.).into())),
130 left: Some(Length::Definite(rems(0.).into())),
131 right: Some(Length::Definite(rems(0.).into())),
132 bottom: Some(Length::Definite(rems(0.5).into())),
133 },
134 padding: EdgesRefinement {
135 top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
136 left: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
137 right: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
138 bottom: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))),
139 },
140 background: Some(colors.editor_background.into()),
141 border_color: Some(colors.border_variant),
142 border_widths: EdgesRefinement {
143 top: Some(AbsoluteLength::Pixels(Pixels(1.))),
144 left: Some(AbsoluteLength::Pixels(Pixels(1.))),
145 right: Some(AbsoluteLength::Pixels(Pixels(1.))),
146 bottom: Some(AbsoluteLength::Pixels(Pixels(1.))),
147 },
148 text: Some(TextStyleRefinement {
149 font_family: Some(theme_settings.buffer_font.family.clone()),
150 font_size: Some(buffer_font_size.into()),
151 ..Default::default()
152 }),
153 ..Default::default()
154 },
155 inline_code: TextStyleRefinement {
156 font_family: Some(theme_settings.buffer_font.family.clone()),
157 font_size: Some(buffer_font_size.into()),
158 background_color: Some(colors.editor_foreground.opacity(0.1)),
159 ..Default::default()
160 },
161 link: TextStyleRefinement {
162 background_color: Some(colors.editor_foreground.opacity(0.025)),
163 underline: Some(UnderlineStyle {
164 color: Some(colors.text_accent.opacity(0.5)),
165 thickness: px(1.),
166 ..Default::default()
167 }),
168 ..Default::default()
169 },
170 ..Default::default()
171 };
172
173 let markdown = cx.new_view(|cx| {
174 Markdown::new(
175 text,
176 markdown_style,
177 Some(self.language_registry.clone()),
178 None,
179 cx,
180 )
181 });
182 self.rendered_messages_by_id.insert(*id, markdown);
183 self.list_state.scroll_to(ListOffset {
184 item_ix: old_len,
185 offset_in_item: Pixels(0.0),
186 });
187 }
188
189 fn handle_thread_event(
190 &mut self,
191 _: Model<Thread>,
192 event: &ThreadEvent,
193 cx: &mut ViewContext<Self>,
194 ) {
195 match event {
196 ThreadEvent::ShowError(error) => {
197 self.last_error = Some(error.clone());
198 }
199 ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => {
200 self.thread_store
201 .update(cx, |thread_store, cx| {
202 thread_store.save_thread(&self.thread, cx)
203 })
204 .detach_and_log_err(cx);
205 }
206 ThreadEvent::StreamedAssistantText(message_id, text) => {
207 if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
208 markdown.update(cx, |markdown, cx| {
209 markdown.append(text, cx);
210 });
211 }
212 }
213 ThreadEvent::MessageAdded(message_id) => {
214 if let Some(message_text) = self
215 .thread
216 .read(cx)
217 .message(*message_id)
218 .map(|message| message.text.clone())
219 {
220 self.push_message(message_id, message_text, cx);
221 }
222
223 self.thread_store
224 .update(cx, |thread_store, cx| {
225 thread_store.save_thread(&self.thread, cx)
226 })
227 .detach_and_log_err(cx);
228
229 cx.notify();
230 }
231 ThreadEvent::UsePendingTools => {
232 let pending_tool_uses = self
233 .thread
234 .read(cx)
235 .pending_tool_uses()
236 .into_iter()
237 .filter(|tool_use| tool_use.status.is_idle())
238 .cloned()
239 .collect::<Vec<_>>();
240
241 for tool_use in pending_tool_uses {
242 if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
243 let task = tool.run(tool_use.input, self.workspace.clone(), cx);
244
245 self.thread.update(cx, |thread, cx| {
246 thread.insert_tool_output(
247 tool_use.assistant_message_id,
248 tool_use.id.clone(),
249 task,
250 cx,
251 );
252 });
253 }
254 }
255 }
256 ThreadEvent::ToolFinished { .. } => {}
257 }
258 }
259
260 fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
261 let message_id = self.messages[ix];
262 let Some(message) = self.thread.read(cx).message(message_id) else {
263 return Empty.into_any();
264 };
265
266 let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
267 return Empty.into_any();
268 };
269
270 let context = self.thread.read(cx).context_for_message(message_id);
271 let colors = cx.theme().colors();
272
273 let message_content = v_flex()
274 .child(div().p_2p5().text_ui(cx).child(markdown.clone()))
275 .when_some(context, |parent, context| {
276 if !context.is_empty() {
277 parent.child(
278 h_flex().flex_wrap().gap_1().px_1p5().pb_1p5().children(
279 context
280 .into_iter()
281 .map(|context| ContextPill::added(context, false, false, None)),
282 ),
283 )
284 } else {
285 parent
286 }
287 });
288
289 let styled_message = match message.role {
290 Role::User => v_flex()
291 .id(("message-container", ix))
292 .py_1()
293 .px_2p5()
294 .child(
295 v_flex()
296 .bg(colors.editor_background)
297 .rounded_lg()
298 .border_1()
299 .border_color(colors.border)
300 .shadow_sm()
301 .child(
302 h_flex()
303 .py_1()
304 .px_2()
305 .bg(colors.editor_foreground.opacity(0.05))
306 .border_b_1()
307 .border_color(colors.border)
308 .justify_between()
309 .rounded_t(px(6.))
310 .child(
311 h_flex()
312 .gap_1p5()
313 .child(
314 Icon::new(IconName::PersonCircle)
315 .size(IconSize::XSmall)
316 .color(Color::Muted),
317 )
318 .child(
319 Label::new("You")
320 .size(LabelSize::Small)
321 .color(Color::Muted),
322 ),
323 ),
324 )
325 .child(message_content),
326 ),
327 Role::Assistant => div().id(("message-container", ix)).child(message_content),
328 Role::System => div().id(("message-container", ix)).py_1().px_2().child(
329 v_flex()
330 .bg(colors.editor_background)
331 .rounded_md()
332 .child(message_content),
333 ),
334 };
335
336 styled_message.into_any()
337 }
338}
339
340impl Render for ActiveThread {
341 fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
342 v_flex()
343 .size_full()
344 .pt_1p5()
345 .child(list(self.list_state.clone()).flex_grow())
346 }
347}