1use std::{
2 cell::RefCell,
3 collections::HashSet,
4 fmt::Display,
5 rc::{Rc, Weak},
6 sync::Arc,
7};
8
9use agent_client_protocol as acp;
10use collections::HashMap;
11use gpui::{
12 App, Empty, Entity, EventEmitter, FocusHandle, Focusable, Global, ListAlignment, ListState,
13 StyleRefinement, Subscription, Task, TextStyleRefinement, Window, actions, list, prelude::*,
14};
15use language::LanguageRegistry;
16use markdown::{CodeBlockRenderer, Markdown, MarkdownElement, MarkdownStyle};
17use project::Project;
18use settings::Settings;
19use theme::ThemeSettings;
20use ui::prelude::*;
21use util::ResultExt as _;
22use workspace::{Item, Workspace};
23
24actions!(dev, [OpenAcpLogs]);
25
26pub fn init(cx: &mut App) {
27 cx.observe_new(
28 |workspace: &mut Workspace, _window, _cx: &mut Context<Workspace>| {
29 workspace.register_action(|workspace, _: &OpenAcpLogs, window, cx| {
30 let acp_tools =
31 Box::new(cx.new(|cx| AcpTools::new(workspace.project().clone(), cx)));
32 workspace.add_item_to_active_pane(acp_tools, None, true, window, cx);
33 });
34 },
35 )
36 .detach();
37}
38
39struct GlobalAcpConnectionRegistry(Entity<AcpConnectionRegistry>);
40
41impl Global for GlobalAcpConnectionRegistry {}
42
43#[derive(Default)]
44pub struct AcpConnectionRegistry {
45 active_connection: RefCell<Option<ActiveConnection>>,
46}
47
48struct ActiveConnection {
49 server_name: SharedString,
50 connection: Weak<acp::ClientSideConnection>,
51}
52
53impl AcpConnectionRegistry {
54 pub fn default_global(cx: &mut App) -> Entity<Self> {
55 if cx.has_global::<GlobalAcpConnectionRegistry>() {
56 cx.global::<GlobalAcpConnectionRegistry>().0.clone()
57 } else {
58 let registry = cx.new(|_cx| AcpConnectionRegistry::default());
59 cx.set_global(GlobalAcpConnectionRegistry(registry.clone()));
60 registry
61 }
62 }
63
64 pub fn set_active_connection(
65 &self,
66 server_name: impl Into<SharedString>,
67 connection: &Rc<acp::ClientSideConnection>,
68 cx: &mut Context<Self>,
69 ) {
70 self.active_connection.replace(Some(ActiveConnection {
71 server_name: server_name.into(),
72 connection: Rc::downgrade(connection),
73 }));
74 cx.notify();
75 }
76}
77
78struct AcpTools {
79 project: Entity<Project>,
80 focus_handle: FocusHandle,
81 expanded: HashSet<usize>,
82 watched_connection: Option<WatchedConnection>,
83 connection_registry: Entity<AcpConnectionRegistry>,
84 _subscription: Subscription,
85}
86
87struct WatchedConnection {
88 server_name: SharedString,
89 messages: Vec<WatchedConnectionMessage>,
90 list_state: ListState,
91 connection: Weak<acp::ClientSideConnection>,
92 incoming_request_methods: HashMap<i32, Arc<str>>,
93 outgoing_request_methods: HashMap<i32, Arc<str>>,
94 _task: Task<()>,
95}
96
97impl AcpTools {
98 fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
99 let connection_registry = AcpConnectionRegistry::default_global(cx);
100
101 let subscription = cx.observe(&connection_registry, |this, _, cx| {
102 this.update_connection(cx);
103 cx.notify();
104 });
105
106 let mut this = Self {
107 project,
108 focus_handle: cx.focus_handle(),
109 expanded: HashSet::default(),
110 watched_connection: None,
111 connection_registry,
112 _subscription: subscription,
113 };
114 this.update_connection(cx);
115 this
116 }
117
118 fn update_connection(&mut self, cx: &mut Context<Self>) {
119 let active_connection = self.connection_registry.read(cx).active_connection.borrow();
120 let Some(active_connection) = active_connection.as_ref() else {
121 return;
122 };
123
124 if let Some(watched_connection) = self.watched_connection.as_ref() {
125 if Weak::ptr_eq(
126 &watched_connection.connection,
127 &active_connection.connection,
128 ) {
129 return;
130 }
131 }
132
133 if let Some(connection) = active_connection.connection.upgrade() {
134 let mut receiver = connection.subscribe();
135 let task = cx.spawn(async move |this, cx| {
136 while let Ok(message) = receiver.recv().await {
137 this.update(cx, |this, cx| {
138 this.push_stream_message(message, cx);
139 })
140 .ok();
141 }
142 });
143
144 self.watched_connection = Some(WatchedConnection {
145 server_name: active_connection.server_name.clone(),
146 messages: vec![],
147 list_state: ListState::new(0, ListAlignment::Bottom, px(2048.)),
148 connection: active_connection.connection.clone(),
149 incoming_request_methods: HashMap::default(),
150 outgoing_request_methods: HashMap::default(),
151 _task: task,
152 });
153 }
154 }
155
156 fn push_stream_message(&mut self, stream_message: acp::StreamMessage, cx: &mut Context<Self>) {
157 let Some(connection) = self.watched_connection.as_mut() else {
158 return;
159 };
160 let language_registry = self.project.read(cx).languages().clone();
161 let index = connection.messages.len();
162
163 let (request_id, method, message_type, params) = match stream_message.message {
164 acp::StreamMessageContent::Request { id, method, params } => {
165 let method_map = match stream_message.direction {
166 acp::StreamMessageDirection::Incoming => {
167 &mut connection.incoming_request_methods
168 }
169 acp::StreamMessageDirection::Outgoing => {
170 &mut connection.outgoing_request_methods
171 }
172 };
173
174 method_map.insert(id, method.clone());
175 (Some(id), method.into(), MessageType::Request, Ok(params))
176 }
177 acp::StreamMessageContent::Response { id, result } => {
178 let method_map = match stream_message.direction {
179 acp::StreamMessageDirection::Incoming => {
180 &mut connection.outgoing_request_methods
181 }
182 acp::StreamMessageDirection::Outgoing => {
183 &mut connection.incoming_request_methods
184 }
185 };
186
187 if let Some(method) = method_map.remove(&id) {
188 (Some(id), method.into(), MessageType::Response, result)
189 } else {
190 (
191 Some(id),
192 "[unrecognized response]".into(),
193 MessageType::Response,
194 result,
195 )
196 }
197 }
198 acp::StreamMessageContent::Notification { method, params } => {
199 (None, method.into(), MessageType::Notification, Ok(params))
200 }
201 };
202
203 let message = WatchedConnectionMessage {
204 name: method,
205 message_type,
206 request_id,
207 direction: stream_message.direction,
208 collapsed_params_md: match params.as_ref() {
209 Ok(params) => params
210 .as_ref()
211 .map(|params| collapsed_params_md(params, &language_registry, cx)),
212 Err(err) => {
213 if let Ok(err) = &serde_json::to_value(err) {
214 Some(collapsed_params_md(&err, &language_registry, cx))
215 } else {
216 None
217 }
218 }
219 },
220
221 expanded_params_md: None,
222 params,
223 };
224
225 connection.messages.push(message);
226 connection.list_state.splice(index..index, 1);
227 cx.notify();
228 }
229
230 fn render_message(
231 &mut self,
232 index: usize,
233 window: &mut Window,
234 cx: &mut Context<Self>,
235 ) -> AnyElement {
236 let Some(connection) = self.watched_connection.as_ref() else {
237 return Empty.into_any();
238 };
239
240 let Some(message) = connection.messages.get(index) else {
241 return Empty.into_any();
242 };
243
244 let base_size = TextSize::Editor.rems(cx);
245
246 let theme_settings = ThemeSettings::get_global(cx);
247 let text_style = window.text_style();
248
249 let colors = cx.theme().colors();
250 let expanded = self.expanded.contains(&index);
251
252 v_flex()
253 .w_full()
254 .px_4()
255 .py_3()
256 .border_color(colors.border)
257 .border_b_1()
258 .gap_2()
259 .items_start()
260 .font_buffer(cx)
261 .text_size(base_size)
262 .id(index)
263 .group("message")
264 .hover(|this| this.bg(colors.element_background.opacity(0.5)))
265 .on_click(cx.listener(move |this, _, _, cx| {
266 if this.expanded.contains(&index) {
267 this.expanded.remove(&index);
268 } else {
269 this.expanded.insert(index);
270 let Some(connection) = &mut this.watched_connection else {
271 return;
272 };
273 let Some(message) = connection.messages.get_mut(index) else {
274 return;
275 };
276 message.expanded(this.project.read(cx).languages().clone(), cx);
277 connection.list_state.scroll_to_reveal_item(index);
278 }
279 cx.notify()
280 }))
281 .child(
282 h_flex()
283 .w_full()
284 .gap_2()
285 .items_center()
286 .flex_shrink_0()
287 .child(match message.direction {
288 acp::StreamMessageDirection::Incoming => {
289 ui::Icon::new(ui::IconName::ArrowDown).color(Color::Error)
290 }
291 acp::StreamMessageDirection::Outgoing => {
292 ui::Icon::new(ui::IconName::ArrowUp).color(Color::Success)
293 }
294 })
295 .child(
296 Label::new(message.name.clone())
297 .buffer_font(cx)
298 .color(Color::Muted),
299 )
300 .child(div().flex_1())
301 .child(
302 div()
303 .child(ui::Chip::new(message.message_type.to_string()))
304 .visible_on_hover("message"),
305 )
306 .children(
307 message
308 .request_id
309 .map(|req_id| div().child(ui::Chip::new(req_id.to_string()))),
310 ),
311 )
312 // I'm aware using markdown is a hack. Trying to get something working for the demo.
313 // Will clean up soon!
314 .when_some(
315 if expanded {
316 message.expanded_params_md.clone()
317 } else {
318 message.collapsed_params_md.clone()
319 },
320 |this, params| {
321 this.child(
322 div().pl_6().w_full().child(
323 MarkdownElement::new(
324 params,
325 MarkdownStyle {
326 base_text_style: text_style,
327 selection_background_color: colors.element_selection_background,
328 syntax: cx.theme().syntax().clone(),
329 code_block_overflow_x_scroll: true,
330 code_block: StyleRefinement {
331 text: Some(TextStyleRefinement {
332 font_family: Some(
333 theme_settings.buffer_font.family.clone(),
334 ),
335 font_size: Some((base_size * 0.8).into()),
336 ..Default::default()
337 }),
338 ..Default::default()
339 },
340 ..Default::default()
341 },
342 )
343 .code_block_renderer(
344 CodeBlockRenderer::Default {
345 copy_button: false,
346 copy_button_on_hover: expanded,
347 border: false,
348 },
349 ),
350 ),
351 )
352 },
353 )
354 .into_any()
355 }
356}
357
358struct WatchedConnectionMessage {
359 name: SharedString,
360 request_id: Option<i32>,
361 direction: acp::StreamMessageDirection,
362 message_type: MessageType,
363 params: Result<Option<serde_json::Value>, acp::Error>,
364 collapsed_params_md: Option<Entity<Markdown>>,
365 expanded_params_md: Option<Entity<Markdown>>,
366}
367
368impl WatchedConnectionMessage {
369 fn expanded(&mut self, language_registry: Arc<LanguageRegistry>, cx: &mut App) {
370 let params_md = match &self.params {
371 Ok(Some(params)) => Some(expanded_params_md(params, &language_registry, cx)),
372 Err(err) => {
373 if let Some(err) = &serde_json::to_value(err).log_err() {
374 Some(expanded_params_md(&err, &language_registry, cx))
375 } else {
376 None
377 }
378 }
379 _ => None,
380 };
381 self.expanded_params_md = params_md;
382 }
383}
384
385fn collapsed_params_md(
386 params: &serde_json::Value,
387 language_registry: &Arc<LanguageRegistry>,
388 cx: &mut App,
389) -> Entity<Markdown> {
390 let params_json = serde_json::to_string(params).unwrap_or_default();
391 let mut spaced_out_json = String::with_capacity(params_json.len() + params_json.len() / 4);
392
393 for ch in params_json.chars() {
394 match ch {
395 '{' => spaced_out_json.push_str("{ "),
396 '}' => spaced_out_json.push_str(" }"),
397 ':' => spaced_out_json.push_str(": "),
398 ',' => spaced_out_json.push_str(", "),
399 c => spaced_out_json.push(c),
400 }
401 }
402
403 let params_md = format!("```json\n{}\n```", spaced_out_json);
404 cx.new(|cx| Markdown::new(params_md.into(), Some(language_registry.clone()), None, cx))
405}
406
407fn expanded_params_md(
408 params: &serde_json::Value,
409 language_registry: &Arc<LanguageRegistry>,
410 cx: &mut App,
411) -> Entity<Markdown> {
412 let params_json = serde_json::to_string_pretty(params).unwrap_or_default();
413 let params_md = format!("```json\n{}\n```", params_json);
414 cx.new(|cx| Markdown::new(params_md.into(), Some(language_registry.clone()), None, cx))
415}
416
417enum MessageType {
418 Request,
419 Response,
420 Notification,
421}
422
423impl Display for MessageType {
424 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
425 match self {
426 MessageType::Request => write!(f, "Request"),
427 MessageType::Response => write!(f, "Response"),
428 MessageType::Notification => write!(f, "Notification"),
429 }
430 }
431}
432
433enum AcpToolsEvent {}
434
435impl EventEmitter<AcpToolsEvent> for AcpTools {}
436
437impl Item for AcpTools {
438 type Event = AcpToolsEvent;
439
440 fn tab_content_text(&self, _detail: usize, _cx: &App) -> ui::SharedString {
441 format!(
442 "ACP: {}",
443 self.watched_connection
444 .as_ref()
445 .map_or("Disconnected", |connection| &connection.server_name)
446 )
447 .into()
448 }
449
450 fn tab_icon(&self, _window: &Window, _cx: &App) -> Option<Icon> {
451 Some(ui::Icon::new(IconName::Thread))
452 }
453}
454
455impl Focusable for AcpTools {
456 fn focus_handle(&self, _cx: &App) -> FocusHandle {
457 self.focus_handle.clone()
458 }
459}
460
461impl Render for AcpTools {
462 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
463 v_flex()
464 .track_focus(&self.focus_handle)
465 .size_full()
466 .bg(cx.theme().colors().editor_background)
467 .child(match self.watched_connection.as_ref() {
468 Some(connection) => {
469 if connection.messages.is_empty() {
470 h_flex()
471 .size_full()
472 .justify_center()
473 .items_center()
474 .child("No messages recorded yet")
475 .into_any()
476 } else {
477 list(
478 connection.list_state.clone(),
479 cx.processor(Self::render_message),
480 )
481 .with_sizing_behavior(gpui::ListSizingBehavior::Auto)
482 .flex_grow()
483 .into_any()
484 }
485 }
486 None => h_flex()
487 .size_full()
488 .justify_center()
489 .items_center()
490 .child("No active connection")
491 .into_any(),
492 })
493 }
494}