acp_tools.rs

  1use std::{
  2    cell::RefCell,
  3    collections::HashSet,
  4    fmt::Display,
  5    rc::{Rc, Weak},
  6    sync::Arc,
  7};
  8
  9use agent_client_protocol as acp;
 10use collections::HashMap;
 11use gpui::{
 12    App, Empty, Entity, EventEmitter, FocusHandle, Focusable, Global, ListAlignment, ListState,
 13    StyleRefinement, Subscription, Task, TextStyleRefinement, Window, actions, list, prelude::*,
 14};
 15use language::LanguageRegistry;
 16use markdown::{CodeBlockRenderer, Markdown, MarkdownElement, MarkdownStyle};
 17use project::{AgentId, Project};
 18use settings::Settings;
 19use theme::ThemeSettings;
 20use ui::{CopyButton, Tooltip, WithScrollbar, prelude::*};
 21use util::ResultExt as _;
 22use workspace::{
 23    Item, ItemHandle, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, Workspace,
 24};
 25
 26actions!(dev, [OpenAcpLogs]);
 27
 28pub fn init(cx: &mut App) {
 29    cx.observe_new(
 30        |workspace: &mut Workspace, _window, _cx: &mut Context<Workspace>| {
 31            workspace.register_action(|workspace, _: &OpenAcpLogs, window, cx| {
 32                let acp_tools =
 33                    Box::new(cx.new(|cx| AcpTools::new(workspace.project().clone(), cx)));
 34                workspace.add_item_to_active_pane(acp_tools, None, true, window, cx);
 35            });
 36        },
 37    )
 38    .detach();
 39}
 40
 41struct GlobalAcpConnectionRegistry(Entity<AcpConnectionRegistry>);
 42
 43impl Global for GlobalAcpConnectionRegistry {}
 44
 45#[derive(Default)]
 46pub struct AcpConnectionRegistry {
 47    active_connection: RefCell<Option<ActiveConnection>>,
 48}
 49
 50struct ActiveConnection {
 51    agent_id: AgentId,
 52    connection: Weak<acp::ClientSideConnection>,
 53}
 54
 55impl AcpConnectionRegistry {
 56    pub fn default_global(cx: &mut App) -> Entity<Self> {
 57        if cx.has_global::<GlobalAcpConnectionRegistry>() {
 58            cx.global::<GlobalAcpConnectionRegistry>().0.clone()
 59        } else {
 60            let registry = cx.new(|_cx| AcpConnectionRegistry::default());
 61            cx.set_global(GlobalAcpConnectionRegistry(registry.clone()));
 62            registry
 63        }
 64    }
 65
 66    pub fn set_active_connection(
 67        &self,
 68        agent_id: AgentId,
 69        connection: &Rc<acp::ClientSideConnection>,
 70        cx: &mut Context<Self>,
 71    ) {
 72        self.active_connection.replace(Some(ActiveConnection {
 73            agent_id,
 74            connection: Rc::downgrade(connection),
 75        }));
 76        cx.notify();
 77    }
 78}
 79
 80struct AcpTools {
 81    project: Entity<Project>,
 82    focus_handle: FocusHandle,
 83    expanded: HashSet<usize>,
 84    watched_connection: Option<WatchedConnection>,
 85    connection_registry: Entity<AcpConnectionRegistry>,
 86    _subscription: Subscription,
 87}
 88
 89struct WatchedConnection {
 90    agent_id: AgentId,
 91    messages: Vec<WatchedConnectionMessage>,
 92    list_state: ListState,
 93    connection: Weak<acp::ClientSideConnection>,
 94    incoming_request_methods: HashMap<acp::RequestId, Arc<str>>,
 95    outgoing_request_methods: HashMap<acp::RequestId, Arc<str>>,
 96    _task: Task<()>,
 97}
 98
 99impl AcpTools {
100    fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
101        let connection_registry = AcpConnectionRegistry::default_global(cx);
102
103        let subscription = cx.observe(&connection_registry, |this, _, cx| {
104            this.update_connection(cx);
105            cx.notify();
106        });
107
108        let mut this = Self {
109            project,
110            focus_handle: cx.focus_handle(),
111            expanded: HashSet::default(),
112            watched_connection: None,
113            connection_registry,
114            _subscription: subscription,
115        };
116        this.update_connection(cx);
117        this
118    }
119
120    fn update_connection(&mut self, cx: &mut Context<Self>) {
121        let active_connection = self.connection_registry.read(cx).active_connection.borrow();
122        let Some(active_connection) = active_connection.as_ref() else {
123            return;
124        };
125
126        if let Some(watched_connection) = self.watched_connection.as_ref() {
127            if Weak::ptr_eq(
128                &watched_connection.connection,
129                &active_connection.connection,
130            ) {
131                return;
132            }
133        }
134
135        if let Some(connection) = active_connection.connection.upgrade() {
136            let mut receiver = connection.subscribe();
137            let task = cx.spawn(async move |this, cx| {
138                while let Ok(message) = receiver.recv().await {
139                    this.update(cx, |this, cx| {
140                        this.push_stream_message(message, cx);
141                    })
142                    .ok();
143                }
144            });
145
146            self.watched_connection = Some(WatchedConnection {
147                agent_id: active_connection.agent_id.clone(),
148                messages: vec![],
149                list_state: ListState::new(0, ListAlignment::Bottom, px(2048.)),
150                connection: active_connection.connection.clone(),
151                incoming_request_methods: HashMap::default(),
152                outgoing_request_methods: HashMap::default(),
153                _task: task,
154            });
155        }
156    }
157
158    fn push_stream_message(&mut self, stream_message: acp::StreamMessage, cx: &mut Context<Self>) {
159        let Some(connection) = self.watched_connection.as_mut() else {
160            return;
161        };
162        let language_registry = self.project.read(cx).languages().clone();
163        let index = connection.messages.len();
164
165        let (request_id, method, message_type, params) = match stream_message.message {
166            acp::StreamMessageContent::Request { id, method, params } => {
167                let method_map = match stream_message.direction {
168                    acp::StreamMessageDirection::Incoming => {
169                        &mut connection.incoming_request_methods
170                    }
171                    acp::StreamMessageDirection::Outgoing => {
172                        &mut connection.outgoing_request_methods
173                    }
174                };
175
176                method_map.insert(id.clone(), method.clone());
177                (Some(id), method.into(), MessageType::Request, Ok(params))
178            }
179            acp::StreamMessageContent::Response { id, result } => {
180                let method_map = match stream_message.direction {
181                    acp::StreamMessageDirection::Incoming => {
182                        &mut connection.outgoing_request_methods
183                    }
184                    acp::StreamMessageDirection::Outgoing => {
185                        &mut connection.incoming_request_methods
186                    }
187                };
188
189                if let Some(method) = method_map.remove(&id) {
190                    (Some(id), method.into(), MessageType::Response, result)
191                } else {
192                    (
193                        Some(id),
194                        "[unrecognized response]".into(),
195                        MessageType::Response,
196                        result,
197                    )
198                }
199            }
200            acp::StreamMessageContent::Notification { method, params } => {
201                (None, method.into(), MessageType::Notification, Ok(params))
202            }
203        };
204
205        let message = WatchedConnectionMessage {
206            name: method,
207            message_type,
208            request_id,
209            direction: stream_message.direction,
210            collapsed_params_md: match params.as_ref() {
211                Ok(params) => params
212                    .as_ref()
213                    .map(|params| collapsed_params_md(params, &language_registry, cx)),
214                Err(err) => {
215                    if let Ok(err) = &serde_json::to_value(err) {
216                        Some(collapsed_params_md(&err, &language_registry, cx))
217                    } else {
218                        None
219                    }
220                }
221            },
222
223            expanded_params_md: None,
224            params,
225        };
226
227        connection.messages.push(message);
228        connection.list_state.splice(index..index, 1);
229        cx.notify();
230    }
231
232    fn serialize_observed_messages(&self) -> Option<String> {
233        let connection = self.watched_connection.as_ref()?;
234
235        let messages: Vec<serde_json::Value> = connection
236            .messages
237            .iter()
238            .filter_map(|message| {
239                let params = match &message.params {
240                    Ok(Some(params)) => params.clone(),
241                    Ok(None) => serde_json::Value::Null,
242                    Err(err) => serde_json::to_value(err).ok()?,
243                };
244                Some(serde_json::json!({
245                    "_direction": match message.direction {
246                        acp::StreamMessageDirection::Incoming => "incoming",
247                        acp::StreamMessageDirection::Outgoing => "outgoing",
248                    },
249                    "_type": message.message_type.to_string().to_lowercase(),
250                    "id": message.request_id,
251                    "method": message.name.to_string(),
252                    "params": params,
253                }))
254            })
255            .collect();
256
257        serde_json::to_string_pretty(&messages).ok()
258    }
259
260    fn clear_messages(&mut self, cx: &mut Context<Self>) {
261        if let Some(connection) = self.watched_connection.as_mut() {
262            connection.messages.clear();
263            connection.list_state.reset(0);
264            self.expanded.clear();
265            cx.notify();
266        }
267    }
268
269    fn render_message(
270        &mut self,
271        index: usize,
272        window: &mut Window,
273        cx: &mut Context<Self>,
274    ) -> AnyElement {
275        let Some(connection) = self.watched_connection.as_ref() else {
276            return Empty.into_any();
277        };
278
279        let Some(message) = connection.messages.get(index) else {
280            return Empty.into_any();
281        };
282
283        let base_size = TextSize::Editor.rems(cx);
284
285        let theme_settings = ThemeSettings::get_global(cx);
286        let text_style = window.text_style();
287
288        let colors = cx.theme().colors();
289        let expanded = self.expanded.contains(&index);
290
291        v_flex()
292            .id(index)
293            .group("message")
294            .font_buffer(cx)
295            .w_full()
296            .py_3()
297            .pl_4()
298            .pr_5()
299            .gap_2()
300            .items_start()
301            .text_size(base_size)
302            .border_color(colors.border)
303            .border_b_1()
304            .hover(|this| this.bg(colors.element_background.opacity(0.5)))
305            .child(
306                h_flex()
307                    .id(("acp-log-message-header", index))
308                    .w_full()
309                    .gap_2()
310                    .flex_shrink_0()
311                    .cursor_pointer()
312                    .on_click(cx.listener(move |this, _, _, cx| {
313                        if this.expanded.contains(&index) {
314                            this.expanded.remove(&index);
315                        } else {
316                            this.expanded.insert(index);
317                            let Some(connection) = &mut this.watched_connection else {
318                                return;
319                            };
320                            let Some(message) = connection.messages.get_mut(index) else {
321                                return;
322                            };
323                            message.expanded(this.project.read(cx).languages().clone(), cx);
324                            connection.list_state.scroll_to_reveal_item(index);
325                        }
326                        cx.notify()
327                    }))
328                    .child(match message.direction {
329                        acp::StreamMessageDirection::Incoming => Icon::new(IconName::ArrowDown)
330                            .color(Color::Error)
331                            .size(IconSize::Small),
332                        acp::StreamMessageDirection::Outgoing => Icon::new(IconName::ArrowUp)
333                            .color(Color::Success)
334                            .size(IconSize::Small),
335                    })
336                    .child(
337                        Label::new(message.name.clone())
338                            .buffer_font(cx)
339                            .color(Color::Muted),
340                    )
341                    .child(div().flex_1())
342                    .child(
343                        div()
344                            .child(ui::Chip::new(message.message_type.to_string()))
345                            .visible_on_hover("message"),
346                    )
347                    .children(
348                        message
349                            .request_id
350                            .as_ref()
351                            .map(|req_id| div().child(ui::Chip::new(req_id.to_string()))),
352                    ),
353            )
354            // I'm aware using markdown is a hack. Trying to get something working for the demo.
355            // Will clean up soon!
356            .when_some(
357                if expanded {
358                    message.expanded_params_md.clone()
359                } else {
360                    message.collapsed_params_md.clone()
361                },
362                |this, params| {
363                    this.child(
364                        div().pl_6().w_full().child(
365                            MarkdownElement::new(
366                                params,
367                                MarkdownStyle {
368                                    base_text_style: text_style,
369                                    selection_background_color: colors.element_selection_background,
370                                    syntax: cx.theme().syntax().clone(),
371                                    code_block_overflow_x_scroll: true,
372                                    code_block: StyleRefinement {
373                                        text: TextStyleRefinement {
374                                            font_family: Some(
375                                                theme_settings.buffer_font.family.clone(),
376                                            ),
377                                            font_size: Some((base_size * 0.8).into()),
378                                            ..Default::default()
379                                        },
380                                        ..Default::default()
381                                    },
382                                    ..Default::default()
383                                },
384                            )
385                            .code_block_renderer(
386                                CodeBlockRenderer::Default {
387                                    copy_button: false,
388                                    copy_button_on_hover: expanded,
389                                    border: false,
390                                },
391                            ),
392                        ),
393                    )
394                },
395            )
396            .into_any()
397    }
398}
399
400struct WatchedConnectionMessage {
401    name: SharedString,
402    request_id: Option<acp::RequestId>,
403    direction: acp::StreamMessageDirection,
404    message_type: MessageType,
405    params: Result<Option<serde_json::Value>, acp::Error>,
406    collapsed_params_md: Option<Entity<Markdown>>,
407    expanded_params_md: Option<Entity<Markdown>>,
408}
409
410impl WatchedConnectionMessage {
411    fn expanded(&mut self, language_registry: Arc<LanguageRegistry>, cx: &mut App) {
412        let params_md = match &self.params {
413            Ok(Some(params)) => Some(expanded_params_md(params, &language_registry, cx)),
414            Err(err) => {
415                if let Some(err) = &serde_json::to_value(err).log_err() {
416                    Some(expanded_params_md(&err, &language_registry, cx))
417                } else {
418                    None
419                }
420            }
421            _ => None,
422        };
423        self.expanded_params_md = params_md;
424    }
425}
426
427fn collapsed_params_md(
428    params: &serde_json::Value,
429    language_registry: &Arc<LanguageRegistry>,
430    cx: &mut App,
431) -> Entity<Markdown> {
432    let params_json = serde_json::to_string(params).unwrap_or_default();
433    let mut spaced_out_json = String::with_capacity(params_json.len() + params_json.len() / 4);
434
435    for ch in params_json.chars() {
436        match ch {
437            '{' => spaced_out_json.push_str("{ "),
438            '}' => spaced_out_json.push_str(" }"),
439            ':' => spaced_out_json.push_str(": "),
440            ',' => spaced_out_json.push_str(", "),
441            c => spaced_out_json.push(c),
442        }
443    }
444
445    let params_md = format!("```json\n{}\n```", spaced_out_json);
446    cx.new(|cx| Markdown::new(params_md.into(), Some(language_registry.clone()), None, cx))
447}
448
449fn expanded_params_md(
450    params: &serde_json::Value,
451    language_registry: &Arc<LanguageRegistry>,
452    cx: &mut App,
453) -> Entity<Markdown> {
454    let params_json = serde_json::to_string_pretty(params).unwrap_or_default();
455    let params_md = format!("```json\n{}\n```", params_json);
456    cx.new(|cx| Markdown::new(params_md.into(), Some(language_registry.clone()), None, cx))
457}
458
459enum MessageType {
460    Request,
461    Response,
462    Notification,
463}
464
465impl Display for MessageType {
466    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
467        match self {
468            MessageType::Request => write!(f, "Request"),
469            MessageType::Response => write!(f, "Response"),
470            MessageType::Notification => write!(f, "Notification"),
471        }
472    }
473}
474
475enum AcpToolsEvent {}
476
477impl EventEmitter<AcpToolsEvent> for AcpTools {}
478
479impl Item for AcpTools {
480    type Event = AcpToolsEvent;
481
482    fn tab_content_text(&self, _detail: usize, _cx: &App) -> ui::SharedString {
483        format!(
484            "ACP: {}",
485            self.watched_connection
486                .as_ref()
487                .map_or("Disconnected", |connection| connection.agent_id.0.as_ref())
488        )
489        .into()
490    }
491
492    fn tab_icon(&self, _window: &Window, _cx: &App) -> Option<Icon> {
493        Some(ui::Icon::new(IconName::Thread))
494    }
495}
496
497impl Focusable for AcpTools {
498    fn focus_handle(&self, _cx: &App) -> FocusHandle {
499        self.focus_handle.clone()
500    }
501}
502
503impl Render for AcpTools {
504    fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
505        v_flex()
506            .track_focus(&self.focus_handle)
507            .size_full()
508            .bg(cx.theme().colors().editor_background)
509            .child(match self.watched_connection.as_ref() {
510                Some(connection) => {
511                    if connection.messages.is_empty() {
512                        h_flex()
513                            .size_full()
514                            .justify_center()
515                            .items_center()
516                            .child("No messages recorded yet")
517                            .into_any()
518                    } else {
519                        div()
520                            .size_full()
521                            .flex_grow()
522                            .child(
523                                list(
524                                    connection.list_state.clone(),
525                                    cx.processor(Self::render_message),
526                                )
527                                .with_sizing_behavior(gpui::ListSizingBehavior::Auto)
528                                .size_full(),
529                            )
530                            .vertical_scrollbar_for(&connection.list_state, window, cx)
531                            .into_any()
532                    }
533                }
534                None => h_flex()
535                    .size_full()
536                    .justify_center()
537                    .items_center()
538                    .child("No active connection")
539                    .into_any(),
540            })
541    }
542}
543
544pub struct AcpToolsToolbarItemView {
545    acp_tools: Option<Entity<AcpTools>>,
546}
547
548impl AcpToolsToolbarItemView {
549    pub fn new() -> Self {
550        Self { acp_tools: None }
551    }
552}
553
554impl Render for AcpToolsToolbarItemView {
555    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
556        let Some(acp_tools) = self.acp_tools.as_ref() else {
557            return Empty.into_any_element();
558        };
559
560        let acp_tools = acp_tools.clone();
561        let has_messages = acp_tools
562            .read(cx)
563            .watched_connection
564            .as_ref()
565            .is_some_and(|connection| !connection.messages.is_empty());
566
567        h_flex()
568            .gap_2()
569            .child({
570                let message = acp_tools
571                    .read(cx)
572                    .serialize_observed_messages()
573                    .unwrap_or_default();
574
575                CopyButton::new("copy-all-messages", message)
576                    .tooltip_label("Copy All Messages")
577                    .disabled(!has_messages)
578            })
579            .child(
580                IconButton::new("clear_messages", IconName::Trash)
581                    .icon_size(IconSize::Small)
582                    .tooltip(Tooltip::text("Clear Messages"))
583                    .disabled(!has_messages)
584                    .on_click(cx.listener(move |_this, _, _window, cx| {
585                        acp_tools.update(cx, |acp_tools, cx| {
586                            acp_tools.clear_messages(cx);
587                        });
588                    })),
589            )
590            .into_any()
591    }
592}
593
594impl EventEmitter<ToolbarItemEvent> for AcpToolsToolbarItemView {}
595
596impl ToolbarItemView for AcpToolsToolbarItemView {
597    fn set_active_pane_item(
598        &mut self,
599        active_pane_item: Option<&dyn ItemHandle>,
600        _window: &mut Window,
601        cx: &mut Context<Self>,
602    ) -> ToolbarItemLocation {
603        if let Some(item) = active_pane_item
604            && let Some(acp_tools) = item.downcast::<AcpTools>()
605        {
606            self.acp_tools = Some(acp_tools);
607            cx.notify();
608            return ToolbarItemLocation::PrimaryRight;
609        }
610        if self.acp_tools.take().is_some() {
611            cx.notify();
612        }
613        ToolbarItemLocation::Hidden
614    }
615}