acp_tools.rs

  1use std::{collections::HashSet, fmt::Display, sync::Arc};
  2
  3use agent_client_protocol::schema as acp;
  4use collections::HashMap;
  5use gpui::{
  6    App, Empty, Entity, EventEmitter, FocusHandle, Focusable, Global, ListAlignment, ListState,
  7    StyleRefinement, Subscription, Task, TextStyleRefinement, Window, actions, list, prelude::*,
  8};
  9use language::LanguageRegistry;
 10use markdown::{CodeBlockRenderer, CopyButtonVisibility, Markdown, MarkdownElement, MarkdownStyle};
 11use project::{AgentId, Project};
 12use settings::Settings;
 13use theme_settings::ThemeSettings;
 14use ui::{CopyButton, Tooltip, WithScrollbar, prelude::*};
 15use util::ResultExt as _;
 16use workspace::{
 17    Item, ItemHandle, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, Workspace,
 18};
 19
 20pub type RequestId = serde_json::Value;
 21
 22#[derive(Clone)]
 23pub enum StreamMessageDirection {
 24    Incoming,
 25    Outgoing,
 26}
 27
 28#[derive(Clone)]
 29pub enum StreamMessageContent {
 30    Request {
 31        id: RequestId,
 32        method: Arc<str>,
 33        params: Option<serde_json::Value>,
 34    },
 35    Response {
 36        id: RequestId,
 37        result: Result<Option<serde_json::Value>, acp::Error>,
 38    },
 39    Notification {
 40        method: Arc<str>,
 41        params: Option<serde_json::Value>,
 42    },
 43}
 44
 45#[derive(Clone)]
 46pub struct StreamMessage {
 47    pub direction: StreamMessageDirection,
 48    pub message: StreamMessageContent,
 49}
 50
 51impl StreamMessage {
 52    pub fn from_json_line(direction: StreamMessageDirection, line: &str) -> Option<Self> {
 53        let value: serde_json::Value = serde_json::from_str(line).ok()?;
 54        let obj = value.as_object()?;
 55
 56        let message = if let Some(method) = obj.get("method").and_then(|m| m.as_str()) {
 57            if let Some(id) = obj.get("id") {
 58                StreamMessageContent::Request {
 59                    id: id.clone(),
 60                    method: method.into(),
 61                    params: obj.get("params").cloned(),
 62                }
 63            } else {
 64                StreamMessageContent::Notification {
 65                    method: method.into(),
 66                    params: obj.get("params").cloned(),
 67                }
 68            }
 69        } else if let Some(id) = obj.get("id") {
 70            if let Some(error) = obj.get("error") {
 71                let acp_err =
 72                    serde_json::from_value::<acp::Error>(error.clone()).unwrap_or_else(|err| {
 73                        log::warn!("Failed to deserialize ACP error: {err}");
 74                        acp::Error::internal_error().data(error.to_string())
 75                    });
 76                StreamMessageContent::Response {
 77                    id: id.clone(),
 78                    result: Err(acp_err),
 79                }
 80            } else {
 81                StreamMessageContent::Response {
 82                    id: id.clone(),
 83                    result: Ok(obj.get("result").cloned()),
 84                }
 85            }
 86        } else {
 87            return None;
 88        };
 89
 90        Some(StreamMessage { direction, message })
 91    }
 92}
 93
 94actions!(dev, [OpenAcpLogs]);
 95
 96pub fn init(cx: &mut App) {
 97    cx.observe_new(
 98        |workspace: &mut Workspace, _window, _cx: &mut Context<Workspace>| {
 99            workspace.register_action(|workspace, _: &OpenAcpLogs, window, cx| {
100                let acp_tools =
101                    Box::new(cx.new(|cx| AcpTools::new(workspace.project().clone(), cx)));
102                workspace.add_item_to_active_pane(acp_tools, None, true, window, cx);
103            });
104        },
105    )
106    .detach();
107}
108
109struct GlobalAcpConnectionRegistry(Entity<AcpConnectionRegistry>);
110
111impl Global for GlobalAcpConnectionRegistry {}
112
113/// A raw JSON-RPC line captured from the transport, tagged with direction.
114/// Deserialization into [`StreamMessage`] is deferred until a subscriber is listening.
115pub struct RawStreamLine {
116    pub direction: StreamMessageDirection,
117    pub line: Arc<str>,
118}
119
120#[derive(Default)]
121pub struct AcpConnectionRegistry {
122    active_agent_id: Option<AgentId>,
123    generation: u64,
124    subscribers: Vec<smol::channel::Sender<StreamMessage>>,
125    _broadcast_task: Option<Task<()>>,
126}
127
128impl AcpConnectionRegistry {
129    pub fn default_global(cx: &mut App) -> Entity<Self> {
130        if cx.has_global::<GlobalAcpConnectionRegistry>() {
131            cx.global::<GlobalAcpConnectionRegistry>().0.clone()
132        } else {
133            let registry = cx.new(|_cx| AcpConnectionRegistry::default());
134            cx.set_global(GlobalAcpConnectionRegistry(registry.clone()));
135            registry
136        }
137    }
138
139    pub fn set_active_connection(
140        &mut self,
141        agent_id: AgentId,
142        raw_rx: smol::channel::Receiver<RawStreamLine>,
143        cx: &mut Context<Self>,
144    ) {
145        self.active_agent_id = Some(agent_id);
146        self.generation += 1;
147        self.subscribers.clear();
148
149        self._broadcast_task = Some(cx.spawn(async move |this, cx| {
150            while let Ok(raw) = raw_rx.recv().await {
151                this.update(cx, |this, _cx| {
152                    if this.subscribers.is_empty() {
153                        return;
154                    }
155
156                    let Some(message) = StreamMessage::from_json_line(raw.direction, &raw.line)
157                    else {
158                        return;
159                    };
160
161                    this.subscribers.retain(|sender| !sender.is_closed());
162                    for sender in &this.subscribers {
163                        sender.try_send(message.clone()).log_err();
164                    }
165                })
166                .log_err();
167            }
168
169            // The transport closed — clear state so observers (e.g. the ACP
170            // logs tab) can transition back to the disconnected state.
171            this.update(cx, |this, cx| {
172                this.active_agent_id = None;
173                this.subscribers.clear();
174                cx.notify();
175            })
176            .log_err();
177        }));
178
179        cx.notify();
180    }
181
182    pub fn subscribe(&mut self) -> smol::channel::Receiver<StreamMessage> {
183        let (sender, receiver) = smol::channel::unbounded();
184        self.subscribers.push(sender);
185        receiver
186    }
187}
188
189struct AcpTools {
190    project: Entity<Project>,
191    focus_handle: FocusHandle,
192    expanded: HashSet<usize>,
193    watched_connection: Option<WatchedConnection>,
194    connection_registry: Entity<AcpConnectionRegistry>,
195    _subscription: Subscription,
196}
197
198struct WatchedConnection {
199    agent_id: AgentId,
200    generation: u64,
201    messages: Vec<WatchedConnectionMessage>,
202    list_state: ListState,
203    incoming_request_methods: HashMap<RequestId, Arc<str>>,
204    outgoing_request_methods: HashMap<RequestId, Arc<str>>,
205    _task: Task<()>,
206}
207
208impl AcpTools {
209    fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
210        let connection_registry = AcpConnectionRegistry::default_global(cx);
211
212        let subscription = cx.observe(&connection_registry, |this, _, cx| {
213            this.update_connection(cx);
214            cx.notify();
215        });
216
217        let mut this = Self {
218            project,
219            focus_handle: cx.focus_handle(),
220            expanded: HashSet::default(),
221            watched_connection: None,
222            connection_registry,
223            _subscription: subscription,
224        };
225        this.update_connection(cx);
226        this
227    }
228
229    fn update_connection(&mut self, cx: &mut Context<Self>) {
230        let (generation, agent_id) = {
231            let registry = self.connection_registry.read(cx);
232            (registry.generation, registry.active_agent_id.clone())
233        };
234
235        let Some(agent_id) = agent_id else {
236            self.watched_connection = None;
237            return;
238        };
239
240        if let Some(watched) = self.watched_connection.as_ref() {
241            if watched.generation == generation {
242                return;
243            }
244        }
245
246        let messages_rx = self
247            .connection_registry
248            .update(cx, |registry, _cx| registry.subscribe());
249
250        let task = cx.spawn(async move |this, cx| {
251            while let Ok(message) = messages_rx.recv().await {
252                this.update(cx, |this, cx| {
253                    this.push_stream_message(message, cx);
254                })
255                .log_err();
256            }
257        });
258
259        self.watched_connection = Some(WatchedConnection {
260            agent_id,
261            generation,
262            messages: vec![],
263            list_state: ListState::new(0, ListAlignment::Bottom, px(2048.)),
264            incoming_request_methods: HashMap::default(),
265            outgoing_request_methods: HashMap::default(),
266            _task: task,
267        });
268    }
269
270    fn push_stream_message(&mut self, stream_message: StreamMessage, cx: &mut Context<Self>) {
271        let Some(connection) = self.watched_connection.as_mut() else {
272            return;
273        };
274        let language_registry = self.project.read(cx).languages().clone();
275        let index = connection.messages.len();
276
277        let (request_id, method, message_type, params) = match stream_message.message {
278            StreamMessageContent::Request { id, method, params } => {
279                let method_map = match stream_message.direction {
280                    StreamMessageDirection::Incoming => &mut connection.incoming_request_methods,
281                    StreamMessageDirection::Outgoing => &mut connection.outgoing_request_methods,
282                };
283
284                method_map.insert(id.clone(), method.clone());
285                (Some(id), method.into(), MessageType::Request, Ok(params))
286            }
287            StreamMessageContent::Response { id, result } => {
288                let method_map = match stream_message.direction {
289                    StreamMessageDirection::Incoming => &mut connection.outgoing_request_methods,
290                    StreamMessageDirection::Outgoing => &mut connection.incoming_request_methods,
291                };
292
293                if let Some(method) = method_map.remove(&id) {
294                    (Some(id), method.into(), MessageType::Response, result)
295                } else {
296                    (
297                        Some(id),
298                        "[unrecognized response]".into(),
299                        MessageType::Response,
300                        result,
301                    )
302                }
303            }
304            StreamMessageContent::Notification { method, params } => {
305                (None, method.into(), MessageType::Notification, Ok(params))
306            }
307        };
308
309        let message = WatchedConnectionMessage {
310            name: method,
311            message_type,
312            request_id,
313            direction: stream_message.direction,
314            collapsed_params_md: match params.as_ref() {
315                Ok(params) => params
316                    .as_ref()
317                    .map(|params| collapsed_params_md(params, &language_registry, cx)),
318                Err(err) => {
319                    if let Ok(err) = &serde_json::to_value(err) {
320                        Some(collapsed_params_md(&err, &language_registry, cx))
321                    } else {
322                        None
323                    }
324                }
325            },
326
327            expanded_params_md: None,
328            params,
329        };
330
331        connection.messages.push(message);
332        connection.list_state.splice(index..index, 1);
333        cx.notify();
334    }
335
336    fn serialize_observed_messages(&self) -> Option<String> {
337        let connection = self.watched_connection.as_ref()?;
338
339        let messages: Vec<serde_json::Value> = connection
340            .messages
341            .iter()
342            .filter_map(|message| {
343                let params = match &message.params {
344                    Ok(Some(params)) => params.clone(),
345                    Ok(None) => serde_json::Value::Null,
346                    Err(err) => serde_json::to_value(err).ok()?,
347                };
348                Some(serde_json::json!({
349                    "_direction": match message.direction {
350                        StreamMessageDirection::Incoming => "incoming",
351                        StreamMessageDirection::Outgoing => "outgoing",
352                    },
353                    "_type": message.message_type.to_string().to_lowercase(),
354                    "id": message.request_id,
355                    "method": message.name.to_string(),
356                    "params": params,
357                }))
358            })
359            .collect();
360
361        serde_json::to_string_pretty(&messages).ok()
362    }
363
364    fn clear_messages(&mut self, cx: &mut Context<Self>) {
365        if let Some(connection) = self.watched_connection.as_mut() {
366            connection.messages.clear();
367            connection.list_state.reset(0);
368            self.expanded.clear();
369            cx.notify();
370        }
371    }
372
373    fn render_message(
374        &mut self,
375        index: usize,
376        window: &mut Window,
377        cx: &mut Context<Self>,
378    ) -> AnyElement {
379        let Some(connection) = self.watched_connection.as_ref() else {
380            return Empty.into_any();
381        };
382
383        let Some(message) = connection.messages.get(index) else {
384            return Empty.into_any();
385        };
386
387        let base_size = TextSize::Editor.rems(cx);
388
389        let theme_settings = ThemeSettings::get_global(cx);
390        let text_style = window.text_style();
391
392        let colors = cx.theme().colors();
393        let expanded = self.expanded.contains(&index);
394
395        v_flex()
396            .id(index)
397            .group("message")
398            .font_buffer(cx)
399            .w_full()
400            .py_3()
401            .pl_4()
402            .pr_5()
403            .gap_2()
404            .items_start()
405            .text_size(base_size)
406            .border_color(colors.border)
407            .border_b_1()
408            .hover(|this| this.bg(colors.element_background.opacity(0.5)))
409            .child(
410                h_flex()
411                    .id(("acp-log-message-header", index))
412                    .w_full()
413                    .gap_2()
414                    .flex_shrink_0()
415                    .cursor_pointer()
416                    .on_click(cx.listener(move |this, _, _, cx| {
417                        if this.expanded.contains(&index) {
418                            this.expanded.remove(&index);
419                        } else {
420                            this.expanded.insert(index);
421                            let Some(connection) = &mut this.watched_connection else {
422                                return;
423                            };
424                            let Some(message) = connection.messages.get_mut(index) else {
425                                return;
426                            };
427                            message.expanded(this.project.read(cx).languages().clone(), cx);
428                            connection.list_state.scroll_to_reveal_item(index);
429                        }
430                        cx.notify()
431                    }))
432                    .child(match message.direction {
433                        StreamMessageDirection::Incoming => Icon::new(IconName::ArrowDown)
434                            .color(Color::Error)
435                            .size(IconSize::Small),
436                        StreamMessageDirection::Outgoing => Icon::new(IconName::ArrowUp)
437                            .color(Color::Success)
438                            .size(IconSize::Small),
439                    })
440                    .child(
441                        Label::new(message.name.clone())
442                            .buffer_font(cx)
443                            .color(Color::Muted),
444                    )
445                    .child(div().flex_1())
446                    .child(
447                        div()
448                            .child(ui::Chip::new(message.message_type.to_string()))
449                            .visible_on_hover("message"),
450                    )
451                    .children(
452                        message
453                            .request_id
454                            .as_ref()
455                            .map(|req_id| div().child(ui::Chip::new(req_id.to_string()))),
456                    ),
457            )
458            // I'm aware using markdown is a hack. Trying to get something working for the demo.
459            // Will clean up soon!
460            .when_some(
461                if expanded {
462                    message.expanded_params_md.clone()
463                } else {
464                    message.collapsed_params_md.clone()
465                },
466                |this, params| {
467                    this.child(
468                        div().pl_6().w_full().child(
469                            MarkdownElement::new(
470                                params,
471                                MarkdownStyle {
472                                    base_text_style: text_style,
473                                    selection_background_color: colors.element_selection_background,
474                                    syntax: cx.theme().syntax().clone(),
475                                    code_block_overflow_x_scroll: true,
476                                    code_block: StyleRefinement {
477                                        text: TextStyleRefinement {
478                                            font_family: Some(
479                                                theme_settings.buffer_font.family.clone(),
480                                            ),
481                                            font_size: Some((base_size * 0.8).into()),
482                                            ..Default::default()
483                                        },
484                                        ..Default::default()
485                                    },
486                                    ..Default::default()
487                                },
488                            )
489                            .code_block_renderer(
490                                CodeBlockRenderer::Default {
491                                    copy_button_visibility: if expanded {
492                                        CopyButtonVisibility::VisibleOnHover
493                                    } else {
494                                        CopyButtonVisibility::Hidden
495                                    },
496                                    border: false,
497                                },
498                            ),
499                        ),
500                    )
501                },
502            )
503            .into_any()
504    }
505}
506
507struct WatchedConnectionMessage {
508    name: SharedString,
509    request_id: Option<RequestId>,
510    direction: StreamMessageDirection,
511    message_type: MessageType,
512    params: Result<Option<serde_json::Value>, acp::Error>,
513    collapsed_params_md: Option<Entity<Markdown>>,
514    expanded_params_md: Option<Entity<Markdown>>,
515}
516
517impl WatchedConnectionMessage {
518    fn expanded(&mut self, language_registry: Arc<LanguageRegistry>, cx: &mut App) {
519        let params_md = match &self.params {
520            Ok(Some(params)) => Some(expanded_params_md(params, &language_registry, cx)),
521            Err(err) => {
522                if let Some(err) = &serde_json::to_value(err).log_err() {
523                    Some(expanded_params_md(&err, &language_registry, cx))
524                } else {
525                    None
526                }
527            }
528            _ => None,
529        };
530        self.expanded_params_md = params_md;
531    }
532}
533
534fn collapsed_params_md(
535    params: &serde_json::Value,
536    language_registry: &Arc<LanguageRegistry>,
537    cx: &mut App,
538) -> Entity<Markdown> {
539    let params_json = serde_json::to_string(params).unwrap_or_default();
540    let mut spaced_out_json = String::with_capacity(params_json.len() + params_json.len() / 4);
541
542    for ch in params_json.chars() {
543        match ch {
544            '{' => spaced_out_json.push_str("{ "),
545            '}' => spaced_out_json.push_str(" }"),
546            ':' => spaced_out_json.push_str(": "),
547            ',' => spaced_out_json.push_str(", "),
548            c => spaced_out_json.push(c),
549        }
550    }
551
552    let params_md = format!("```json\n{}\n```", spaced_out_json);
553    cx.new(|cx| Markdown::new(params_md.into(), Some(language_registry.clone()), None, cx))
554}
555
556fn expanded_params_md(
557    params: &serde_json::Value,
558    language_registry: &Arc<LanguageRegistry>,
559    cx: &mut App,
560) -> Entity<Markdown> {
561    let params_json = serde_json::to_string_pretty(params).unwrap_or_default();
562    let params_md = format!("```json\n{}\n```", params_json);
563    cx.new(|cx| Markdown::new(params_md.into(), Some(language_registry.clone()), None, cx))
564}
565
566enum MessageType {
567    Request,
568    Response,
569    Notification,
570}
571
572impl Display for MessageType {
573    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
574        match self {
575            MessageType::Request => write!(f, "Request"),
576            MessageType::Response => write!(f, "Response"),
577            MessageType::Notification => write!(f, "Notification"),
578        }
579    }
580}
581
582enum AcpToolsEvent {}
583
584impl EventEmitter<AcpToolsEvent> for AcpTools {}
585
586impl Item for AcpTools {
587    type Event = AcpToolsEvent;
588
589    fn tab_content_text(&self, _detail: usize, _cx: &App) -> ui::SharedString {
590        format!(
591            "ACP: {}",
592            self.watched_connection
593                .as_ref()
594                .map_or("Disconnected", |connection| connection.agent_id.0.as_ref())
595        )
596        .into()
597    }
598
599    fn tab_icon(&self, _window: &Window, _cx: &App) -> Option<Icon> {
600        Some(ui::Icon::new(IconName::Thread))
601    }
602}
603
604impl Focusable for AcpTools {
605    fn focus_handle(&self, _cx: &App) -> FocusHandle {
606        self.focus_handle.clone()
607    }
608}
609
610impl Render for AcpTools {
611    fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
612        v_flex()
613            .track_focus(&self.focus_handle)
614            .size_full()
615            .bg(cx.theme().colors().editor_background)
616            .child(match self.watched_connection.as_ref() {
617                Some(connection) => {
618                    if connection.messages.is_empty() {
619                        h_flex()
620                            .size_full()
621                            .justify_center()
622                            .items_center()
623                            .child("No messages recorded yet")
624                            .into_any()
625                    } else {
626                        div()
627                            .size_full()
628                            .flex_grow()
629                            .child(
630                                list(
631                                    connection.list_state.clone(),
632                                    cx.processor(Self::render_message),
633                                )
634                                .with_sizing_behavior(gpui::ListSizingBehavior::Auto)
635                                .size_full(),
636                            )
637                            .vertical_scrollbar_for(&connection.list_state, window, cx)
638                            .into_any()
639                    }
640                }
641                None => h_flex()
642                    .size_full()
643                    .justify_center()
644                    .items_center()
645                    .child("No active connection")
646                    .into_any(),
647            })
648    }
649}
650
651pub struct AcpToolsToolbarItemView {
652    acp_tools: Option<Entity<AcpTools>>,
653}
654
655impl AcpToolsToolbarItemView {
656    pub fn new() -> Self {
657        Self { acp_tools: None }
658    }
659}
660
661impl Render for AcpToolsToolbarItemView {
662    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
663        let Some(acp_tools) = self.acp_tools.as_ref() else {
664            return Empty.into_any_element();
665        };
666
667        let acp_tools = acp_tools.clone();
668        let has_messages = acp_tools
669            .read(cx)
670            .watched_connection
671            .as_ref()
672            .is_some_and(|connection| !connection.messages.is_empty());
673
674        h_flex()
675            .gap_2()
676            .child({
677                let message = acp_tools
678                    .read(cx)
679                    .serialize_observed_messages()
680                    .unwrap_or_default();
681
682                CopyButton::new("copy-all-messages", message)
683                    .tooltip_label("Copy All Messages")
684                    .disabled(!has_messages)
685            })
686            .child(
687                IconButton::new("clear_messages", IconName::Trash)
688                    .icon_size(IconSize::Small)
689                    .tooltip(Tooltip::text("Clear Messages"))
690                    .disabled(!has_messages)
691                    .on_click(cx.listener(move |_this, _, _window, cx| {
692                        acp_tools.update(cx, |acp_tools, cx| {
693                            acp_tools.clear_messages(cx);
694                        });
695                    })),
696            )
697            .into_any()
698    }
699}
700
701impl EventEmitter<ToolbarItemEvent> for AcpToolsToolbarItemView {}
702
703impl ToolbarItemView for AcpToolsToolbarItemView {
704    fn set_active_pane_item(
705        &mut self,
706        active_pane_item: Option<&dyn ItemHandle>,
707        _window: &mut Window,
708        cx: &mut Context<Self>,
709    ) -> ToolbarItemLocation {
710        if let Some(item) = active_pane_item
711            && let Some(acp_tools) = item.downcast::<AcpTools>()
712        {
713            self.acp_tools = Some(acp_tools);
714            cx.notify();
715            return ToolbarItemLocation::PrimaryRight;
716        }
717        if self.acp_tools.take().is_some() {
718            cx.notify();
719        }
720        ToolbarItemLocation::Hidden
721    }
722}