client.rs

  1use crate::{
  2    adapters::{DebugAdapterBinary, DebugAdapterName},
  3    transport::{IoKind, LogKind, TransportDelegate},
  4};
  5use anyhow::{anyhow, Result};
  6use dap_types::{
  7    messages::{Message, Response},
  8    requests::Request,
  9};
 10use futures::{channel::oneshot, select, FutureExt as _};
 11use gpui::{AppContext, AsyncApp, BackgroundExecutor};
 12use smol::channel::{Receiver, Sender};
 13use std::{
 14    hash::Hash,
 15    sync::atomic::{AtomicU64, Ordering},
 16    time::Duration,
 17};
 18
 19#[cfg(any(test, feature = "test-support"))]
 20const DAP_REQUEST_TIMEOUT: Duration = Duration::from_secs(2);
 21
 22#[cfg(not(any(test, feature = "test-support")))]
 23const DAP_REQUEST_TIMEOUT: Duration = Duration::from_secs(12);
 24
 25#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
 26#[repr(transparent)]
 27pub struct SessionId(pub u32);
 28
 29impl SessionId {
 30    pub fn from_proto(client_id: u64) -> Self {
 31        Self(client_id as u32)
 32    }
 33
 34    pub fn to_proto(&self) -> u64 {
 35        self.0 as u64
 36    }
 37}
 38
 39/// Represents a connection to the debug adapter process, either via stdout/stdin or a socket.
 40pub struct DebugAdapterClient {
 41    id: SessionId,
 42    name: DebugAdapterName,
 43    sequence_count: AtomicU64,
 44    binary: DebugAdapterBinary,
 45    executor: BackgroundExecutor,
 46    transport_delegate: TransportDelegate,
 47}
 48
 49pub type DapMessageHandler = Box<dyn FnMut(Message) + 'static + Send + Sync>;
 50
 51impl DebugAdapterClient {
 52    pub async fn start(
 53        id: SessionId,
 54        name: DebugAdapterName,
 55        binary: DebugAdapterBinary,
 56        message_handler: DapMessageHandler,
 57        cx: AsyncApp,
 58    ) -> Result<Self> {
 59        let ((server_rx, server_tx), transport_delegate) =
 60            TransportDelegate::start(&binary, cx.clone()).await?;
 61        let this = Self {
 62            id,
 63            name,
 64            binary,
 65            transport_delegate,
 66            sequence_count: AtomicU64::new(1),
 67            executor: cx.background_executor().clone(),
 68        };
 69        log::info!("Successfully connected to debug adapter");
 70
 71        let client_id = this.id;
 72
 73        // start handling events/reverse requests
 74
 75        cx.background_spawn(Self::handle_receive_messages(
 76            client_id,
 77            server_rx,
 78            server_tx.clone(),
 79            message_handler,
 80        ))
 81        .detach();
 82
 83        Ok(this)
 84    }
 85
 86    pub async fn reconnect(
 87        &self,
 88        session_id: SessionId,
 89        binary: DebugAdapterBinary,
 90        message_handler: DapMessageHandler,
 91        cx: AsyncApp,
 92    ) -> Result<Self> {
 93        let binary = match self.transport_delegate.transport() {
 94            crate::transport::Transport::Tcp(tcp_transport) => DebugAdapterBinary {
 95                command: binary.command,
 96                arguments: binary.arguments,
 97                envs: binary.envs,
 98                cwd: binary.cwd,
 99                connection: Some(crate::adapters::TcpArguments {
100                    host: tcp_transport.host,
101                    port: tcp_transport.port,
102                    timeout: Some(tcp_transport.timeout),
103                }),
104            },
105            _ => self.binary.clone(),
106        };
107
108        Self::start(session_id, self.name(), binary, message_handler, cx).await
109    }
110
111    async fn handle_receive_messages(
112        client_id: SessionId,
113        server_rx: Receiver<Message>,
114        client_tx: Sender<Message>,
115        mut message_handler: DapMessageHandler,
116    ) -> Result<()> {
117        let result = loop {
118            let message = match server_rx.recv().await {
119                Ok(message) => message,
120                Err(e) => break Err(e.into()),
121            };
122
123            match message {
124                Message::Event(ev) => {
125                    log::debug!("Client {} received event `{}`", client_id.0, &ev);
126
127                    message_handler(Message::Event(ev))
128                }
129                Message::Request(req) => {
130                    log::debug!(
131                        "Client {} received reverse request `{}`",
132                        client_id.0,
133                        &req.command
134                    );
135
136                    message_handler(Message::Request(req))
137                }
138                Message::Response(response) => {
139                    log::debug!("Received response after request timeout: {:#?}", response);
140                }
141            }
142
143            smol::future::yield_now().await;
144        };
145
146        drop(client_tx);
147
148        log::debug!("Handle receive messages dropped");
149
150        result
151    }
152
153    /// Send a request to an adapter and get a response back
154    /// Note: This function will block until a response is sent back from the adapter
155    pub async fn request<R: Request>(&self, arguments: R::Arguments) -> Result<R::Response> {
156        let serialized_arguments = serde_json::to_value(arguments)?;
157
158        let (callback_tx, callback_rx) = oneshot::channel::<Result<Response>>();
159
160        let sequence_id = self.next_sequence_id();
161
162        let request = crate::messages::Request {
163            seq: sequence_id,
164            command: R::COMMAND.to_string(),
165            arguments: Some(serialized_arguments),
166        };
167
168        self.transport_delegate
169            .add_pending_request(sequence_id, callback_tx)
170            .await;
171
172        log::debug!(
173            "Client {} send `{}` request with sequence_id: {}",
174            self.id.0,
175            R::COMMAND.to_string(),
176            sequence_id
177        );
178
179        self.send_message(Message::Request(request)).await?;
180
181        let mut timeout = self.executor.timer(DAP_REQUEST_TIMEOUT).fuse();
182        let command = R::COMMAND.to_string();
183
184        select! {
185            response = callback_rx.fuse() => {
186                log::debug!(
187                    "Client {} received response for: `{}` sequence_id: {}",
188                    self.id.0,
189                    command,
190                    sequence_id
191                );
192
193                let response = response??;
194                match response.success {
195                    true => Ok(serde_json::from_value(response.body.unwrap_or_default())?),
196                    false => Err(anyhow!("Request failed: {}", response.message.unwrap_or_default())),
197                }
198            }
199
200            _ = timeout => {
201                self.transport_delegate.cancel_pending_request(&sequence_id).await;
202                log::error!("Cancelled DAP request for {command:?} id {sequence_id} which took over {DAP_REQUEST_TIMEOUT:?}");
203                anyhow::bail!("DAP request timeout");
204            }
205        }
206    }
207
208    pub async fn send_message(&self, message: Message) -> Result<()> {
209        self.transport_delegate.send_message(message).await
210    }
211
212    pub fn id(&self) -> SessionId {
213        self.id
214    }
215
216    pub fn name(&self) -> DebugAdapterName {
217        self.name.clone()
218    }
219    pub fn binary(&self) -> &DebugAdapterBinary {
220        &self.binary
221    }
222
223    /// Get the next sequence id to be used in a request
224    pub fn next_sequence_id(&self) -> u64 {
225        self.sequence_count.fetch_add(1, Ordering::Relaxed)
226    }
227
228    pub async fn shutdown(&self) -> Result<()> {
229        self.transport_delegate.shutdown().await
230    }
231
232    pub fn has_adapter_logs(&self) -> bool {
233        self.transport_delegate.has_adapter_logs()
234    }
235
236    pub fn add_log_handler<F>(&self, f: F, kind: LogKind)
237    where
238        F: 'static + Send + FnMut(IoKind, &str),
239    {
240        self.transport_delegate.add_log_handler(f, kind);
241    }
242
243    #[cfg(any(test, feature = "test-support"))]
244    pub async fn on_request<R: dap_types::requests::Request, F>(&self, handler: F)
245    where
246        F: 'static
247            + Send
248            + FnMut(u64, R::Arguments) -> Result<R::Response, dap_types::ErrorResponse>,
249    {
250        let transport = self.transport_delegate.transport().as_fake();
251        transport.on_request::<R, F>(handler).await;
252    }
253
254    #[cfg(any(test, feature = "test-support"))]
255    pub async fn fake_reverse_request<R: dap_types::requests::Request>(&self, args: R::Arguments) {
256        self.send_message(Message::Request(dap_types::messages::Request {
257            seq: self.sequence_count.load(Ordering::Relaxed),
258            command: R::COMMAND.into(),
259            arguments: serde_json::to_value(args).ok(),
260        }))
261        .await
262        .unwrap();
263    }
264
265    #[cfg(any(test, feature = "test-support"))]
266    pub async fn on_response<R: dap_types::requests::Request, F>(&self, handler: F)
267    where
268        F: 'static + Send + Fn(Response),
269    {
270        let transport = self.transport_delegate.transport().as_fake();
271        transport.on_response::<R, F>(handler).await;
272    }
273
274    #[cfg(any(test, feature = "test-support"))]
275    pub async fn fake_event(&self, event: dap_types::messages::Events) {
276        self.send_message(Message::Event(Box::new(event)))
277            .await
278            .unwrap();
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use crate::{client::DebugAdapterClient, debugger_settings::DebuggerSettings};
286    use dap_types::{
287        messages::Events,
288        requests::{Initialize, Request, RunInTerminal},
289        Capabilities, InitializeRequestArguments, InitializeRequestArgumentsPathFormat,
290        RunInTerminalRequestArguments,
291    };
292    use gpui::TestAppContext;
293    use serde_json::json;
294    use settings::{Settings, SettingsStore};
295    use std::sync::{
296        atomic::{AtomicBool, Ordering},
297        Arc,
298    };
299
300    pub fn init_test(cx: &mut gpui::TestAppContext) {
301        if std::env::var("RUST_LOG").is_ok() {
302            env_logger::try_init().ok();
303        }
304
305        cx.update(|cx| {
306            let settings = SettingsStore::test(cx);
307            cx.set_global(settings);
308            DebuggerSettings::register(cx);
309        });
310    }
311
312    #[gpui::test]
313    pub async fn test_initialize_client(cx: &mut TestAppContext) {
314        init_test(cx);
315
316        let client = DebugAdapterClient::start(
317            crate::client::SessionId(1),
318            DebugAdapterName("adapter".into()),
319            DebugAdapterBinary {
320                command: "command".into(),
321                arguments: Default::default(),
322                envs: Default::default(),
323                connection: None,
324                cwd: None,
325            },
326            Box::new(|_| panic!("Did not expect to hit this code path")),
327            cx.to_async(),
328        )
329        .await
330        .unwrap();
331
332        client
333            .on_request::<Initialize, _>(move |_, _| {
334                Ok(dap_types::Capabilities {
335                    supports_configuration_done_request: Some(true),
336                    ..Default::default()
337                })
338            })
339            .await;
340
341        cx.run_until_parked();
342
343        let response = client
344            .request::<Initialize>(InitializeRequestArguments {
345                client_id: Some("zed".to_owned()),
346                client_name: Some("Zed".to_owned()),
347                adapter_id: "fake-adapter".to_owned(),
348                locale: Some("en-US".to_owned()),
349                path_format: Some(InitializeRequestArgumentsPathFormat::Path),
350                supports_variable_type: Some(true),
351                supports_variable_paging: Some(false),
352                supports_run_in_terminal_request: Some(true),
353                supports_memory_references: Some(true),
354                supports_progress_reporting: Some(false),
355                supports_invalidated_event: Some(false),
356                lines_start_at1: Some(true),
357                columns_start_at1: Some(true),
358                supports_memory_event: Some(false),
359                supports_args_can_be_interpreted_by_shell: Some(false),
360                supports_start_debugging_request: Some(true),
361                supports_ansistyling: Some(false),
362            })
363            .await
364            .unwrap();
365
366        cx.run_until_parked();
367
368        assert_eq!(
369            dap_types::Capabilities {
370                supports_configuration_done_request: Some(true),
371                ..Default::default()
372            },
373            response
374        );
375
376        client.shutdown().await.unwrap();
377    }
378
379    #[gpui::test]
380    pub async fn test_calls_event_handler(cx: &mut TestAppContext) {
381        init_test(cx);
382
383        let called_event_handler = Arc::new(AtomicBool::new(false));
384
385        let client = DebugAdapterClient::start(
386            crate::client::SessionId(1),
387            DebugAdapterName("adapter".into()),
388            DebugAdapterBinary {
389                command: "command".into(),
390                arguments: Default::default(),
391                envs: Default::default(),
392                connection: None,
393                cwd: None,
394            },
395            Box::new({
396                let called_event_handler = called_event_handler.clone();
397                move |event| {
398                    called_event_handler.store(true, Ordering::SeqCst);
399
400                    assert_eq!(
401                        Message::Event(Box::new(Events::Initialized(
402                            Some(Capabilities::default())
403                        ))),
404                        event
405                    );
406                }
407            }),
408            cx.to_async(),
409        )
410        .await
411        .unwrap();
412
413        cx.run_until_parked();
414
415        client
416            .fake_event(Events::Initialized(Some(Capabilities::default())))
417            .await;
418
419        cx.run_until_parked();
420
421        assert!(
422            called_event_handler.load(std::sync::atomic::Ordering::SeqCst),
423            "Event handler was not called"
424        );
425
426        client.shutdown().await.unwrap();
427    }
428
429    #[gpui::test]
430    pub async fn test_calls_event_handler_for_reverse_request(cx: &mut TestAppContext) {
431        init_test(cx);
432
433        let called_event_handler = Arc::new(AtomicBool::new(false));
434
435        let client = DebugAdapterClient::start(
436            crate::client::SessionId(1),
437            DebugAdapterName(Arc::from("test-adapter")),
438            DebugAdapterBinary {
439                command: "command".into(),
440                arguments: Default::default(),
441                envs: Default::default(),
442                connection: None,
443                cwd: None,
444            },
445            Box::new({
446                let called_event_handler = called_event_handler.clone();
447                move |event| {
448                    called_event_handler.store(true, Ordering::SeqCst);
449
450                    assert_eq!(
451                        Message::Request(dap_types::messages::Request {
452                            seq: 1,
453                            command: RunInTerminal::COMMAND.into(),
454                            arguments: Some(json!({
455                                "cwd": "/project/path/src",
456                                "args": ["node", "test.js"],
457                            }))
458                        }),
459                        event
460                    );
461                }
462            }),
463            cx.to_async(),
464        )
465        .await
466        .unwrap();
467
468        cx.run_until_parked();
469
470        client
471            .fake_reverse_request::<RunInTerminal>(RunInTerminalRequestArguments {
472                kind: None,
473                title: None,
474                cwd: "/project/path/src".into(),
475                args: vec!["node".into(), "test.js".into()],
476                env: None,
477                args_can_be_interpreted_by_shell: None,
478            })
479            .await;
480
481        cx.run_until_parked();
482
483        assert!(
484            called_event_handler.load(std::sync::atomic::Ordering::SeqCst),
485            "Event handler was not called"
486        );
487
488        client.shutdown().await.unwrap();
489    }
490}