client.rs

  1use crate::{
  2    adapters::{DebugAdapterBinary, DebugAdapterName},
  3    transport::{IoKind, LogKind, TransportDelegate},
  4};
  5use anyhow::{Result, anyhow};
  6use dap_types::{
  7    messages::{Message, Response},
  8    requests::Request,
  9};
 10use futures::{FutureExt as _, channel::oneshot, select};
 11use gpui::{AppContext, AsyncApp, BackgroundExecutor};
 12use smol::channel::{Receiver, Sender};
 13use std::{
 14    hash::Hash,
 15    sync::atomic::{AtomicU64, Ordering},
 16    time::Duration,
 17};
 18
 19#[cfg(any(test, feature = "test-support"))]
 20const DAP_REQUEST_TIMEOUT: Duration = Duration::from_secs(2);
 21
 22#[cfg(not(any(test, feature = "test-support")))]
 23const DAP_REQUEST_TIMEOUT: Duration = Duration::from_secs(12);
 24
 25#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
 26#[repr(transparent)]
 27pub struct SessionId(pub u32);
 28
 29impl SessionId {
 30    pub fn from_proto(client_id: u64) -> Self {
 31        Self(client_id as u32)
 32    }
 33
 34    pub fn to_proto(&self) -> u64 {
 35        self.0 as u64
 36    }
 37}
 38
 39/// Represents a connection to the debug adapter process, either via stdout/stdin or a socket.
 40pub struct DebugAdapterClient {
 41    id: SessionId,
 42    sequence_count: AtomicU64,
 43    binary: DebugAdapterBinary,
 44    executor: BackgroundExecutor,
 45    transport_delegate: TransportDelegate,
 46}
 47
 48pub type DapMessageHandler = Box<dyn FnMut(Message) + 'static + Send + Sync>;
 49
 50impl DebugAdapterClient {
 51    pub async fn start(
 52        id: SessionId,
 53        binary: DebugAdapterBinary,
 54        message_handler: DapMessageHandler,
 55        cx: AsyncApp,
 56    ) -> Result<Self> {
 57        let ((server_rx, server_tx), transport_delegate) =
 58            TransportDelegate::start(&binary, cx.clone()).await?;
 59        let this = Self {
 60            id,
 61            binary,
 62            transport_delegate,
 63            sequence_count: AtomicU64::new(1),
 64            executor: cx.background_executor().clone(),
 65        };
 66        log::info!("Successfully connected to debug adapter");
 67
 68        let client_id = this.id;
 69
 70        // start handling events/reverse requests
 71        cx.background_spawn(Self::handle_receive_messages(
 72            client_id,
 73            server_rx,
 74            server_tx.clone(),
 75            message_handler,
 76        ))
 77        .detach();
 78
 79        Ok(this)
 80    }
 81
 82    pub async fn reconnect(
 83        &self,
 84        session_id: SessionId,
 85        binary: DebugAdapterBinary,
 86        message_handler: DapMessageHandler,
 87        cx: AsyncApp,
 88    ) -> Result<Self> {
 89        let binary = match self.transport_delegate.transport() {
 90            crate::transport::Transport::Tcp(tcp_transport) => DebugAdapterBinary {
 91                adapter_name: binary.adapter_name,
 92                command: binary.command,
 93                arguments: binary.arguments,
 94                envs: binary.envs,
 95                cwd: binary.cwd,
 96                connection: Some(crate::adapters::TcpArguments {
 97                    host: tcp_transport.host,
 98                    port: tcp_transport.port,
 99                    timeout: Some(tcp_transport.timeout),
100                }),
101                request_args: binary.request_args,
102            },
103            _ => self.binary.clone(),
104        };
105
106        Self::start(session_id, binary, message_handler, cx).await
107    }
108
109    async fn handle_receive_messages(
110        client_id: SessionId,
111        server_rx: Receiver<Message>,
112        client_tx: Sender<Message>,
113        mut message_handler: DapMessageHandler,
114    ) -> Result<()> {
115        let result = loop {
116            let message = match server_rx.recv().await {
117                Ok(message) => message,
118                Err(e) => break Err(e.into()),
119            };
120            match message {
121                Message::Event(ev) => {
122                    log::debug!("Client {} received event `{}`", client_id.0, &ev);
123
124                    message_handler(Message::Event(ev))
125                }
126                Message::Request(req) => {
127                    log::debug!(
128                        "Client {} received reverse request `{}`",
129                        client_id.0,
130                        &req.command
131                    );
132
133                    message_handler(Message::Request(req))
134                }
135                Message::Response(response) => {
136                    log::debug!("Received response after request timeout: {:#?}", response);
137                }
138            }
139
140            smol::future::yield_now().await;
141        };
142
143        drop(client_tx);
144
145        log::debug!("Handle receive messages dropped");
146
147        result
148    }
149
150    /// Send a request to an adapter and get a response back
151    /// Note: This function will block until a response is sent back from the adapter
152    pub async fn request<R: Request>(&self, arguments: R::Arguments) -> Result<R::Response> {
153        let serialized_arguments = serde_json::to_value(arguments)?;
154
155        let (callback_tx, callback_rx) = oneshot::channel::<Result<Response>>();
156
157        let sequence_id = self.next_sequence_id();
158
159        let request = crate::messages::Request {
160            seq: sequence_id,
161            command: R::COMMAND.to_string(),
162            arguments: Some(serialized_arguments),
163        };
164        self.transport_delegate
165            .add_pending_request(sequence_id, callback_tx)
166            .await;
167
168        log::debug!(
169            "Client {} send `{}` request with sequence_id: {}",
170            self.id.0,
171            R::COMMAND,
172            sequence_id
173        );
174
175        self.send_message(Message::Request(request)).await?;
176
177        let mut timeout = self.executor.timer(DAP_REQUEST_TIMEOUT).fuse();
178        let command = R::COMMAND.to_string();
179
180        select! {
181            response = callback_rx.fuse() => {
182                log::debug!(
183                    "Client {} received response for: `{}` sequence_id: {}",
184                    self.id.0,
185                    command,
186                    sequence_id
187                );
188
189                let response = response??;
190                match response.success {
191                    true => {
192                        if let Some(json) = response.body {
193                            Ok(serde_json::from_value(json)?)
194                        // Note: dap types configure themselves to return `None` when an empty object is received,
195                        // which then fails here...
196                        } else if let Ok(result) = serde_json::from_value(serde_json::Value::Object(Default::default())) {
197                            Ok(result)
198                        } else {
199                            Ok(serde_json::from_value(Default::default())?)
200                        }
201                    }
202                    false => Err(anyhow!("Request failed: {}", response.message.unwrap_or_default())),
203                }
204            }
205
206            _ = timeout => {
207                self.transport_delegate.cancel_pending_request(&sequence_id).await;
208                log::error!("Cancelled DAP request for {command:?} id {sequence_id} which took over {DAP_REQUEST_TIMEOUT:?}");
209                anyhow::bail!("DAP request timeout");
210            }
211        }
212    }
213
214    pub async fn send_message(&self, message: Message) -> Result<()> {
215        self.transport_delegate.send_message(message).await
216    }
217
218    pub fn id(&self) -> SessionId {
219        self.id
220    }
221
222    pub fn name(&self) -> DebugAdapterName {
223        self.binary.adapter_name.clone()
224    }
225    pub fn binary(&self) -> &DebugAdapterBinary {
226        &self.binary
227    }
228
229    /// Get the next sequence id to be used in a request
230    pub fn next_sequence_id(&self) -> u64 {
231        self.sequence_count.fetch_add(1, Ordering::Relaxed)
232    }
233
234    pub async fn shutdown(&self) -> Result<()> {
235        self.transport_delegate.shutdown().await
236    }
237
238    pub fn has_adapter_logs(&self) -> bool {
239        self.transport_delegate.has_adapter_logs()
240    }
241
242    pub fn add_log_handler<F>(&self, f: F, kind: LogKind)
243    where
244        F: 'static + Send + FnMut(IoKind, &str),
245    {
246        self.transport_delegate.add_log_handler(f, kind);
247    }
248
249    #[cfg(any(test, feature = "test-support"))]
250    pub fn on_request<R: dap_types::requests::Request, F>(&self, handler: F)
251    where
252        F: 'static
253            + Send
254            + FnMut(u64, R::Arguments) -> Result<R::Response, dap_types::ErrorResponse>,
255    {
256        let transport = self.transport_delegate.transport().as_fake();
257        transport.on_request::<R, F>(handler);
258    }
259
260    #[cfg(any(test, feature = "test-support"))]
261    pub async fn fake_reverse_request<R: dap_types::requests::Request>(&self, args: R::Arguments) {
262        self.send_message(Message::Request(dap_types::messages::Request {
263            seq: self.sequence_count.load(Ordering::Relaxed),
264            command: R::COMMAND.into(),
265            arguments: serde_json::to_value(args).ok(),
266        }))
267        .await
268        .unwrap();
269    }
270
271    #[cfg(any(test, feature = "test-support"))]
272    pub async fn on_response<R: dap_types::requests::Request, F>(&self, handler: F)
273    where
274        F: 'static + Send + Fn(Response),
275    {
276        let transport = self.transport_delegate.transport().as_fake();
277        transport.on_response::<R, F>(handler).await;
278    }
279
280    #[cfg(any(test, feature = "test-support"))]
281    pub async fn fake_event(&self, event: dap_types::messages::Events) {
282        self.send_message(Message::Event(Box::new(event)))
283            .await
284            .unwrap();
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use crate::{client::DebugAdapterClient, debugger_settings::DebuggerSettings};
292    use dap_types::{
293        Capabilities, InitializeRequestArguments, InitializeRequestArgumentsPathFormat,
294        RunInTerminalRequestArguments, StartDebuggingRequestArguments,
295        messages::Events,
296        requests::{Initialize, Request, RunInTerminal},
297    };
298    use gpui::TestAppContext;
299    use serde_json::json;
300    use settings::{Settings, SettingsStore};
301    use std::sync::{
302        Arc,
303        atomic::{AtomicBool, Ordering},
304    };
305
306    pub fn init_test(cx: &mut gpui::TestAppContext) {
307        if std::env::var("RUST_LOG").is_ok() {
308            env_logger::try_init().ok();
309        }
310
311        cx.update(|cx| {
312            let settings = SettingsStore::test(cx);
313            cx.set_global(settings);
314            DebuggerSettings::register(cx);
315        });
316    }
317
318    #[gpui::test]
319    pub async fn test_initialize_client(cx: &mut TestAppContext) {
320        init_test(cx);
321
322        let client = DebugAdapterClient::start(
323            crate::client::SessionId(1),
324            DebugAdapterBinary {
325                adapter_name: "adapter".into(),
326                command: "command".into(),
327                arguments: Default::default(),
328                envs: Default::default(),
329                connection: None,
330                cwd: None,
331                request_args: StartDebuggingRequestArguments {
332                    configuration: serde_json::Value::Null,
333                    request: dap_types::StartDebuggingRequestArgumentsRequest::Launch,
334                },
335            },
336            Box::new(|_| panic!("Did not expect to hit this code path")),
337            cx.to_async(),
338        )
339        .await
340        .unwrap();
341
342        client.on_request::<Initialize, _>(move |_, _| {
343            Ok(dap_types::Capabilities {
344                supports_configuration_done_request: Some(true),
345                ..Default::default()
346            })
347        });
348
349        cx.run_until_parked();
350
351        let response = client
352            .request::<Initialize>(InitializeRequestArguments {
353                client_id: Some("zed".to_owned()),
354                client_name: Some("Zed".to_owned()),
355                adapter_id: "fake-adapter".to_owned(),
356                locale: Some("en-US".to_owned()),
357                path_format: Some(InitializeRequestArgumentsPathFormat::Path),
358                supports_variable_type: Some(true),
359                supports_variable_paging: Some(false),
360                supports_run_in_terminal_request: Some(true),
361                supports_memory_references: Some(true),
362                supports_progress_reporting: Some(false),
363                supports_invalidated_event: Some(false),
364                lines_start_at1: Some(true),
365                columns_start_at1: Some(true),
366                supports_memory_event: Some(false),
367                supports_args_can_be_interpreted_by_shell: Some(false),
368                supports_start_debugging_request: Some(true),
369                supports_ansistyling: Some(false),
370            })
371            .await
372            .unwrap();
373
374        cx.run_until_parked();
375
376        assert_eq!(
377            dap_types::Capabilities {
378                supports_configuration_done_request: Some(true),
379                ..Default::default()
380            },
381            response
382        );
383
384        client.shutdown().await.unwrap();
385    }
386
387    #[gpui::test]
388    pub async fn test_calls_event_handler(cx: &mut TestAppContext) {
389        init_test(cx);
390
391        let called_event_handler = Arc::new(AtomicBool::new(false));
392
393        let client = DebugAdapterClient::start(
394            crate::client::SessionId(1),
395            DebugAdapterBinary {
396                adapter_name: "adapter".into(),
397                command: "command".into(),
398                arguments: Default::default(),
399                envs: Default::default(),
400                connection: None,
401                cwd: None,
402                request_args: StartDebuggingRequestArguments {
403                    configuration: serde_json::Value::Null,
404                    request: dap_types::StartDebuggingRequestArgumentsRequest::Launch,
405                },
406            },
407            Box::new({
408                let called_event_handler = called_event_handler.clone();
409                move |event| {
410                    called_event_handler.store(true, Ordering::SeqCst);
411
412                    assert_eq!(
413                        Message::Event(Box::new(Events::Initialized(
414                            Some(Capabilities::default())
415                        ))),
416                        event
417                    );
418                }
419            }),
420            cx.to_async(),
421        )
422        .await
423        .unwrap();
424
425        cx.run_until_parked();
426
427        client
428            .fake_event(Events::Initialized(Some(Capabilities::default())))
429            .await;
430
431        cx.run_until_parked();
432
433        assert!(
434            called_event_handler.load(std::sync::atomic::Ordering::SeqCst),
435            "Event handler was not called"
436        );
437
438        client.shutdown().await.unwrap();
439    }
440
441    #[gpui::test]
442    pub async fn test_calls_event_handler_for_reverse_request(cx: &mut TestAppContext) {
443        init_test(cx);
444
445        let called_event_handler = Arc::new(AtomicBool::new(false));
446
447        let client = DebugAdapterClient::start(
448            crate::client::SessionId(1),
449            DebugAdapterBinary {
450                adapter_name: "test-adapter".into(),
451                command: "command".into(),
452                arguments: Default::default(),
453                envs: Default::default(),
454                connection: None,
455                cwd: None,
456                request_args: dap_types::StartDebuggingRequestArguments {
457                    configuration: serde_json::Value::Null,
458                    request: dap_types::StartDebuggingRequestArgumentsRequest::Launch,
459                },
460            },
461            Box::new({
462                let called_event_handler = called_event_handler.clone();
463                move |event| {
464                    called_event_handler.store(true, Ordering::SeqCst);
465
466                    assert_eq!(
467                        Message::Request(dap_types::messages::Request {
468                            seq: 1,
469                            command: RunInTerminal::COMMAND.into(),
470                            arguments: Some(json!({
471                                "cwd": "/project/path/src",
472                                "args": ["node", "test.js"],
473                            }))
474                        }),
475                        event
476                    );
477                }
478            }),
479            cx.to_async(),
480        )
481        .await
482        .unwrap();
483
484        cx.run_until_parked();
485
486        client
487            .fake_reverse_request::<RunInTerminal>(RunInTerminalRequestArguments {
488                kind: None,
489                title: None,
490                cwd: "/project/path/src".into(),
491                args: vec!["node".into(), "test.js".into()],
492                env: None,
493                args_can_be_interpreted_by_shell: None,
494            })
495            .await;
496
497        cx.run_until_parked();
498
499        assert!(
500            called_event_handler.load(std::sync::atomic::Ordering::SeqCst),
501            "Event handler was not called"
502        );
503
504        client.shutdown().await.unwrap();
505    }
506}