acp_tools.rs

  1use std::{
  2    cell::RefCell,
  3    collections::HashSet,
  4    fmt::Display,
  5    rc::{Rc, Weak},
  6    sync::Arc,
  7};
  8
  9use agent_client_protocol as acp;
 10use collections::HashMap;
 11use gpui::{
 12    App, Empty, Entity, EventEmitter, FocusHandle, Focusable, Global, ListAlignment, ListState,
 13    StyleRefinement, Subscription, Task, TextStyleRefinement, Window, actions, list, prelude::*,
 14};
 15use language::LanguageRegistry;
 16use markdown::{CodeBlockRenderer, CopyButtonVisibility, Markdown, MarkdownElement, MarkdownStyle};
 17use project::{AgentId, Project};
 18use settings::Settings;
 19use theme_settings::ThemeSettings;
 20use ui::{CopyButton, Tooltip, WithScrollbar, prelude::*};
 21use util::ResultExt as _;
 22use workspace::{
 23    Item, ItemHandle, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, Workspace,
 24};
 25
 26actions!(dev, [OpenAcpLogs]);
 27
 28pub fn init(cx: &mut App) {
 29    cx.observe_new(
 30        |workspace: &mut Workspace, _window, _cx: &mut Context<Workspace>| {
 31            workspace.register_action(|workspace, _: &OpenAcpLogs, window, cx| {
 32                let acp_tools =
 33                    Box::new(cx.new(|cx| AcpTools::new(workspace.project().clone(), cx)));
 34                workspace.add_item_to_active_pane(acp_tools, None, true, window, cx);
 35            });
 36        },
 37    )
 38    .detach();
 39}
 40
 41struct GlobalAcpConnectionRegistry(Entity<AcpConnectionRegistry>);
 42
 43impl Global for GlobalAcpConnectionRegistry {}
 44
 45#[derive(Default)]
 46pub struct AcpConnectionRegistry {
 47    active_connection: RefCell<Option<ActiveConnection>>,
 48}
 49
 50struct ActiveConnection {
 51    agent_id: AgentId,
 52    connection: Weak<acp::ClientSideConnection>,
 53}
 54
 55impl AcpConnectionRegistry {
 56    pub fn default_global(cx: &mut App) -> Entity<Self> {
 57        if cx.has_global::<GlobalAcpConnectionRegistry>() {
 58            cx.global::<GlobalAcpConnectionRegistry>().0.clone()
 59        } else {
 60            let registry = cx.new(|_cx| AcpConnectionRegistry::default());
 61            cx.set_global(GlobalAcpConnectionRegistry(registry.clone()));
 62            registry
 63        }
 64    }
 65
 66    pub fn set_active_connection(
 67        &self,
 68        agent_id: AgentId,
 69        connection: &Rc<acp::ClientSideConnection>,
 70        cx: &mut Context<Self>,
 71    ) {
 72        self.active_connection.replace(Some(ActiveConnection {
 73            agent_id,
 74            connection: Rc::downgrade(connection),
 75        }));
 76        cx.notify();
 77    }
 78}
 79
 80struct AcpTools {
 81    project: Entity<Project>,
 82    focus_handle: FocusHandle,
 83    expanded: HashSet<usize>,
 84    watched_connection: Option<WatchedConnection>,
 85    connection_registry: Entity<AcpConnectionRegistry>,
 86    _subscription: Subscription,
 87}
 88
 89struct WatchedConnection {
 90    agent_id: AgentId,
 91    messages: Vec<WatchedConnectionMessage>,
 92    list_state: ListState,
 93    connection: Weak<acp::ClientSideConnection>,
 94    incoming_request_methods: HashMap<acp::RequestId, Arc<str>>,
 95    outgoing_request_methods: HashMap<acp::RequestId, Arc<str>>,
 96    _task: Task<()>,
 97}
 98
 99impl AcpTools {
100    fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
101        let connection_registry = AcpConnectionRegistry::default_global(cx);
102
103        let subscription = cx.observe(&connection_registry, |this, _, cx| {
104            this.update_connection(cx);
105            cx.notify();
106        });
107
108        let mut this = Self {
109            project,
110            focus_handle: cx.focus_handle(),
111            expanded: HashSet::default(),
112            watched_connection: None,
113            connection_registry,
114            _subscription: subscription,
115        };
116        this.update_connection(cx);
117        this
118    }
119
120    fn update_connection(&mut self, cx: &mut Context<Self>) {
121        let active_connection = self.connection_registry.read(cx).active_connection.borrow();
122        let Some(active_connection) = active_connection.as_ref() else {
123            return;
124        };
125
126        if let Some(watched_connection) = self.watched_connection.as_ref() {
127            if Weak::ptr_eq(
128                &watched_connection.connection,
129                &active_connection.connection,
130            ) {
131                return;
132            }
133        }
134
135        if let Some(connection) = active_connection.connection.upgrade() {
136            let mut receiver = connection.subscribe();
137            let task = cx.spawn(async move |this, cx| {
138                while let Ok(message) = receiver.recv().await {
139                    this.update(cx, |this, cx| {
140                        this.push_stream_message(message, cx);
141                    })
142                    .ok();
143                }
144            });
145
146            self.watched_connection = Some(WatchedConnection {
147                agent_id: active_connection.agent_id.clone(),
148                messages: vec![],
149                list_state: ListState::new(0, ListAlignment::Bottom, px(2048.)),
150                connection: active_connection.connection.clone(),
151                incoming_request_methods: HashMap::default(),
152                outgoing_request_methods: HashMap::default(),
153                _task: task,
154            });
155        }
156    }
157
158    fn push_stream_message(&mut self, stream_message: acp::StreamMessage, cx: &mut Context<Self>) {
159        let Some(connection) = self.watched_connection.as_mut() else {
160            return;
161        };
162        let language_registry = self.project.read(cx).languages().clone();
163        let index = connection.messages.len();
164
165        let (request_id, method, message_type, params) = match stream_message.message {
166            acp::StreamMessageContent::Request { id, method, params } => {
167                let method_map = match stream_message.direction {
168                    acp::StreamMessageDirection::Incoming => {
169                        &mut connection.incoming_request_methods
170                    }
171                    acp::StreamMessageDirection::Outgoing => {
172                        &mut connection.outgoing_request_methods
173                    }
174                };
175
176                method_map.insert(id.clone(), method.clone());
177                (Some(id), method.into(), MessageType::Request, Ok(params))
178            }
179            acp::StreamMessageContent::Response { id, result } => {
180                let method_map = match stream_message.direction {
181                    acp::StreamMessageDirection::Incoming => {
182                        &mut connection.outgoing_request_methods
183                    }
184                    acp::StreamMessageDirection::Outgoing => {
185                        &mut connection.incoming_request_methods
186                    }
187                };
188
189                if let Some(method) = method_map.remove(&id) {
190                    (Some(id), method.into(), MessageType::Response, result)
191                } else {
192                    (
193                        Some(id),
194                        "[unrecognized response]".into(),
195                        MessageType::Response,
196                        result,
197                    )
198                }
199            }
200            acp::StreamMessageContent::Notification { method, params } => {
201                (None, method.into(), MessageType::Notification, Ok(params))
202            }
203        };
204
205        let message = WatchedConnectionMessage {
206            name: method,
207            message_type,
208            request_id,
209            direction: stream_message.direction,
210            collapsed_params_md: match params.as_ref() {
211                Ok(params) => params
212                    .as_ref()
213                    .map(|params| collapsed_params_md(params, &language_registry, cx)),
214                Err(err) => {
215                    if let Ok(err) = &serde_json::to_value(err) {
216                        Some(collapsed_params_md(&err, &language_registry, cx))
217                    } else {
218                        None
219                    }
220                }
221            },
222
223            expanded_params_md: None,
224            params,
225        };
226
227        connection.messages.push(message);
228        connection.list_state.splice(index..index, 1);
229        cx.notify();
230    }
231
232    fn serialize_observed_messages(&self) -> Option<String> {
233        let connection = self.watched_connection.as_ref()?;
234
235        let messages: Vec<serde_json::Value> = connection
236            .messages
237            .iter()
238            .filter_map(|message| {
239                let params = match &message.params {
240                    Ok(Some(params)) => params.clone(),
241                    Ok(None) => serde_json::Value::Null,
242                    Err(err) => serde_json::to_value(err).ok()?,
243                };
244                Some(serde_json::json!({
245                    "_direction": match message.direction {
246                        acp::StreamMessageDirection::Incoming => "incoming",
247                        acp::StreamMessageDirection::Outgoing => "outgoing",
248                    },
249                    "_type": message.message_type.to_string().to_lowercase(),
250                    "id": message.request_id,
251                    "method": message.name.to_string(),
252                    "params": params,
253                }))
254            })
255            .collect();
256
257        serde_json::to_string_pretty(&messages).ok()
258    }
259
260    fn clear_messages(&mut self, cx: &mut Context<Self>) {
261        if let Some(connection) = self.watched_connection.as_mut() {
262            connection.messages.clear();
263            connection.list_state.reset(0);
264            self.expanded.clear();
265            cx.notify();
266        }
267    }
268
269    fn render_message(
270        &mut self,
271        index: usize,
272        window: &mut Window,
273        cx: &mut Context<Self>,
274    ) -> AnyElement {
275        let Some(connection) = self.watched_connection.as_ref() else {
276            return Empty.into_any();
277        };
278
279        let Some(message) = connection.messages.get(index) else {
280            return Empty.into_any();
281        };
282
283        let base_size = TextSize::Editor.rems(cx);
284
285        let theme_settings = ThemeSettings::get_global(cx);
286        let text_style = window.text_style();
287
288        let colors = cx.theme().colors();
289        let expanded = self.expanded.contains(&index);
290
291        v_flex()
292            .id(index)
293            .group("message")
294            .font_buffer(cx)
295            .w_full()
296            .py_3()
297            .pl_4()
298            .pr_5()
299            .gap_2()
300            .items_start()
301            .text_size(base_size)
302            .border_color(colors.border)
303            .border_b_1()
304            .hover(|this| this.bg(colors.element_background.opacity(0.5)))
305            .child(
306                h_flex()
307                    .id(("acp-log-message-header", index))
308                    .w_full()
309                    .gap_2()
310                    .flex_shrink_0()
311                    .cursor_pointer()
312                    .on_click(cx.listener(move |this, _, _, cx| {
313                        if this.expanded.contains(&index) {
314                            this.expanded.remove(&index);
315                        } else {
316                            this.expanded.insert(index);
317                            let Some(connection) = &mut this.watched_connection else {
318                                return;
319                            };
320                            let Some(message) = connection.messages.get_mut(index) else {
321                                return;
322                            };
323                            message.expanded(this.project.read(cx).languages().clone(), cx);
324                            connection.list_state.scroll_to_reveal_item(index);
325                        }
326                        cx.notify()
327                    }))
328                    .child(match message.direction {
329                        acp::StreamMessageDirection::Incoming => Icon::new(IconName::ArrowDown)
330                            .color(Color::Error)
331                            .size(IconSize::Small),
332                        acp::StreamMessageDirection::Outgoing => Icon::new(IconName::ArrowUp)
333                            .color(Color::Success)
334                            .size(IconSize::Small),
335                    })
336                    .child(
337                        Label::new(message.name.clone())
338                            .buffer_font(cx)
339                            .color(Color::Muted),
340                    )
341                    .child(div().flex_1())
342                    .child(
343                        div()
344                            .child(ui::Chip::new(message.message_type.to_string()))
345                            .visible_on_hover("message"),
346                    )
347                    .children(
348                        message
349                            .request_id
350                            .as_ref()
351                            .map(|req_id| div().child(ui::Chip::new(req_id.to_string()))),
352                    ),
353            )
354            // I'm aware using markdown is a hack. Trying to get something working for the demo.
355            // Will clean up soon!
356            .when_some(
357                if expanded {
358                    message.expanded_params_md.clone()
359                } else {
360                    message.collapsed_params_md.clone()
361                },
362                |this, params| {
363                    this.child(
364                        div().pl_6().w_full().child(
365                            MarkdownElement::new(
366                                params,
367                                MarkdownStyle {
368                                    base_text_style: text_style,
369                                    selection_background_color: colors.element_selection_background,
370                                    syntax: cx.theme().syntax().clone(),
371                                    code_block_overflow_x_scroll: true,
372                                    code_block: StyleRefinement {
373                                        text: TextStyleRefinement {
374                                            font_family: Some(
375                                                theme_settings.buffer_font.family.clone(),
376                                            ),
377                                            font_size: Some((base_size * 0.8).into()),
378                                            ..Default::default()
379                                        },
380                                        ..Default::default()
381                                    },
382                                    ..Default::default()
383                                },
384                            )
385                            .code_block_renderer(
386                                CodeBlockRenderer::Default {
387                                    copy_button_visibility: if expanded {
388                                        CopyButtonVisibility::VisibleOnHover
389                                    } else {
390                                        CopyButtonVisibility::Hidden
391                                    },
392                                    border: false,
393                                },
394                            ),
395                        ),
396                    )
397                },
398            )
399            .into_any()
400    }
401}
402
403struct WatchedConnectionMessage {
404    name: SharedString,
405    request_id: Option<acp::RequestId>,
406    direction: acp::StreamMessageDirection,
407    message_type: MessageType,
408    params: Result<Option<serde_json::Value>, acp::Error>,
409    collapsed_params_md: Option<Entity<Markdown>>,
410    expanded_params_md: Option<Entity<Markdown>>,
411}
412
413impl WatchedConnectionMessage {
414    fn expanded(&mut self, language_registry: Arc<LanguageRegistry>, cx: &mut App) {
415        let params_md = match &self.params {
416            Ok(Some(params)) => Some(expanded_params_md(params, &language_registry, cx)),
417            Err(err) => {
418                if let Some(err) = &serde_json::to_value(err).log_err() {
419                    Some(expanded_params_md(&err, &language_registry, cx))
420                } else {
421                    None
422                }
423            }
424            _ => None,
425        };
426        self.expanded_params_md = params_md;
427    }
428}
429
430fn collapsed_params_md(
431    params: &serde_json::Value,
432    language_registry: &Arc<LanguageRegistry>,
433    cx: &mut App,
434) -> Entity<Markdown> {
435    let params_json = serde_json::to_string(params).unwrap_or_default();
436    let mut spaced_out_json = String::with_capacity(params_json.len() + params_json.len() / 4);
437
438    for ch in params_json.chars() {
439        match ch {
440            '{' => spaced_out_json.push_str("{ "),
441            '}' => spaced_out_json.push_str(" }"),
442            ':' => spaced_out_json.push_str(": "),
443            ',' => spaced_out_json.push_str(", "),
444            c => spaced_out_json.push(c),
445        }
446    }
447
448    let params_md = format!("```json\n{}\n```", spaced_out_json);
449    cx.new(|cx| Markdown::new(params_md.into(), Some(language_registry.clone()), None, cx))
450}
451
452fn expanded_params_md(
453    params: &serde_json::Value,
454    language_registry: &Arc<LanguageRegistry>,
455    cx: &mut App,
456) -> Entity<Markdown> {
457    let params_json = serde_json::to_string_pretty(params).unwrap_or_default();
458    let params_md = format!("```json\n{}\n```", params_json);
459    cx.new(|cx| Markdown::new(params_md.into(), Some(language_registry.clone()), None, cx))
460}
461
462enum MessageType {
463    Request,
464    Response,
465    Notification,
466}
467
468impl Display for MessageType {
469    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
470        match self {
471            MessageType::Request => write!(f, "Request"),
472            MessageType::Response => write!(f, "Response"),
473            MessageType::Notification => write!(f, "Notification"),
474        }
475    }
476}
477
478enum AcpToolsEvent {}
479
480impl EventEmitter<AcpToolsEvent> for AcpTools {}
481
482impl Item for AcpTools {
483    type Event = AcpToolsEvent;
484
485    fn tab_content_text(&self, _detail: usize, _cx: &App) -> ui::SharedString {
486        format!(
487            "ACP: {}",
488            self.watched_connection
489                .as_ref()
490                .map_or("Disconnected", |connection| connection.agent_id.0.as_ref())
491        )
492        .into()
493    }
494
495    fn tab_icon(&self, _window: &Window, _cx: &App) -> Option<Icon> {
496        Some(ui::Icon::new(IconName::Thread))
497    }
498}
499
500impl Focusable for AcpTools {
501    fn focus_handle(&self, _cx: &App) -> FocusHandle {
502        self.focus_handle.clone()
503    }
504}
505
506impl Render for AcpTools {
507    fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
508        v_flex()
509            .track_focus(&self.focus_handle)
510            .size_full()
511            .bg(cx.theme().colors().editor_background)
512            .child(match self.watched_connection.as_ref() {
513                Some(connection) => {
514                    if connection.messages.is_empty() {
515                        h_flex()
516                            .size_full()
517                            .justify_center()
518                            .items_center()
519                            .child("No messages recorded yet")
520                            .into_any()
521                    } else {
522                        div()
523                            .size_full()
524                            .flex_grow()
525                            .child(
526                                list(
527                                    connection.list_state.clone(),
528                                    cx.processor(Self::render_message),
529                                )
530                                .with_sizing_behavior(gpui::ListSizingBehavior::Auto)
531                                .size_full(),
532                            )
533                            .vertical_scrollbar_for(&connection.list_state, window, cx)
534                            .into_any()
535                    }
536                }
537                None => h_flex()
538                    .size_full()
539                    .justify_center()
540                    .items_center()
541                    .child("No active connection")
542                    .into_any(),
543            })
544    }
545}
546
547pub struct AcpToolsToolbarItemView {
548    acp_tools: Option<Entity<AcpTools>>,
549}
550
551impl AcpToolsToolbarItemView {
552    pub fn new() -> Self {
553        Self { acp_tools: None }
554    }
555}
556
557impl Render for AcpToolsToolbarItemView {
558    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
559        let Some(acp_tools) = self.acp_tools.as_ref() else {
560            return Empty.into_any_element();
561        };
562
563        let acp_tools = acp_tools.clone();
564        let has_messages = acp_tools
565            .read(cx)
566            .watched_connection
567            .as_ref()
568            .is_some_and(|connection| !connection.messages.is_empty());
569
570        h_flex()
571            .gap_2()
572            .child({
573                let message = acp_tools
574                    .read(cx)
575                    .serialize_observed_messages()
576                    .unwrap_or_default();
577
578                CopyButton::new("copy-all-messages", message)
579                    .tooltip_label("Copy All Messages")
580                    .disabled(!has_messages)
581            })
582            .child(
583                IconButton::new("clear_messages", IconName::Trash)
584                    .icon_size(IconSize::Small)
585                    .tooltip(Tooltip::text("Clear Messages"))
586                    .disabled(!has_messages)
587                    .on_click(cx.listener(move |_this, _, _window, cx| {
588                        acp_tools.update(cx, |acp_tools, cx| {
589                            acp_tools.clear_messages(cx);
590                        });
591                    })),
592            )
593            .into_any()
594    }
595}
596
597impl EventEmitter<ToolbarItemEvent> for AcpToolsToolbarItemView {}
598
599impl ToolbarItemView for AcpToolsToolbarItemView {
600    fn set_active_pane_item(
601        &mut self,
602        active_pane_item: Option<&dyn ItemHandle>,
603        _window: &mut Window,
604        cx: &mut Context<Self>,
605    ) -> ToolbarItemLocation {
606        if let Some(item) = active_pane_item
607            && let Some(acp_tools) = item.downcast::<AcpTools>()
608        {
609            self.acp_tools = Some(acp_tools);
610            cx.notify();
611            return ToolbarItemLocation::PrimaryRight;
612        }
613        if self.acp_tools.take().is_some() {
614            cx.notify();
615        }
616        ToolbarItemLocation::Hidden
617    }
618}