acp_tools.rs

  1use std::{
  2    cell::RefCell,
  3    collections::HashSet,
  4    fmt::Display,
  5    rc::{Rc, Weak},
  6    sync::Arc,
  7};
  8
  9use agent_client_protocol as acp;
 10use collections::HashMap;
 11use gpui::{
 12    App, Empty, Entity, EventEmitter, FocusHandle, Focusable, Global, ListAlignment, ListState,
 13    StyleRefinement, Subscription, Task, TextStyleRefinement, Window, actions, list, prelude::*,
 14};
 15use language::LanguageRegistry;
 16use markdown::{CodeBlockRenderer, Markdown, MarkdownElement, MarkdownStyle};
 17use project::Project;
 18use settings::Settings;
 19use theme::ThemeSettings;
 20use ui::prelude::*;
 21use util::ResultExt as _;
 22use workspace::{Item, Workspace};
 23
 24actions!(dev, [OpenAcpLogs]);
 25
 26pub fn init(cx: &mut App) {
 27    cx.observe_new(
 28        |workspace: &mut Workspace, _window, _cx: &mut Context<Workspace>| {
 29            workspace.register_action(|workspace, _: &OpenAcpLogs, window, cx| {
 30                let acp_tools =
 31                    Box::new(cx.new(|cx| AcpTools::new(workspace.project().clone(), cx)));
 32                workspace.add_item_to_active_pane(acp_tools, None, true, window, cx);
 33            });
 34        },
 35    )
 36    .detach();
 37}
 38
 39struct GlobalAcpConnectionRegistry(Entity<AcpConnectionRegistry>);
 40
 41impl Global for GlobalAcpConnectionRegistry {}
 42
 43#[derive(Default)]
 44pub struct AcpConnectionRegistry {
 45    active_connection: RefCell<Option<ActiveConnection>>,
 46}
 47
 48struct ActiveConnection {
 49    server_name: SharedString,
 50    connection: Weak<acp::ClientSideConnection>,
 51}
 52
 53impl AcpConnectionRegistry {
 54    pub fn default_global(cx: &mut App) -> Entity<Self> {
 55        if cx.has_global::<GlobalAcpConnectionRegistry>() {
 56            cx.global::<GlobalAcpConnectionRegistry>().0.clone()
 57        } else {
 58            let registry = cx.new(|_cx| AcpConnectionRegistry::default());
 59            cx.set_global(GlobalAcpConnectionRegistry(registry.clone()));
 60            registry
 61        }
 62    }
 63
 64    pub fn set_active_connection(
 65        &self,
 66        server_name: impl Into<SharedString>,
 67        connection: &Rc<acp::ClientSideConnection>,
 68        cx: &mut Context<Self>,
 69    ) {
 70        self.active_connection.replace(Some(ActiveConnection {
 71            server_name: server_name.into(),
 72            connection: Rc::downgrade(connection),
 73        }));
 74        cx.notify();
 75    }
 76}
 77
 78struct AcpTools {
 79    project: Entity<Project>,
 80    focus_handle: FocusHandle,
 81    expanded: HashSet<usize>,
 82    watched_connection: Option<WatchedConnection>,
 83    connection_registry: Entity<AcpConnectionRegistry>,
 84    _subscription: Subscription,
 85}
 86
 87struct WatchedConnection {
 88    server_name: SharedString,
 89    messages: Vec<WatchedConnectionMessage>,
 90    list_state: ListState,
 91    connection: Weak<acp::ClientSideConnection>,
 92    incoming_request_methods: HashMap<i32, Arc<str>>,
 93    outgoing_request_methods: HashMap<i32, Arc<str>>,
 94    _task: Task<()>,
 95}
 96
 97impl AcpTools {
 98    fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
 99        let connection_registry = AcpConnectionRegistry::default_global(cx);
100
101        let subscription = cx.observe(&connection_registry, |this, _, cx| {
102            this.update_connection(cx);
103            cx.notify();
104        });
105
106        let mut this = Self {
107            project,
108            focus_handle: cx.focus_handle(),
109            expanded: HashSet::default(),
110            watched_connection: None,
111            connection_registry,
112            _subscription: subscription,
113        };
114        this.update_connection(cx);
115        this
116    }
117
118    fn update_connection(&mut self, cx: &mut Context<Self>) {
119        let active_connection = self.connection_registry.read(cx).active_connection.borrow();
120        let Some(active_connection) = active_connection.as_ref() else {
121            return;
122        };
123
124        if let Some(watched_connection) = self.watched_connection.as_ref() {
125            if Weak::ptr_eq(
126                &watched_connection.connection,
127                &active_connection.connection,
128            ) {
129                return;
130            }
131        }
132
133        if let Some(connection) = active_connection.connection.upgrade() {
134            let mut receiver = connection.subscribe();
135            let task = cx.spawn(async move |this, cx| {
136                while let Ok(message) = receiver.recv().await {
137                    this.update(cx, |this, cx| {
138                        this.push_stream_message(message, cx);
139                    })
140                    .ok();
141                }
142            });
143
144            self.watched_connection = Some(WatchedConnection {
145                server_name: active_connection.server_name.clone(),
146                messages: vec![],
147                list_state: ListState::new(0, ListAlignment::Bottom, px(2048.)),
148                connection: active_connection.connection.clone(),
149                incoming_request_methods: HashMap::default(),
150                outgoing_request_methods: HashMap::default(),
151                _task: task,
152            });
153        }
154    }
155
156    fn push_stream_message(&mut self, stream_message: acp::StreamMessage, cx: &mut Context<Self>) {
157        let Some(connection) = self.watched_connection.as_mut() else {
158            return;
159        };
160        let language_registry = self.project.read(cx).languages().clone();
161        let index = connection.messages.len();
162
163        let (request_id, method, message_type, params) = match stream_message.message {
164            acp::StreamMessageContent::Request { id, method, params } => {
165                let method_map = match stream_message.direction {
166                    acp::StreamMessageDirection::Incoming => {
167                        &mut connection.incoming_request_methods
168                    }
169                    acp::StreamMessageDirection::Outgoing => {
170                        &mut connection.outgoing_request_methods
171                    }
172                };
173
174                method_map.insert(id, method.clone());
175                (Some(id), method.into(), MessageType::Request, Ok(params))
176            }
177            acp::StreamMessageContent::Response { id, result } => {
178                let method_map = match stream_message.direction {
179                    acp::StreamMessageDirection::Incoming => {
180                        &mut connection.outgoing_request_methods
181                    }
182                    acp::StreamMessageDirection::Outgoing => {
183                        &mut connection.incoming_request_methods
184                    }
185                };
186
187                if let Some(method) = method_map.remove(&id) {
188                    (Some(id), method.into(), MessageType::Response, result)
189                } else {
190                    (
191                        Some(id),
192                        "[unrecognized response]".into(),
193                        MessageType::Response,
194                        result,
195                    )
196                }
197            }
198            acp::StreamMessageContent::Notification { method, params } => {
199                (None, method.into(), MessageType::Notification, Ok(params))
200            }
201        };
202
203        let message = WatchedConnectionMessage {
204            name: method,
205            message_type,
206            request_id,
207            direction: stream_message.direction,
208            collapsed_params_md: match params.as_ref() {
209                Ok(params) => params
210                    .as_ref()
211                    .map(|params| collapsed_params_md(params, &language_registry, cx)),
212                Err(err) => {
213                    if let Ok(err) = &serde_json::to_value(err) {
214                        Some(collapsed_params_md(&err, &language_registry, cx))
215                    } else {
216                        None
217                    }
218                }
219            },
220
221            expanded_params_md: None,
222            params,
223        };
224
225        connection.messages.push(message);
226        connection.list_state.splice(index..index, 1);
227        cx.notify();
228    }
229
230    fn render_message(
231        &mut self,
232        index: usize,
233        window: &mut Window,
234        cx: &mut Context<Self>,
235    ) -> AnyElement {
236        let Some(connection) = self.watched_connection.as_ref() else {
237            return Empty.into_any();
238        };
239
240        let Some(message) = connection.messages.get(index) else {
241            return Empty.into_any();
242        };
243
244        let base_size = TextSize::Editor.rems(cx);
245
246        let theme_settings = ThemeSettings::get_global(cx);
247        let text_style = window.text_style();
248
249        let colors = cx.theme().colors();
250        let expanded = self.expanded.contains(&index);
251
252        v_flex()
253            .w_full()
254            .px_4()
255            .py_3()
256            .border_color(colors.border)
257            .border_b_1()
258            .gap_2()
259            .items_start()
260            .font_buffer(cx)
261            .text_size(base_size)
262            .id(index)
263            .group("message")
264            .hover(|this| this.bg(colors.element_background.opacity(0.5)))
265            .on_click(cx.listener(move |this, _, _, cx| {
266                if this.expanded.contains(&index) {
267                    this.expanded.remove(&index);
268                } else {
269                    this.expanded.insert(index);
270                    let Some(connection) = &mut this.watched_connection else {
271                        return;
272                    };
273                    let Some(message) = connection.messages.get_mut(index) else {
274                        return;
275                    };
276                    message.expanded(this.project.read(cx).languages().clone(), cx);
277                    connection.list_state.scroll_to_reveal_item(index);
278                }
279                cx.notify()
280            }))
281            .child(
282                h_flex()
283                    .w_full()
284                    .gap_2()
285                    .items_center()
286                    .flex_shrink_0()
287                    .child(match message.direction {
288                        acp::StreamMessageDirection::Incoming => {
289                            ui::Icon::new(ui::IconName::ArrowDown).color(Color::Error)
290                        }
291                        acp::StreamMessageDirection::Outgoing => {
292                            ui::Icon::new(ui::IconName::ArrowUp).color(Color::Success)
293                        }
294                    })
295                    .child(
296                        Label::new(message.name.clone())
297                            .buffer_font(cx)
298                            .color(Color::Muted),
299                    )
300                    .child(div().flex_1())
301                    .child(
302                        div()
303                            .child(ui::Chip::new(message.message_type.to_string()))
304                            .visible_on_hover("message"),
305                    )
306                    .children(
307                        message
308                            .request_id
309                            .map(|req_id| div().child(ui::Chip::new(req_id.to_string()))),
310                    ),
311            )
312            // I'm aware using markdown is a hack. Trying to get something working for the demo.
313            // Will clean up soon!
314            .when_some(
315                if expanded {
316                    message.expanded_params_md.clone()
317                } else {
318                    message.collapsed_params_md.clone()
319                },
320                |this, params| {
321                    this.child(
322                        div().pl_6().w_full().child(
323                            MarkdownElement::new(
324                                params,
325                                MarkdownStyle {
326                                    base_text_style: text_style,
327                                    selection_background_color: colors.element_selection_background,
328                                    syntax: cx.theme().syntax().clone(),
329                                    code_block_overflow_x_scroll: true,
330                                    code_block: StyleRefinement {
331                                        text: Some(TextStyleRefinement {
332                                            font_family: Some(
333                                                theme_settings.buffer_font.family.clone(),
334                                            ),
335                                            font_size: Some((base_size * 0.8).into()),
336                                            ..Default::default()
337                                        }),
338                                        ..Default::default()
339                                    },
340                                    ..Default::default()
341                                },
342                            )
343                            .code_block_renderer(
344                                CodeBlockRenderer::Default {
345                                    copy_button: false,
346                                    copy_button_on_hover: expanded,
347                                    border: false,
348                                },
349                            ),
350                        ),
351                    )
352                },
353            )
354            .into_any()
355    }
356}
357
358struct WatchedConnectionMessage {
359    name: SharedString,
360    request_id: Option<i32>,
361    direction: acp::StreamMessageDirection,
362    message_type: MessageType,
363    params: Result<Option<serde_json::Value>, acp::Error>,
364    collapsed_params_md: Option<Entity<Markdown>>,
365    expanded_params_md: Option<Entity<Markdown>>,
366}
367
368impl WatchedConnectionMessage {
369    fn expanded(&mut self, language_registry: Arc<LanguageRegistry>, cx: &mut App) {
370        let params_md = match &self.params {
371            Ok(Some(params)) => Some(expanded_params_md(params, &language_registry, cx)),
372            Err(err) => {
373                if let Some(err) = &serde_json::to_value(err).log_err() {
374                    Some(expanded_params_md(&err, &language_registry, cx))
375                } else {
376                    None
377                }
378            }
379            _ => None,
380        };
381        self.expanded_params_md = params_md;
382    }
383}
384
385fn collapsed_params_md(
386    params: &serde_json::Value,
387    language_registry: &Arc<LanguageRegistry>,
388    cx: &mut App,
389) -> Entity<Markdown> {
390    let params_json = serde_json::to_string(params).unwrap_or_default();
391    let mut spaced_out_json = String::with_capacity(params_json.len() + params_json.len() / 4);
392
393    for ch in params_json.chars() {
394        match ch {
395            '{' => spaced_out_json.push_str("{ "),
396            '}' => spaced_out_json.push_str(" }"),
397            ':' => spaced_out_json.push_str(": "),
398            ',' => spaced_out_json.push_str(", "),
399            c => spaced_out_json.push(c),
400        }
401    }
402
403    let params_md = format!("```json\n{}\n```", spaced_out_json);
404    cx.new(|cx| Markdown::new(params_md.into(), Some(language_registry.clone()), None, cx))
405}
406
407fn expanded_params_md(
408    params: &serde_json::Value,
409    language_registry: &Arc<LanguageRegistry>,
410    cx: &mut App,
411) -> Entity<Markdown> {
412    let params_json = serde_json::to_string_pretty(params).unwrap_or_default();
413    let params_md = format!("```json\n{}\n```", params_json);
414    cx.new(|cx| Markdown::new(params_md.into(), Some(language_registry.clone()), None, cx))
415}
416
417enum MessageType {
418    Request,
419    Response,
420    Notification,
421}
422
423impl Display for MessageType {
424    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
425        match self {
426            MessageType::Request => write!(f, "Request"),
427            MessageType::Response => write!(f, "Response"),
428            MessageType::Notification => write!(f, "Notification"),
429        }
430    }
431}
432
433enum AcpToolsEvent {}
434
435impl EventEmitter<AcpToolsEvent> for AcpTools {}
436
437impl Item for AcpTools {
438    type Event = AcpToolsEvent;
439
440    fn tab_content_text(&self, _detail: usize, _cx: &App) -> ui::SharedString {
441        format!(
442            "ACP: {}",
443            self.watched_connection
444                .as_ref()
445                .map_or("Disconnected", |connection| &connection.server_name)
446        )
447        .into()
448    }
449
450    fn tab_icon(&self, _window: &Window, _cx: &App) -> Option<Icon> {
451        Some(ui::Icon::new(IconName::Thread))
452    }
453}
454
455impl Focusable for AcpTools {
456    fn focus_handle(&self, _cx: &App) -> FocusHandle {
457        self.focus_handle.clone()
458    }
459}
460
461impl Render for AcpTools {
462    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
463        v_flex()
464            .track_focus(&self.focus_handle)
465            .size_full()
466            .bg(cx.theme().colors().editor_background)
467            .child(match self.watched_connection.as_ref() {
468                Some(connection) => {
469                    if connection.messages.is_empty() {
470                        h_flex()
471                            .size_full()
472                            .justify_center()
473                            .items_center()
474                            .child("No messages recorded yet")
475                            .into_any()
476                    } else {
477                        list(
478                            connection.list_state.clone(),
479                            cx.processor(Self::render_message),
480                        )
481                        .with_sizing_behavior(gpui::ListSizingBehavior::Auto)
482                        .flex_grow()
483                        .into_any()
484                    }
485                }
486                None => h_flex()
487                    .size_full()
488                    .justify_center()
489                    .items_center()
490                    .child("No active connection")
491                    .into_any(),
492            })
493    }
494}