acp_tools.rs

  1use std::{
  2    cell::RefCell,
  3    collections::HashSet,
  4    fmt::Display,
  5    rc::{Rc, Weak},
  6    sync::Arc,
  7    time::Duration,
  8};
  9
 10use agent_client_protocol as acp;
 11use collections::HashMap;
 12use gpui::{
 13    App, ClipboardItem, Empty, Entity, EventEmitter, FocusHandle, Focusable, Global, ListAlignment,
 14    ListState, StyleRefinement, Subscription, Task, TextStyleRefinement, Window, actions, list,
 15    prelude::*,
 16};
 17use language::LanguageRegistry;
 18use markdown::{CodeBlockRenderer, Markdown, MarkdownElement, MarkdownStyle};
 19use project::Project;
 20use settings::Settings;
 21use theme::ThemeSettings;
 22use ui::{Tooltip, WithScrollbar, prelude::*};
 23use util::ResultExt as _;
 24use workspace::{
 25    Item, ItemHandle, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, Workspace,
 26};
 27
 28actions!(dev, [OpenAcpLogs]);
 29
 30pub fn init(cx: &mut App) {
 31    cx.observe_new(
 32        |workspace: &mut Workspace, _window, _cx: &mut Context<Workspace>| {
 33            workspace.register_action(|workspace, _: &OpenAcpLogs, window, cx| {
 34                let acp_tools =
 35                    Box::new(cx.new(|cx| AcpTools::new(workspace.project().clone(), cx)));
 36                workspace.add_item_to_active_pane(acp_tools, None, true, window, cx);
 37            });
 38        },
 39    )
 40    .detach();
 41}
 42
 43struct GlobalAcpConnectionRegistry(Entity<AcpConnectionRegistry>);
 44
 45impl Global for GlobalAcpConnectionRegistry {}
 46
 47#[derive(Default)]
 48pub struct AcpConnectionRegistry {
 49    active_connection: RefCell<Option<ActiveConnection>>,
 50}
 51
 52struct ActiveConnection {
 53    server_name: SharedString,
 54    connection: Weak<acp::ClientSideConnection>,
 55}
 56
 57impl AcpConnectionRegistry {
 58    pub fn default_global(cx: &mut App) -> Entity<Self> {
 59        if cx.has_global::<GlobalAcpConnectionRegistry>() {
 60            cx.global::<GlobalAcpConnectionRegistry>().0.clone()
 61        } else {
 62            let registry = cx.new(|_cx| AcpConnectionRegistry::default());
 63            cx.set_global(GlobalAcpConnectionRegistry(registry.clone()));
 64            registry
 65        }
 66    }
 67
 68    pub fn set_active_connection(
 69        &self,
 70        server_name: impl Into<SharedString>,
 71        connection: &Rc<acp::ClientSideConnection>,
 72        cx: &mut Context<Self>,
 73    ) {
 74        self.active_connection.replace(Some(ActiveConnection {
 75            server_name: server_name.into(),
 76            connection: Rc::downgrade(connection),
 77        }));
 78        cx.notify();
 79    }
 80}
 81
 82struct AcpTools {
 83    project: Entity<Project>,
 84    focus_handle: FocusHandle,
 85    expanded: HashSet<usize>,
 86    watched_connection: Option<WatchedConnection>,
 87    connection_registry: Entity<AcpConnectionRegistry>,
 88    _subscription: Subscription,
 89}
 90
 91struct WatchedConnection {
 92    server_name: SharedString,
 93    messages: Vec<WatchedConnectionMessage>,
 94    list_state: ListState,
 95    connection: Weak<acp::ClientSideConnection>,
 96    incoming_request_methods: HashMap<acp::RequestId, Arc<str>>,
 97    outgoing_request_methods: HashMap<acp::RequestId, Arc<str>>,
 98    _task: Task<()>,
 99}
100
101impl AcpTools {
102    fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
103        let connection_registry = AcpConnectionRegistry::default_global(cx);
104
105        let subscription = cx.observe(&connection_registry, |this, _, cx| {
106            this.update_connection(cx);
107            cx.notify();
108        });
109
110        let mut this = Self {
111            project,
112            focus_handle: cx.focus_handle(),
113            expanded: HashSet::default(),
114            watched_connection: None,
115            connection_registry,
116            _subscription: subscription,
117        };
118        this.update_connection(cx);
119        this
120    }
121
122    fn update_connection(&mut self, cx: &mut Context<Self>) {
123        let active_connection = self.connection_registry.read(cx).active_connection.borrow();
124        let Some(active_connection) = active_connection.as_ref() else {
125            return;
126        };
127
128        if let Some(watched_connection) = self.watched_connection.as_ref() {
129            if Weak::ptr_eq(
130                &watched_connection.connection,
131                &active_connection.connection,
132            ) {
133                return;
134            }
135        }
136
137        if let Some(connection) = active_connection.connection.upgrade() {
138            let mut receiver = connection.subscribe();
139            let task = cx.spawn(async move |this, cx| {
140                while let Ok(message) = receiver.recv().await {
141                    this.update(cx, |this, cx| {
142                        this.push_stream_message(message, cx);
143                    })
144                    .ok();
145                }
146            });
147
148            self.watched_connection = Some(WatchedConnection {
149                server_name: active_connection.server_name.clone(),
150                messages: vec![],
151                list_state: ListState::new(0, ListAlignment::Bottom, px(2048.)),
152                connection: active_connection.connection.clone(),
153                incoming_request_methods: HashMap::default(),
154                outgoing_request_methods: HashMap::default(),
155                _task: task,
156            });
157        }
158    }
159
160    fn push_stream_message(&mut self, stream_message: acp::StreamMessage, cx: &mut Context<Self>) {
161        let Some(connection) = self.watched_connection.as_mut() else {
162            return;
163        };
164        let language_registry = self.project.read(cx).languages().clone();
165        let index = connection.messages.len();
166
167        let (request_id, method, message_type, params) = match stream_message.message {
168            acp::StreamMessageContent::Request { id, method, params } => {
169                let method_map = match stream_message.direction {
170                    acp::StreamMessageDirection::Incoming => {
171                        &mut connection.incoming_request_methods
172                    }
173                    acp::StreamMessageDirection::Outgoing => {
174                        &mut connection.outgoing_request_methods
175                    }
176                };
177
178                method_map.insert(id.clone(), method.clone());
179                (Some(id), method.into(), MessageType::Request, Ok(params))
180            }
181            acp::StreamMessageContent::Response { id, result } => {
182                let method_map = match stream_message.direction {
183                    acp::StreamMessageDirection::Incoming => {
184                        &mut connection.outgoing_request_methods
185                    }
186                    acp::StreamMessageDirection::Outgoing => {
187                        &mut connection.incoming_request_methods
188                    }
189                };
190
191                if let Some(method) = method_map.remove(&id) {
192                    (Some(id), method.into(), MessageType::Response, result)
193                } else {
194                    (
195                        Some(id),
196                        "[unrecognized response]".into(),
197                        MessageType::Response,
198                        result,
199                    )
200                }
201            }
202            acp::StreamMessageContent::Notification { method, params } => {
203                (None, method.into(), MessageType::Notification, Ok(params))
204            }
205        };
206
207        let message = WatchedConnectionMessage {
208            name: method,
209            message_type,
210            request_id,
211            direction: stream_message.direction,
212            collapsed_params_md: match params.as_ref() {
213                Ok(params) => params
214                    .as_ref()
215                    .map(|params| collapsed_params_md(params, &language_registry, cx)),
216                Err(err) => {
217                    if let Ok(err) = &serde_json::to_value(err) {
218                        Some(collapsed_params_md(&err, &language_registry, cx))
219                    } else {
220                        None
221                    }
222                }
223            },
224
225            expanded_params_md: None,
226            params,
227        };
228
229        connection.messages.push(message);
230        connection.list_state.splice(index..index, 1);
231        cx.notify();
232    }
233
234    fn serialize_observed_messages(&self) -> Option<String> {
235        let connection = self.watched_connection.as_ref()?;
236
237        let messages: Vec<serde_json::Value> = connection
238            .messages
239            .iter()
240            .filter_map(|message| {
241                let params = match &message.params {
242                    Ok(Some(params)) => params.clone(),
243                    Ok(None) => serde_json::Value::Null,
244                    Err(err) => serde_json::to_value(err).ok()?,
245                };
246                Some(serde_json::json!({
247                    "_direction": match message.direction {
248                        acp::StreamMessageDirection::Incoming => "incoming",
249                        acp::StreamMessageDirection::Outgoing => "outgoing",
250                    },
251                    "_type": message.message_type.to_string().to_lowercase(),
252                    "id": message.request_id,
253                    "method": message.name.to_string(),
254                    "params": params,
255                }))
256            })
257            .collect();
258
259        serde_json::to_string_pretty(&messages).ok()
260    }
261
262    fn clear_messages(&mut self, cx: &mut Context<Self>) {
263        if let Some(connection) = self.watched_connection.as_mut() {
264            connection.messages.clear();
265            connection.list_state.reset(0);
266            self.expanded.clear();
267            cx.notify();
268        }
269    }
270
271    fn render_message(
272        &mut self,
273        index: usize,
274        window: &mut Window,
275        cx: &mut Context<Self>,
276    ) -> AnyElement {
277        let Some(connection) = self.watched_connection.as_ref() else {
278            return Empty.into_any();
279        };
280
281        let Some(message) = connection.messages.get(index) else {
282            return Empty.into_any();
283        };
284
285        let base_size = TextSize::Editor.rems(cx);
286
287        let theme_settings = ThemeSettings::get_global(cx);
288        let text_style = window.text_style();
289
290        let colors = cx.theme().colors();
291        let expanded = self.expanded.contains(&index);
292
293        v_flex()
294            .id(index)
295            .group("message")
296            .cursor_pointer()
297            .font_buffer(cx)
298            .w_full()
299            .py_3()
300            .pl_4()
301            .pr_5()
302            .gap_2()
303            .items_start()
304            .text_size(base_size)
305            .border_color(colors.border)
306            .border_b_1()
307            .hover(|this| this.bg(colors.element_background.opacity(0.5)))
308            .on_click(cx.listener(move |this, _, _, cx| {
309                if this.expanded.contains(&index) {
310                    this.expanded.remove(&index);
311                } else {
312                    this.expanded.insert(index);
313                    let Some(connection) = &mut this.watched_connection else {
314                        return;
315                    };
316                    let Some(message) = connection.messages.get_mut(index) else {
317                        return;
318                    };
319                    message.expanded(this.project.read(cx).languages().clone(), cx);
320                    connection.list_state.scroll_to_reveal_item(index);
321                }
322                cx.notify()
323            }))
324            .child(
325                h_flex()
326                    .w_full()
327                    .gap_2()
328                    .flex_shrink_0()
329                    .child(match message.direction {
330                        acp::StreamMessageDirection::Incoming => Icon::new(IconName::ArrowDown)
331                            .color(Color::Error)
332                            .size(IconSize::Small),
333                        acp::StreamMessageDirection::Outgoing => Icon::new(IconName::ArrowUp)
334                            .color(Color::Success)
335                            .size(IconSize::Small),
336                    })
337                    .child(
338                        Label::new(message.name.clone())
339                            .buffer_font(cx)
340                            .color(Color::Muted),
341                    )
342                    .child(div().flex_1())
343                    .child(
344                        div()
345                            .child(ui::Chip::new(message.message_type.to_string()))
346                            .visible_on_hover("message"),
347                    )
348                    .children(
349                        message
350                            .request_id
351                            .as_ref()
352                            .map(|req_id| div().child(ui::Chip::new(req_id.to_string()))),
353                    ),
354            )
355            // I'm aware using markdown is a hack. Trying to get something working for the demo.
356            // Will clean up soon!
357            .when_some(
358                if expanded {
359                    message.expanded_params_md.clone()
360                } else {
361                    message.collapsed_params_md.clone()
362                },
363                |this, params| {
364                    this.child(
365                        div().pl_6().w_full().child(
366                            MarkdownElement::new(
367                                params,
368                                MarkdownStyle {
369                                    base_text_style: text_style,
370                                    selection_background_color: colors.element_selection_background,
371                                    syntax: cx.theme().syntax().clone(),
372                                    code_block_overflow_x_scroll: true,
373                                    code_block: StyleRefinement {
374                                        text: Some(TextStyleRefinement {
375                                            font_family: Some(
376                                                theme_settings.buffer_font.family.clone(),
377                                            ),
378                                            font_size: Some((base_size * 0.8).into()),
379                                            ..Default::default()
380                                        }),
381                                        ..Default::default()
382                                    },
383                                    ..Default::default()
384                                },
385                            )
386                            .code_block_renderer(
387                                CodeBlockRenderer::Default {
388                                    copy_button: false,
389                                    copy_button_on_hover: expanded,
390                                    border: false,
391                                },
392                            ),
393                        ),
394                    )
395                },
396            )
397            .into_any()
398    }
399}
400
401struct WatchedConnectionMessage {
402    name: SharedString,
403    request_id: Option<acp::RequestId>,
404    direction: acp::StreamMessageDirection,
405    message_type: MessageType,
406    params: Result<Option<serde_json::Value>, acp::Error>,
407    collapsed_params_md: Option<Entity<Markdown>>,
408    expanded_params_md: Option<Entity<Markdown>>,
409}
410
411impl WatchedConnectionMessage {
412    fn expanded(&mut self, language_registry: Arc<LanguageRegistry>, cx: &mut App) {
413        let params_md = match &self.params {
414            Ok(Some(params)) => Some(expanded_params_md(params, &language_registry, cx)),
415            Err(err) => {
416                if let Some(err) = &serde_json::to_value(err).log_err() {
417                    Some(expanded_params_md(&err, &language_registry, cx))
418                } else {
419                    None
420                }
421            }
422            _ => None,
423        };
424        self.expanded_params_md = params_md;
425    }
426}
427
428fn collapsed_params_md(
429    params: &serde_json::Value,
430    language_registry: &Arc<LanguageRegistry>,
431    cx: &mut App,
432) -> Entity<Markdown> {
433    let params_json = serde_json::to_string(params).unwrap_or_default();
434    let mut spaced_out_json = String::with_capacity(params_json.len() + params_json.len() / 4);
435
436    for ch in params_json.chars() {
437        match ch {
438            '{' => spaced_out_json.push_str("{ "),
439            '}' => spaced_out_json.push_str(" }"),
440            ':' => spaced_out_json.push_str(": "),
441            ',' => spaced_out_json.push_str(", "),
442            c => spaced_out_json.push(c),
443        }
444    }
445
446    let params_md = format!("```json\n{}\n```", spaced_out_json);
447    cx.new(|cx| Markdown::new(params_md.into(), Some(language_registry.clone()), None, cx))
448}
449
450fn expanded_params_md(
451    params: &serde_json::Value,
452    language_registry: &Arc<LanguageRegistry>,
453    cx: &mut App,
454) -> Entity<Markdown> {
455    let params_json = serde_json::to_string_pretty(params).unwrap_or_default();
456    let params_md = format!("```json\n{}\n```", params_json);
457    cx.new(|cx| Markdown::new(params_md.into(), Some(language_registry.clone()), None, cx))
458}
459
460enum MessageType {
461    Request,
462    Response,
463    Notification,
464}
465
466impl Display for MessageType {
467    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
468        match self {
469            MessageType::Request => write!(f, "Request"),
470            MessageType::Response => write!(f, "Response"),
471            MessageType::Notification => write!(f, "Notification"),
472        }
473    }
474}
475
476enum AcpToolsEvent {}
477
478impl EventEmitter<AcpToolsEvent> for AcpTools {}
479
480impl Item for AcpTools {
481    type Event = AcpToolsEvent;
482
483    fn tab_content_text(&self, _detail: usize, _cx: &App) -> ui::SharedString {
484        format!(
485            "ACP: {}",
486            self.watched_connection
487                .as_ref()
488                .map_or("Disconnected", |connection| &connection.server_name)
489        )
490        .into()
491    }
492
493    fn tab_icon(&self, _window: &Window, _cx: &App) -> Option<Icon> {
494        Some(ui::Icon::new(IconName::Thread))
495    }
496}
497
498impl Focusable for AcpTools {
499    fn focus_handle(&self, _cx: &App) -> FocusHandle {
500        self.focus_handle.clone()
501    }
502}
503
504impl Render for AcpTools {
505    fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
506        v_flex()
507            .track_focus(&self.focus_handle)
508            .size_full()
509            .bg(cx.theme().colors().editor_background)
510            .child(match self.watched_connection.as_ref() {
511                Some(connection) => {
512                    if connection.messages.is_empty() {
513                        h_flex()
514                            .size_full()
515                            .justify_center()
516                            .items_center()
517                            .child("No messages recorded yet")
518                            .into_any()
519                    } else {
520                        div()
521                            .size_full()
522                            .flex_grow()
523                            .child(
524                                list(
525                                    connection.list_state.clone(),
526                                    cx.processor(Self::render_message),
527                                )
528                                .with_sizing_behavior(gpui::ListSizingBehavior::Auto)
529                                .size_full(),
530                            )
531                            .vertical_scrollbar_for(connection.list_state.clone(), window, cx)
532                            .into_any()
533                    }
534                }
535                None => h_flex()
536                    .size_full()
537                    .justify_center()
538                    .items_center()
539                    .child("No active connection")
540                    .into_any(),
541            })
542    }
543}
544
545pub struct AcpToolsToolbarItemView {
546    acp_tools: Option<Entity<AcpTools>>,
547    just_copied: bool,
548}
549
550impl AcpToolsToolbarItemView {
551    pub fn new() -> Self {
552        Self {
553            acp_tools: None,
554            just_copied: false,
555        }
556    }
557}
558
559impl Render for AcpToolsToolbarItemView {
560    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
561        let Some(acp_tools) = self.acp_tools.as_ref() else {
562            return Empty.into_any_element();
563        };
564
565        let acp_tools = acp_tools.clone();
566        let has_messages = acp_tools
567            .read(cx)
568            .watched_connection
569            .as_ref()
570            .is_some_and(|connection| !connection.messages.is_empty());
571
572        h_flex()
573            .gap_2()
574            .child({
575                let acp_tools = acp_tools.clone();
576                IconButton::new(
577                    "copy_all_messages",
578                    if self.just_copied {
579                        IconName::Check
580                    } else {
581                        IconName::Copy
582                    },
583                )
584                .icon_size(IconSize::Small)
585                .tooltip(Tooltip::text(if self.just_copied {
586                    "Copied!"
587                } else {
588                    "Copy All Messages"
589                }))
590                .disabled(!has_messages)
591                .on_click(cx.listener(move |this, _, _window, cx| {
592                    if let Some(content) = acp_tools.read(cx).serialize_observed_messages() {
593                        cx.write_to_clipboard(ClipboardItem::new_string(content));
594
595                        this.just_copied = true;
596                        cx.spawn(async move |this, cx| {
597                            cx.background_executor().timer(Duration::from_secs(2)).await;
598                            this.update(cx, |this, cx| {
599                                this.just_copied = false;
600                                cx.notify();
601                            })
602                        })
603                        .detach();
604                    }
605                }))
606            })
607            .child(
608                IconButton::new("clear_messages", IconName::Trash)
609                    .icon_size(IconSize::Small)
610                    .tooltip(Tooltip::text("Clear Messages"))
611                    .disabled(!has_messages)
612                    .on_click(cx.listener(move |_this, _, _window, cx| {
613                        acp_tools.update(cx, |acp_tools, cx| {
614                            acp_tools.clear_messages(cx);
615                        });
616                    })),
617            )
618            .into_any()
619    }
620}
621
622impl EventEmitter<ToolbarItemEvent> for AcpToolsToolbarItemView {}
623
624impl ToolbarItemView for AcpToolsToolbarItemView {
625    fn set_active_pane_item(
626        &mut self,
627        active_pane_item: Option<&dyn ItemHandle>,
628        _window: &mut Window,
629        cx: &mut Context<Self>,
630    ) -> ToolbarItemLocation {
631        if let Some(item) = active_pane_item
632            && let Some(acp_tools) = item.downcast::<AcpTools>()
633        {
634            self.acp_tools = Some(acp_tools);
635            cx.notify();
636            return ToolbarItemLocation::PrimaryRight;
637        }
638        if self.acp_tools.take().is_some() {
639            cx.notify();
640        }
641        ToolbarItemLocation::Hidden
642    }
643}