acp_tools.rs

  1use std::{
  2    cell::RefCell,
  3    collections::HashSet,
  4    fmt::Display,
  5    rc::{Rc, Weak},
  6    sync::Arc,
  7};
  8
  9use agent_client_protocol as acp;
 10use collections::HashMap;
 11use gpui::{
 12    App, Empty, Entity, EventEmitter, FocusHandle, Focusable, Global, ListAlignment, ListState,
 13    StyleRefinement, Subscription, Task, TextStyleRefinement, Window, actions, list, prelude::*,
 14};
 15use language::LanguageRegistry;
 16use markdown::{CodeBlockRenderer, Markdown, MarkdownElement, MarkdownStyle};
 17use project::Project;
 18use settings::Settings;
 19use theme::ThemeSettings;
 20use ui::{CopyButton, Tooltip, WithScrollbar, prelude::*};
 21use util::ResultExt as _;
 22use workspace::{
 23    Item, ItemHandle, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, Workspace,
 24};
 25
 26actions!(dev, [OpenAcpLogs]);
 27
 28pub fn init(cx: &mut App) {
 29    cx.observe_new(
 30        |workspace: &mut Workspace, _window, _cx: &mut Context<Workspace>| {
 31            workspace.register_action(|workspace, _: &OpenAcpLogs, window, cx| {
 32                let acp_tools =
 33                    Box::new(cx.new(|cx| AcpTools::new(workspace.project().clone(), cx)));
 34                workspace.add_item_to_active_pane(acp_tools, None, true, window, cx);
 35            });
 36        },
 37    )
 38    .detach();
 39}
 40
 41struct GlobalAcpConnectionRegistry(Entity<AcpConnectionRegistry>);
 42
 43impl Global for GlobalAcpConnectionRegistry {}
 44
 45#[derive(Default)]
 46pub struct AcpConnectionRegistry {
 47    active_connection: RefCell<Option<ActiveConnection>>,
 48}
 49
 50struct ActiveConnection {
 51    server_name: SharedString,
 52    connection: Weak<acp::ClientSideConnection>,
 53}
 54
 55impl AcpConnectionRegistry {
 56    pub fn default_global(cx: &mut App) -> Entity<Self> {
 57        if cx.has_global::<GlobalAcpConnectionRegistry>() {
 58            cx.global::<GlobalAcpConnectionRegistry>().0.clone()
 59        } else {
 60            let registry = cx.new(|_cx| AcpConnectionRegistry::default());
 61            cx.set_global(GlobalAcpConnectionRegistry(registry.clone()));
 62            registry
 63        }
 64    }
 65
 66    pub fn set_active_connection(
 67        &self,
 68        server_name: impl Into<SharedString>,
 69        connection: &Rc<acp::ClientSideConnection>,
 70        cx: &mut Context<Self>,
 71    ) {
 72        self.active_connection.replace(Some(ActiveConnection {
 73            server_name: server_name.into(),
 74            connection: Rc::downgrade(connection),
 75        }));
 76        cx.notify();
 77    }
 78}
 79
 80struct AcpTools {
 81    project: Entity<Project>,
 82    focus_handle: FocusHandle,
 83    expanded: HashSet<usize>,
 84    watched_connection: Option<WatchedConnection>,
 85    connection_registry: Entity<AcpConnectionRegistry>,
 86    _subscription: Subscription,
 87}
 88
 89struct WatchedConnection {
 90    server_name: SharedString,
 91    messages: Vec<WatchedConnectionMessage>,
 92    list_state: ListState,
 93    connection: Weak<acp::ClientSideConnection>,
 94    incoming_request_methods: HashMap<acp::RequestId, Arc<str>>,
 95    outgoing_request_methods: HashMap<acp::RequestId, Arc<str>>,
 96    _task: Task<()>,
 97}
 98
 99impl AcpTools {
100    fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
101        let connection_registry = AcpConnectionRegistry::default_global(cx);
102
103        let subscription = cx.observe(&connection_registry, |this, _, cx| {
104            this.update_connection(cx);
105            cx.notify();
106        });
107
108        let mut this = Self {
109            project,
110            focus_handle: cx.focus_handle(),
111            expanded: HashSet::default(),
112            watched_connection: None,
113            connection_registry,
114            _subscription: subscription,
115        };
116        this.update_connection(cx);
117        this
118    }
119
120    fn update_connection(&mut self, cx: &mut Context<Self>) {
121        let active_connection = self.connection_registry.read(cx).active_connection.borrow();
122        let Some(active_connection) = active_connection.as_ref() else {
123            return;
124        };
125
126        if let Some(watched_connection) = self.watched_connection.as_ref() {
127            if Weak::ptr_eq(
128                &watched_connection.connection,
129                &active_connection.connection,
130            ) {
131                return;
132            }
133        }
134
135        if let Some(connection) = active_connection.connection.upgrade() {
136            let mut receiver = connection.subscribe();
137            let task = cx.spawn(async move |this, cx| {
138                while let Ok(message) = receiver.recv().await {
139                    this.update(cx, |this, cx| {
140                        this.push_stream_message(message, cx);
141                    })
142                    .ok();
143                }
144            });
145
146            self.watched_connection = Some(WatchedConnection {
147                server_name: active_connection.server_name.clone(),
148                messages: vec![],
149                list_state: ListState::new(0, ListAlignment::Bottom, px(2048.)),
150                connection: active_connection.connection.clone(),
151                incoming_request_methods: HashMap::default(),
152                outgoing_request_methods: HashMap::default(),
153                _task: task,
154            });
155        }
156    }
157
158    fn push_stream_message(&mut self, stream_message: acp::StreamMessage, cx: &mut Context<Self>) {
159        let Some(connection) = self.watched_connection.as_mut() else {
160            return;
161        };
162        let language_registry = self.project.read(cx).languages().clone();
163        let index = connection.messages.len();
164
165        let (request_id, method, message_type, params) = match stream_message.message {
166            acp::StreamMessageContent::Request { id, method, params } => {
167                let method_map = match stream_message.direction {
168                    acp::StreamMessageDirection::Incoming => {
169                        &mut connection.incoming_request_methods
170                    }
171                    acp::StreamMessageDirection::Outgoing => {
172                        &mut connection.outgoing_request_methods
173                    }
174                };
175
176                method_map.insert(id.clone(), method.clone());
177                (Some(id), method.into(), MessageType::Request, Ok(params))
178            }
179            acp::StreamMessageContent::Response { id, result } => {
180                let method_map = match stream_message.direction {
181                    acp::StreamMessageDirection::Incoming => {
182                        &mut connection.outgoing_request_methods
183                    }
184                    acp::StreamMessageDirection::Outgoing => {
185                        &mut connection.incoming_request_methods
186                    }
187                };
188
189                if let Some(method) = method_map.remove(&id) {
190                    (Some(id), method.into(), MessageType::Response, result)
191                } else {
192                    (
193                        Some(id),
194                        "[unrecognized response]".into(),
195                        MessageType::Response,
196                        result,
197                    )
198                }
199            }
200            acp::StreamMessageContent::Notification { method, params } => {
201                (None, method.into(), MessageType::Notification, Ok(params))
202            }
203        };
204
205        let message = WatchedConnectionMessage {
206            name: method,
207            message_type,
208            request_id,
209            direction: stream_message.direction,
210            collapsed_params_md: match params.as_ref() {
211                Ok(params) => params
212                    .as_ref()
213                    .map(|params| collapsed_params_md(params, &language_registry, cx)),
214                Err(err) => {
215                    if let Ok(err) = &serde_json::to_value(err) {
216                        Some(collapsed_params_md(&err, &language_registry, cx))
217                    } else {
218                        None
219                    }
220                }
221            },
222
223            expanded_params_md: None,
224            params,
225        };
226
227        connection.messages.push(message);
228        connection.list_state.splice(index..index, 1);
229        cx.notify();
230    }
231
232    fn serialize_observed_messages(&self) -> Option<String> {
233        let connection = self.watched_connection.as_ref()?;
234
235        let messages: Vec<serde_json::Value> = connection
236            .messages
237            .iter()
238            .filter_map(|message| {
239                let params = match &message.params {
240                    Ok(Some(params)) => params.clone(),
241                    Ok(None) => serde_json::Value::Null,
242                    Err(err) => serde_json::to_value(err).ok()?,
243                };
244                Some(serde_json::json!({
245                    "_direction": match message.direction {
246                        acp::StreamMessageDirection::Incoming => "incoming",
247                        acp::StreamMessageDirection::Outgoing => "outgoing",
248                    },
249                    "_type": message.message_type.to_string().to_lowercase(),
250                    "id": message.request_id,
251                    "method": message.name.to_string(),
252                    "params": params,
253                }))
254            })
255            .collect();
256
257        serde_json::to_string_pretty(&messages).ok()
258    }
259
260    fn clear_messages(&mut self, cx: &mut Context<Self>) {
261        if let Some(connection) = self.watched_connection.as_mut() {
262            connection.messages.clear();
263            connection.list_state.reset(0);
264            self.expanded.clear();
265            cx.notify();
266        }
267    }
268
269    fn render_message(
270        &mut self,
271        index: usize,
272        window: &mut Window,
273        cx: &mut Context<Self>,
274    ) -> AnyElement {
275        let Some(connection) = self.watched_connection.as_ref() else {
276            return Empty.into_any();
277        };
278
279        let Some(message) = connection.messages.get(index) else {
280            return Empty.into_any();
281        };
282
283        let base_size = TextSize::Editor.rems(cx);
284
285        let theme_settings = ThemeSettings::get_global(cx);
286        let text_style = window.text_style();
287
288        let colors = cx.theme().colors();
289        let expanded = self.expanded.contains(&index);
290
291        v_flex()
292            .id(index)
293            .group("message")
294            .cursor_pointer()
295            .font_buffer(cx)
296            .w_full()
297            .py_3()
298            .pl_4()
299            .pr_5()
300            .gap_2()
301            .items_start()
302            .text_size(base_size)
303            .border_color(colors.border)
304            .border_b_1()
305            .hover(|this| this.bg(colors.element_background.opacity(0.5)))
306            .on_click(cx.listener(move |this, _, _, cx| {
307                if this.expanded.contains(&index) {
308                    this.expanded.remove(&index);
309                } else {
310                    this.expanded.insert(index);
311                    let Some(connection) = &mut this.watched_connection else {
312                        return;
313                    };
314                    let Some(message) = connection.messages.get_mut(index) else {
315                        return;
316                    };
317                    message.expanded(this.project.read(cx).languages().clone(), cx);
318                    connection.list_state.scroll_to_reveal_item(index);
319                }
320                cx.notify()
321            }))
322            .child(
323                h_flex()
324                    .w_full()
325                    .gap_2()
326                    .flex_shrink_0()
327                    .child(match message.direction {
328                        acp::StreamMessageDirection::Incoming => Icon::new(IconName::ArrowDown)
329                            .color(Color::Error)
330                            .size(IconSize::Small),
331                        acp::StreamMessageDirection::Outgoing => Icon::new(IconName::ArrowUp)
332                            .color(Color::Success)
333                            .size(IconSize::Small),
334                    })
335                    .child(
336                        Label::new(message.name.clone())
337                            .buffer_font(cx)
338                            .color(Color::Muted),
339                    )
340                    .child(div().flex_1())
341                    .child(
342                        div()
343                            .child(ui::Chip::new(message.message_type.to_string()))
344                            .visible_on_hover("message"),
345                    )
346                    .children(
347                        message
348                            .request_id
349                            .as_ref()
350                            .map(|req_id| div().child(ui::Chip::new(req_id.to_string()))),
351                    ),
352            )
353            // I'm aware using markdown is a hack. Trying to get something working for the demo.
354            // Will clean up soon!
355            .when_some(
356                if expanded {
357                    message.expanded_params_md.clone()
358                } else {
359                    message.collapsed_params_md.clone()
360                },
361                |this, params| {
362                    this.child(
363                        div().pl_6().w_full().child(
364                            MarkdownElement::new(
365                                params,
366                                MarkdownStyle {
367                                    base_text_style: text_style,
368                                    selection_background_color: colors.element_selection_background,
369                                    syntax: cx.theme().syntax().clone(),
370                                    code_block_overflow_x_scroll: true,
371                                    code_block: StyleRefinement {
372                                        text: TextStyleRefinement {
373                                            font_family: Some(
374                                                theme_settings.buffer_font.family.clone(),
375                                            ),
376                                            font_size: Some((base_size * 0.8).into()),
377                                            ..Default::default()
378                                        },
379                                        ..Default::default()
380                                    },
381                                    ..Default::default()
382                                },
383                            )
384                            .code_block_renderer(
385                                CodeBlockRenderer::Default {
386                                    copy_button: false,
387                                    copy_button_on_hover: expanded,
388                                    border: false,
389                                },
390                            ),
391                        ),
392                    )
393                },
394            )
395            .into_any()
396    }
397}
398
399struct WatchedConnectionMessage {
400    name: SharedString,
401    request_id: Option<acp::RequestId>,
402    direction: acp::StreamMessageDirection,
403    message_type: MessageType,
404    params: Result<Option<serde_json::Value>, acp::Error>,
405    collapsed_params_md: Option<Entity<Markdown>>,
406    expanded_params_md: Option<Entity<Markdown>>,
407}
408
409impl WatchedConnectionMessage {
410    fn expanded(&mut self, language_registry: Arc<LanguageRegistry>, cx: &mut App) {
411        let params_md = match &self.params {
412            Ok(Some(params)) => Some(expanded_params_md(params, &language_registry, cx)),
413            Err(err) => {
414                if let Some(err) = &serde_json::to_value(err).log_err() {
415                    Some(expanded_params_md(&err, &language_registry, cx))
416                } else {
417                    None
418                }
419            }
420            _ => None,
421        };
422        self.expanded_params_md = params_md;
423    }
424}
425
426fn collapsed_params_md(
427    params: &serde_json::Value,
428    language_registry: &Arc<LanguageRegistry>,
429    cx: &mut App,
430) -> Entity<Markdown> {
431    let params_json = serde_json::to_string(params).unwrap_or_default();
432    let mut spaced_out_json = String::with_capacity(params_json.len() + params_json.len() / 4);
433
434    for ch in params_json.chars() {
435        match ch {
436            '{' => spaced_out_json.push_str("{ "),
437            '}' => spaced_out_json.push_str(" }"),
438            ':' => spaced_out_json.push_str(": "),
439            ',' => spaced_out_json.push_str(", "),
440            c => spaced_out_json.push(c),
441        }
442    }
443
444    let params_md = format!("```json\n{}\n```", spaced_out_json);
445    cx.new(|cx| Markdown::new(params_md.into(), Some(language_registry.clone()), None, cx))
446}
447
448fn expanded_params_md(
449    params: &serde_json::Value,
450    language_registry: &Arc<LanguageRegistry>,
451    cx: &mut App,
452) -> Entity<Markdown> {
453    let params_json = serde_json::to_string_pretty(params).unwrap_or_default();
454    let params_md = format!("```json\n{}\n```", params_json);
455    cx.new(|cx| Markdown::new(params_md.into(), Some(language_registry.clone()), None, cx))
456}
457
458enum MessageType {
459    Request,
460    Response,
461    Notification,
462}
463
464impl Display for MessageType {
465    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
466        match self {
467            MessageType::Request => write!(f, "Request"),
468            MessageType::Response => write!(f, "Response"),
469            MessageType::Notification => write!(f, "Notification"),
470        }
471    }
472}
473
474enum AcpToolsEvent {}
475
476impl EventEmitter<AcpToolsEvent> for AcpTools {}
477
478impl Item for AcpTools {
479    type Event = AcpToolsEvent;
480
481    fn tab_content_text(&self, _detail: usize, _cx: &App) -> ui::SharedString {
482        format!(
483            "ACP: {}",
484            self.watched_connection
485                .as_ref()
486                .map_or("Disconnected", |connection| &connection.server_name)
487        )
488        .into()
489    }
490
491    fn tab_icon(&self, _window: &Window, _cx: &App) -> Option<Icon> {
492        Some(ui::Icon::new(IconName::Thread))
493    }
494}
495
496impl Focusable for AcpTools {
497    fn focus_handle(&self, _cx: &App) -> FocusHandle {
498        self.focus_handle.clone()
499    }
500}
501
502impl Render for AcpTools {
503    fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
504        v_flex()
505            .track_focus(&self.focus_handle)
506            .size_full()
507            .bg(cx.theme().colors().editor_background)
508            .child(match self.watched_connection.as_ref() {
509                Some(connection) => {
510                    if connection.messages.is_empty() {
511                        h_flex()
512                            .size_full()
513                            .justify_center()
514                            .items_center()
515                            .child("No messages recorded yet")
516                            .into_any()
517                    } else {
518                        div()
519                            .size_full()
520                            .flex_grow()
521                            .child(
522                                list(
523                                    connection.list_state.clone(),
524                                    cx.processor(Self::render_message),
525                                )
526                                .with_sizing_behavior(gpui::ListSizingBehavior::Auto)
527                                .size_full(),
528                            )
529                            .vertical_scrollbar_for(&connection.list_state, window, cx)
530                            .into_any()
531                    }
532                }
533                None => h_flex()
534                    .size_full()
535                    .justify_center()
536                    .items_center()
537                    .child("No active connection")
538                    .into_any(),
539            })
540    }
541}
542
543pub struct AcpToolsToolbarItemView {
544    acp_tools: Option<Entity<AcpTools>>,
545}
546
547impl AcpToolsToolbarItemView {
548    pub fn new() -> Self {
549        Self { acp_tools: None }
550    }
551}
552
553impl Render for AcpToolsToolbarItemView {
554    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
555        let Some(acp_tools) = self.acp_tools.as_ref() else {
556            return Empty.into_any_element();
557        };
558
559        let acp_tools = acp_tools.clone();
560        let has_messages = acp_tools
561            .read(cx)
562            .watched_connection
563            .as_ref()
564            .is_some_and(|connection| !connection.messages.is_empty());
565
566        h_flex()
567            .gap_2()
568            .child({
569                let message = acp_tools
570                    .read(cx)
571                    .serialize_observed_messages()
572                    .unwrap_or_default();
573
574                CopyButton::new(message)
575                    .tooltip_label("Copy All Messages")
576                    .disabled(!has_messages)
577            })
578            .child(
579                IconButton::new("clear_messages", IconName::Trash)
580                    .icon_size(IconSize::Small)
581                    .tooltip(Tooltip::text("Clear Messages"))
582                    .disabled(!has_messages)
583                    .on_click(cx.listener(move |_this, _, _window, cx| {
584                        acp_tools.update(cx, |acp_tools, cx| {
585                            acp_tools.clear_messages(cx);
586                        });
587                    })),
588            )
589            .into_any()
590    }
591}
592
593impl EventEmitter<ToolbarItemEvent> for AcpToolsToolbarItemView {}
594
595impl ToolbarItemView for AcpToolsToolbarItemView {
596    fn set_active_pane_item(
597        &mut self,
598        active_pane_item: Option<&dyn ItemHandle>,
599        _window: &mut Window,
600        cx: &mut Context<Self>,
601    ) -> ToolbarItemLocation {
602        if let Some(item) = active_pane_item
603            && let Some(acp_tools) = item.downcast::<AcpTools>()
604        {
605            self.acp_tools = Some(acp_tools);
606            cx.notify();
607            return ToolbarItemLocation::PrimaryRight;
608        }
609        if self.acp_tools.take().is_some() {
610            cx.notify();
611        }
612        ToolbarItemLocation::Hidden
613    }
614}