client.rs

  1use crate::{
  2    adapters::DebugAdapterBinary,
  3    transport::{IoKind, LogKind, TransportDelegate},
  4};
  5use anyhow::{Result, anyhow};
  6use dap_types::{
  7    messages::{Message, Response},
  8    requests::Request,
  9};
 10use futures::{FutureExt as _, channel::oneshot, select};
 11use gpui::{AppContext, AsyncApp, BackgroundExecutor};
 12use smol::channel::{Receiver, Sender};
 13use std::{
 14    hash::Hash,
 15    sync::atomic::{AtomicU64, Ordering},
 16    time::Duration,
 17};
 18
 19#[cfg(any(test, feature = "test-support"))]
 20const DAP_REQUEST_TIMEOUT: Duration = Duration::from_secs(2);
 21
 22#[cfg(not(any(test, feature = "test-support")))]
 23const DAP_REQUEST_TIMEOUT: Duration = Duration::from_secs(12);
 24
 25#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
 26#[repr(transparent)]
 27pub struct SessionId(pub u32);
 28
 29impl SessionId {
 30    pub fn from_proto(client_id: u64) -> Self {
 31        Self(client_id as u32)
 32    }
 33
 34    pub fn to_proto(&self) -> u64 {
 35        self.0 as u64
 36    }
 37}
 38
 39/// Represents a connection to the debug adapter process, either via stdout/stdin or a socket.
 40pub struct DebugAdapterClient {
 41    id: SessionId,
 42    sequence_count: AtomicU64,
 43    binary: DebugAdapterBinary,
 44    executor: BackgroundExecutor,
 45    transport_delegate: TransportDelegate,
 46}
 47
 48pub type DapMessageHandler = Box<dyn FnMut(Message) + 'static + Send + Sync>;
 49
 50impl DebugAdapterClient {
 51    pub async fn start(
 52        id: SessionId,
 53        binary: DebugAdapterBinary,
 54        message_handler: DapMessageHandler,
 55        cx: AsyncApp,
 56    ) -> Result<Self> {
 57        let ((server_rx, server_tx), transport_delegate) =
 58            TransportDelegate::start(&binary, cx.clone()).await?;
 59        let this = Self {
 60            id,
 61            binary,
 62            transport_delegate,
 63            sequence_count: AtomicU64::new(1),
 64            executor: cx.background_executor().clone(),
 65        };
 66        log::info!("Successfully connected to debug adapter");
 67
 68        let client_id = this.id;
 69
 70        // start handling events/reverse requests
 71        cx.background_spawn(Self::handle_receive_messages(
 72            client_id,
 73            server_rx,
 74            server_tx.clone(),
 75            message_handler,
 76        ))
 77        .detach();
 78
 79        Ok(this)
 80    }
 81
 82    pub async fn reconnect(
 83        &self,
 84        session_id: SessionId,
 85        binary: DebugAdapterBinary,
 86        message_handler: DapMessageHandler,
 87        cx: AsyncApp,
 88    ) -> Result<Self> {
 89        let binary = match self.transport_delegate.transport() {
 90            crate::transport::Transport::Tcp(tcp_transport) => DebugAdapterBinary {
 91                command: binary.command,
 92                arguments: binary.arguments,
 93                envs: binary.envs,
 94                cwd: binary.cwd,
 95                connection: Some(crate::adapters::TcpArguments {
 96                    host: tcp_transport.host,
 97                    port: tcp_transport.port,
 98                    timeout: Some(tcp_transport.timeout),
 99                }),
100                request_args: binary.request_args,
101            },
102            _ => self.binary.clone(),
103        };
104
105        Self::start(session_id, binary, message_handler, cx).await
106    }
107
108    async fn handle_receive_messages(
109        client_id: SessionId,
110        server_rx: Receiver<Message>,
111        client_tx: Sender<Message>,
112        mut message_handler: DapMessageHandler,
113    ) -> Result<()> {
114        let result = loop {
115            let message = match server_rx.recv().await {
116                Ok(message) => message,
117                Err(e) => break Err(e.into()),
118            };
119            match message {
120                Message::Event(ev) => {
121                    log::debug!("Client {} received event `{}`", client_id.0, &ev);
122
123                    message_handler(Message::Event(ev))
124                }
125                Message::Request(req) => {
126                    log::debug!(
127                        "Client {} received reverse request `{}`",
128                        client_id.0,
129                        &req.command
130                    );
131
132                    message_handler(Message::Request(req))
133                }
134                Message::Response(response) => {
135                    log::debug!("Received response after request timeout: {:#?}", response);
136                }
137            }
138
139            smol::future::yield_now().await;
140        };
141
142        drop(client_tx);
143
144        log::debug!("Handle receive messages dropped");
145
146        result
147    }
148
149    /// Send a request to an adapter and get a response back
150    /// Note: This function will block until a response is sent back from the adapter
151    pub async fn request<R: Request>(&self, arguments: R::Arguments) -> Result<R::Response> {
152        let serialized_arguments = serde_json::to_value(arguments)?;
153
154        let (callback_tx, callback_rx) = oneshot::channel::<Result<Response>>();
155
156        let sequence_id = self.next_sequence_id();
157
158        let request = crate::messages::Request {
159            seq: sequence_id,
160            command: R::COMMAND.to_string(),
161            arguments: Some(serialized_arguments),
162        };
163        self.transport_delegate
164            .add_pending_request(sequence_id, callback_tx)
165            .await;
166
167        log::debug!(
168            "Client {} send `{}` request with sequence_id: {}",
169            self.id.0,
170            R::COMMAND,
171            sequence_id
172        );
173
174        self.send_message(Message::Request(request)).await?;
175
176        let mut timeout = self.executor.timer(DAP_REQUEST_TIMEOUT).fuse();
177        let command = R::COMMAND.to_string();
178
179        select! {
180            response = callback_rx.fuse() => {
181                log::debug!(
182                    "Client {} received response for: `{}` sequence_id: {}",
183                    self.id.0,
184                    command,
185                    sequence_id
186                );
187
188                let response = response??;
189                match response.success {
190                    true => {
191                        if let Some(json) = response.body {
192                            Ok(serde_json::from_value(json)?)
193                        // Note: dap types configure themselves to return `None` when an empty object is received,
194                        // which then fails here...
195                        } else if let Ok(result) = serde_json::from_value(serde_json::Value::Object(Default::default())) {
196                            Ok(result)
197                        } else {
198                            Ok(serde_json::from_value(Default::default())?)
199                        }
200                    }
201                    false => Err(anyhow!("Request failed: {}", response.message.unwrap_or_default())),
202                }
203            }
204
205            _ = timeout => {
206                self.transport_delegate.cancel_pending_request(&sequence_id).await;
207                log::error!("Cancelled DAP request for {command:?} id {sequence_id} which took over {DAP_REQUEST_TIMEOUT:?}");
208                anyhow::bail!("DAP request timeout");
209            }
210        }
211    }
212
213    pub async fn send_message(&self, message: Message) -> Result<()> {
214        self.transport_delegate.send_message(message).await
215    }
216
217    pub fn id(&self) -> SessionId {
218        self.id
219    }
220
221    pub fn binary(&self) -> &DebugAdapterBinary {
222        &self.binary
223    }
224
225    /// Get the next sequence id to be used in a request
226    pub fn next_sequence_id(&self) -> u64 {
227        self.sequence_count.fetch_add(1, Ordering::Relaxed)
228    }
229
230    pub async fn shutdown(&self) -> Result<()> {
231        self.transport_delegate.shutdown().await
232    }
233
234    pub fn has_adapter_logs(&self) -> bool {
235        self.transport_delegate.has_adapter_logs()
236    }
237
238    pub fn add_log_handler<F>(&self, f: F, kind: LogKind)
239    where
240        F: 'static + Send + FnMut(IoKind, &str),
241    {
242        self.transport_delegate.add_log_handler(f, kind);
243    }
244
245    #[cfg(any(test, feature = "test-support"))]
246    pub fn on_request<R: dap_types::requests::Request, F>(&self, handler: F)
247    where
248        F: 'static
249            + Send
250            + FnMut(u64, R::Arguments) -> Result<R::Response, dap_types::ErrorResponse>,
251    {
252        let transport = self.transport_delegate.transport().as_fake();
253        transport.on_request::<R, F>(handler);
254    }
255
256    #[cfg(any(test, feature = "test-support"))]
257    pub async fn fake_reverse_request<R: dap_types::requests::Request>(&self, args: R::Arguments) {
258        self.send_message(Message::Request(dap_types::messages::Request {
259            seq: self.sequence_count.load(Ordering::Relaxed),
260            command: R::COMMAND.into(),
261            arguments: serde_json::to_value(args).ok(),
262        }))
263        .await
264        .unwrap();
265    }
266
267    #[cfg(any(test, feature = "test-support"))]
268    pub async fn on_response<R: dap_types::requests::Request, F>(&self, handler: F)
269    where
270        F: 'static + Send + Fn(Response),
271    {
272        let transport = self.transport_delegate.transport().as_fake();
273        transport.on_response::<R, F>(handler).await;
274    }
275
276    #[cfg(any(test, feature = "test-support"))]
277    pub async fn fake_event(&self, event: dap_types::messages::Events) {
278        self.send_message(Message::Event(Box::new(event)))
279            .await
280            .unwrap();
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287    use crate::{client::DebugAdapterClient, debugger_settings::DebuggerSettings};
288    use dap_types::{
289        Capabilities, InitializeRequestArguments, InitializeRequestArgumentsPathFormat,
290        RunInTerminalRequestArguments, StartDebuggingRequestArguments,
291        messages::Events,
292        requests::{Initialize, Request, RunInTerminal},
293    };
294    use gpui::TestAppContext;
295    use serde_json::json;
296    use settings::{Settings, SettingsStore};
297    use std::sync::{
298        Arc,
299        atomic::{AtomicBool, Ordering},
300    };
301
302    pub fn init_test(cx: &mut gpui::TestAppContext) {
303        if std::env::var("RUST_LOG").is_ok() {
304            env_logger::try_init().ok();
305        }
306
307        cx.update(|cx| {
308            let settings = SettingsStore::test(cx);
309            cx.set_global(settings);
310            DebuggerSettings::register(cx);
311        });
312    }
313
314    #[gpui::test]
315    pub async fn test_initialize_client(cx: &mut TestAppContext) {
316        init_test(cx);
317
318        let client = DebugAdapterClient::start(
319            crate::client::SessionId(1),
320            DebugAdapterBinary {
321                command: "command".into(),
322                arguments: Default::default(),
323                envs: Default::default(),
324                connection: None,
325                cwd: None,
326                request_args: StartDebuggingRequestArguments {
327                    configuration: serde_json::Value::Null,
328                    request: dap_types::StartDebuggingRequestArgumentsRequest::Launch,
329                },
330            },
331            Box::new(|_| panic!("Did not expect to hit this code path")),
332            cx.to_async(),
333        )
334        .await
335        .unwrap();
336
337        client.on_request::<Initialize, _>(move |_, _| {
338            Ok(dap_types::Capabilities {
339                supports_configuration_done_request: Some(true),
340                ..Default::default()
341            })
342        });
343
344        cx.run_until_parked();
345
346        let response = client
347            .request::<Initialize>(InitializeRequestArguments {
348                client_id: Some("zed".to_owned()),
349                client_name: Some("Zed".to_owned()),
350                adapter_id: "fake-adapter".to_owned(),
351                locale: Some("en-US".to_owned()),
352                path_format: Some(InitializeRequestArgumentsPathFormat::Path),
353                supports_variable_type: Some(true),
354                supports_variable_paging: Some(false),
355                supports_run_in_terminal_request: Some(true),
356                supports_memory_references: Some(true),
357                supports_progress_reporting: Some(false),
358                supports_invalidated_event: Some(false),
359                lines_start_at1: Some(true),
360                columns_start_at1: Some(true),
361                supports_memory_event: Some(false),
362                supports_args_can_be_interpreted_by_shell: Some(false),
363                supports_start_debugging_request: Some(true),
364                supports_ansistyling: Some(false),
365            })
366            .await
367            .unwrap();
368
369        cx.run_until_parked();
370
371        assert_eq!(
372            dap_types::Capabilities {
373                supports_configuration_done_request: Some(true),
374                ..Default::default()
375            },
376            response
377        );
378
379        client.shutdown().await.unwrap();
380    }
381
382    #[gpui::test]
383    pub async fn test_calls_event_handler(cx: &mut TestAppContext) {
384        init_test(cx);
385
386        let called_event_handler = Arc::new(AtomicBool::new(false));
387
388        let client = DebugAdapterClient::start(
389            crate::client::SessionId(1),
390            DebugAdapterBinary {
391                command: "command".into(),
392                arguments: Default::default(),
393                envs: Default::default(),
394                connection: None,
395                cwd: None,
396                request_args: StartDebuggingRequestArguments {
397                    configuration: serde_json::Value::Null,
398                    request: dap_types::StartDebuggingRequestArgumentsRequest::Launch,
399                },
400            },
401            Box::new({
402                let called_event_handler = called_event_handler.clone();
403                move |event| {
404                    called_event_handler.store(true, Ordering::SeqCst);
405
406                    assert_eq!(
407                        Message::Event(Box::new(Events::Initialized(
408                            Some(Capabilities::default())
409                        ))),
410                        event
411                    );
412                }
413            }),
414            cx.to_async(),
415        )
416        .await
417        .unwrap();
418
419        cx.run_until_parked();
420
421        client
422            .fake_event(Events::Initialized(Some(Capabilities::default())))
423            .await;
424
425        cx.run_until_parked();
426
427        assert!(
428            called_event_handler.load(std::sync::atomic::Ordering::SeqCst),
429            "Event handler was not called"
430        );
431
432        client.shutdown().await.unwrap();
433    }
434
435    #[gpui::test]
436    pub async fn test_calls_event_handler_for_reverse_request(cx: &mut TestAppContext) {
437        init_test(cx);
438
439        let called_event_handler = Arc::new(AtomicBool::new(false));
440
441        let client = DebugAdapterClient::start(
442            crate::client::SessionId(1),
443            DebugAdapterBinary {
444                command: "command".into(),
445                arguments: Default::default(),
446                envs: Default::default(),
447                connection: None,
448                cwd: None,
449                request_args: dap_types::StartDebuggingRequestArguments {
450                    configuration: serde_json::Value::Null,
451                    request: dap_types::StartDebuggingRequestArgumentsRequest::Launch,
452                },
453            },
454            Box::new({
455                let called_event_handler = called_event_handler.clone();
456                move |event| {
457                    called_event_handler.store(true, Ordering::SeqCst);
458
459                    assert_eq!(
460                        Message::Request(dap_types::messages::Request {
461                            seq: 1,
462                            command: RunInTerminal::COMMAND.into(),
463                            arguments: Some(json!({
464                                "cwd": "/project/path/src",
465                                "args": ["node", "test.js"],
466                            }))
467                        }),
468                        event
469                    );
470                }
471            }),
472            cx.to_async(),
473        )
474        .await
475        .unwrap();
476
477        cx.run_until_parked();
478
479        client
480            .fake_reverse_request::<RunInTerminal>(RunInTerminalRequestArguments {
481                kind: None,
482                title: None,
483                cwd: "/project/path/src".into(),
484                args: vec!["node".into(), "test.js".into()],
485                env: None,
486                args_can_be_interpreted_by_shell: None,
487            })
488            .await;
489
490        cx.run_until_parked();
491
492        assert!(
493            called_event_handler.load(std::sync::atomic::Ordering::SeqCst),
494            "Event handler was not called"
495        );
496
497        client.shutdown().await.unwrap();
498    }
499}