client.rs

  1use crate::{
  2    adapters::{DebugAdapterBinary, DebugAdapterName},
  3    transport::{IoKind, LogKind, TransportDelegate},
  4};
  5use anyhow::{Result, anyhow};
  6use dap_types::{
  7    messages::{Message, Response},
  8    requests::Request,
  9};
 10use futures::{FutureExt as _, channel::oneshot, select};
 11use gpui::{AppContext, AsyncApp, BackgroundExecutor};
 12use smol::channel::{Receiver, Sender};
 13use std::{
 14    hash::Hash,
 15    sync::atomic::{AtomicU64, Ordering},
 16    time::Duration,
 17};
 18
 19#[cfg(any(test, feature = "test-support"))]
 20const DAP_REQUEST_TIMEOUT: Duration = Duration::from_secs(2);
 21
 22#[cfg(not(any(test, feature = "test-support")))]
 23const DAP_REQUEST_TIMEOUT: Duration = Duration::from_secs(12);
 24
 25#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
 26#[repr(transparent)]
 27pub struct SessionId(pub u32);
 28
 29impl SessionId {
 30    pub fn from_proto(client_id: u64) -> Self {
 31        Self(client_id as u32)
 32    }
 33
 34    pub fn to_proto(&self) -> u64 {
 35        self.0 as u64
 36    }
 37}
 38
 39/// Represents a connection to the debug adapter process, either via stdout/stdin or a socket.
 40pub struct DebugAdapterClient {
 41    id: SessionId,
 42    name: DebugAdapterName,
 43    sequence_count: AtomicU64,
 44    binary: DebugAdapterBinary,
 45    executor: BackgroundExecutor,
 46    transport_delegate: TransportDelegate,
 47}
 48
 49pub type DapMessageHandler = Box<dyn FnMut(Message) + 'static + Send + Sync>;
 50
 51impl DebugAdapterClient {
 52    pub async fn start(
 53        id: SessionId,
 54        name: DebugAdapterName,
 55        binary: DebugAdapterBinary,
 56        message_handler: DapMessageHandler,
 57        cx: AsyncApp,
 58    ) -> Result<Self> {
 59        let ((server_rx, server_tx), transport_delegate) =
 60            TransportDelegate::start(&binary, cx.clone()).await?;
 61        let this = Self {
 62            id,
 63            name,
 64            binary,
 65            transport_delegate,
 66            sequence_count: AtomicU64::new(1),
 67            executor: cx.background_executor().clone(),
 68        };
 69        log::info!("Successfully connected to debug adapter");
 70
 71        let client_id = this.id;
 72
 73        // start handling events/reverse requests
 74        cx.background_spawn(Self::handle_receive_messages(
 75            client_id,
 76            server_rx,
 77            server_tx.clone(),
 78            message_handler,
 79        ))
 80        .detach();
 81
 82        Ok(this)
 83    }
 84
 85    pub async fn reconnect(
 86        &self,
 87        session_id: SessionId,
 88        binary: DebugAdapterBinary,
 89        message_handler: DapMessageHandler,
 90        cx: AsyncApp,
 91    ) -> Result<Self> {
 92        let binary = match self.transport_delegate.transport() {
 93            crate::transport::Transport::Tcp(tcp_transport) => DebugAdapterBinary {
 94                command: binary.command,
 95                arguments: binary.arguments,
 96                envs: binary.envs,
 97                cwd: binary.cwd,
 98                connection: Some(crate::adapters::TcpArguments {
 99                    host: tcp_transport.host,
100                    port: tcp_transport.port,
101                    timeout: Some(tcp_transport.timeout),
102                }),
103            },
104            _ => self.binary.clone(),
105        };
106
107        Self::start(session_id, self.name(), binary, message_handler, cx).await
108    }
109
110    async fn handle_receive_messages(
111        client_id: SessionId,
112        server_rx: Receiver<Message>,
113        client_tx: Sender<Message>,
114        mut message_handler: DapMessageHandler,
115    ) -> Result<()> {
116        let result = loop {
117            let message = match server_rx.recv().await {
118                Ok(message) => message,
119                Err(e) => break Err(e.into()),
120            };
121            match message {
122                Message::Event(ev) => {
123                    log::debug!("Client {} received event `{}`", client_id.0, &ev);
124
125                    message_handler(Message::Event(ev))
126                }
127                Message::Request(req) => {
128                    log::debug!(
129                        "Client {} received reverse request `{}`",
130                        client_id.0,
131                        &req.command
132                    );
133
134                    message_handler(Message::Request(req))
135                }
136                Message::Response(response) => {
137                    log::debug!("Received response after request timeout: {:#?}", response);
138                }
139            }
140
141            smol::future::yield_now().await;
142        };
143
144        drop(client_tx);
145
146        log::debug!("Handle receive messages dropped");
147
148        result
149    }
150
151    /// Send a request to an adapter and get a response back
152    /// Note: This function will block until a response is sent back from the adapter
153    pub async fn request<R: Request>(&self, arguments: R::Arguments) -> Result<R::Response> {
154        let serialized_arguments = serde_json::to_value(arguments)?;
155
156        let (callback_tx, callback_rx) = oneshot::channel::<Result<Response>>();
157
158        let sequence_id = self.next_sequence_id();
159
160        let request = crate::messages::Request {
161            seq: sequence_id,
162            command: R::COMMAND.to_string(),
163            arguments: Some(serialized_arguments),
164        };
165        self.transport_delegate
166            .add_pending_request(sequence_id, callback_tx)
167            .await;
168
169        log::debug!(
170            "Client {} send `{}` request with sequence_id: {}",
171            self.id.0,
172            R::COMMAND,
173            sequence_id
174        );
175
176        self.send_message(Message::Request(request)).await?;
177
178        let mut timeout = self.executor.timer(DAP_REQUEST_TIMEOUT).fuse();
179        let command = R::COMMAND.to_string();
180
181        select! {
182            response = callback_rx.fuse() => {
183                log::debug!(
184                    "Client {} received response for: `{}` sequence_id: {}",
185                    self.id.0,
186                    command,
187                    sequence_id
188                );
189
190                let response = response??;
191                match response.success {
192                    true => Ok(serde_json::from_value(response.body.unwrap_or_default())?),
193                    false => Err(anyhow!("Request failed: {}", response.message.unwrap_or_default())),
194                }
195            }
196
197            _ = timeout => {
198                self.transport_delegate.cancel_pending_request(&sequence_id).await;
199                log::error!("Cancelled DAP request for {command:?} id {sequence_id} which took over {DAP_REQUEST_TIMEOUT:?}");
200                anyhow::bail!("DAP request timeout");
201            }
202        }
203    }
204
205    pub async fn send_message(&self, message: Message) -> Result<()> {
206        self.transport_delegate.send_message(message).await
207    }
208
209    pub fn id(&self) -> SessionId {
210        self.id
211    }
212
213    pub fn name(&self) -> DebugAdapterName {
214        self.name.clone()
215    }
216    pub fn binary(&self) -> &DebugAdapterBinary {
217        &self.binary
218    }
219
220    /// Get the next sequence id to be used in a request
221    pub fn next_sequence_id(&self) -> u64 {
222        self.sequence_count.fetch_add(1, Ordering::Relaxed)
223    }
224
225    pub async fn shutdown(&self) -> Result<()> {
226        self.transport_delegate.shutdown().await
227    }
228
229    pub fn has_adapter_logs(&self) -> bool {
230        self.transport_delegate.has_adapter_logs()
231    }
232
233    pub fn add_log_handler<F>(&self, f: F, kind: LogKind)
234    where
235        F: 'static + Send + FnMut(IoKind, &str),
236    {
237        self.transport_delegate.add_log_handler(f, kind);
238    }
239
240    #[cfg(any(test, feature = "test-support"))]
241    pub async fn on_request<R: dap_types::requests::Request, F>(&self, handler: F)
242    where
243        F: 'static
244            + Send
245            + FnMut(u64, R::Arguments) -> Result<R::Response, dap_types::ErrorResponse>,
246    {
247        let transport = self.transport_delegate.transport().as_fake();
248        transport.on_request::<R, F>(handler).await;
249    }
250
251    #[cfg(any(test, feature = "test-support"))]
252    pub async fn fake_reverse_request<R: dap_types::requests::Request>(&self, args: R::Arguments) {
253        self.send_message(Message::Request(dap_types::messages::Request {
254            seq: self.sequence_count.load(Ordering::Relaxed),
255            command: R::COMMAND.into(),
256            arguments: serde_json::to_value(args).ok(),
257        }))
258        .await
259        .unwrap();
260    }
261
262    #[cfg(any(test, feature = "test-support"))]
263    pub async fn on_response<R: dap_types::requests::Request, F>(&self, handler: F)
264    where
265        F: 'static + Send + Fn(Response),
266    {
267        let transport = self.transport_delegate.transport().as_fake();
268        transport.on_response::<R, F>(handler).await;
269    }
270
271    #[cfg(any(test, feature = "test-support"))]
272    pub async fn fake_event(&self, event: dap_types::messages::Events) {
273        self.send_message(Message::Event(Box::new(event)))
274            .await
275            .unwrap();
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282    use crate::{client::DebugAdapterClient, debugger_settings::DebuggerSettings};
283    use dap_types::{
284        Capabilities, InitializeRequestArguments, InitializeRequestArgumentsPathFormat,
285        RunInTerminalRequestArguments,
286        messages::Events,
287        requests::{Initialize, Request, RunInTerminal},
288    };
289    use gpui::TestAppContext;
290    use serde_json::json;
291    use settings::{Settings, SettingsStore};
292    use std::sync::{
293        Arc,
294        atomic::{AtomicBool, Ordering},
295    };
296
297    pub fn init_test(cx: &mut gpui::TestAppContext) {
298        if std::env::var("RUST_LOG").is_ok() {
299            env_logger::try_init().ok();
300        }
301
302        cx.update(|cx| {
303            let settings = SettingsStore::test(cx);
304            cx.set_global(settings);
305            DebuggerSettings::register(cx);
306        });
307    }
308
309    #[gpui::test]
310    pub async fn test_initialize_client(cx: &mut TestAppContext) {
311        init_test(cx);
312
313        let client = DebugAdapterClient::start(
314            crate::client::SessionId(1),
315            DebugAdapterName("adapter".into()),
316            DebugAdapterBinary {
317                command: "command".into(),
318                arguments: Default::default(),
319                envs: Default::default(),
320                connection: None,
321                cwd: None,
322            },
323            Box::new(|_| panic!("Did not expect to hit this code path")),
324            cx.to_async(),
325        )
326        .await
327        .unwrap();
328
329        client
330            .on_request::<Initialize, _>(move |_, _| {
331                Ok(dap_types::Capabilities {
332                    supports_configuration_done_request: Some(true),
333                    ..Default::default()
334                })
335            })
336            .await;
337
338        cx.run_until_parked();
339
340        let response = client
341            .request::<Initialize>(InitializeRequestArguments {
342                client_id: Some("zed".to_owned()),
343                client_name: Some("Zed".to_owned()),
344                adapter_id: "fake-adapter".to_owned(),
345                locale: Some("en-US".to_owned()),
346                path_format: Some(InitializeRequestArgumentsPathFormat::Path),
347                supports_variable_type: Some(true),
348                supports_variable_paging: Some(false),
349                supports_run_in_terminal_request: Some(true),
350                supports_memory_references: Some(true),
351                supports_progress_reporting: Some(false),
352                supports_invalidated_event: Some(false),
353                lines_start_at1: Some(true),
354                columns_start_at1: Some(true),
355                supports_memory_event: Some(false),
356                supports_args_can_be_interpreted_by_shell: Some(false),
357                supports_start_debugging_request: Some(true),
358                supports_ansistyling: Some(false),
359            })
360            .await
361            .unwrap();
362
363        cx.run_until_parked();
364
365        assert_eq!(
366            dap_types::Capabilities {
367                supports_configuration_done_request: Some(true),
368                ..Default::default()
369            },
370            response
371        );
372
373        client.shutdown().await.unwrap();
374    }
375
376    #[gpui::test]
377    pub async fn test_calls_event_handler(cx: &mut TestAppContext) {
378        init_test(cx);
379
380        let called_event_handler = Arc::new(AtomicBool::new(false));
381
382        let client = DebugAdapterClient::start(
383            crate::client::SessionId(1),
384            DebugAdapterName("adapter".into()),
385            DebugAdapterBinary {
386                command: "command".into(),
387                arguments: Default::default(),
388                envs: Default::default(),
389                connection: None,
390                cwd: None,
391            },
392            Box::new({
393                let called_event_handler = called_event_handler.clone();
394                move |event| {
395                    called_event_handler.store(true, Ordering::SeqCst);
396
397                    assert_eq!(
398                        Message::Event(Box::new(Events::Initialized(
399                            Some(Capabilities::default())
400                        ))),
401                        event
402                    );
403                }
404            }),
405            cx.to_async(),
406        )
407        .await
408        .unwrap();
409
410        cx.run_until_parked();
411
412        client
413            .fake_event(Events::Initialized(Some(Capabilities::default())))
414            .await;
415
416        cx.run_until_parked();
417
418        assert!(
419            called_event_handler.load(std::sync::atomic::Ordering::SeqCst),
420            "Event handler was not called"
421        );
422
423        client.shutdown().await.unwrap();
424    }
425
426    #[gpui::test]
427    pub async fn test_calls_event_handler_for_reverse_request(cx: &mut TestAppContext) {
428        init_test(cx);
429
430        let called_event_handler = Arc::new(AtomicBool::new(false));
431
432        let client = DebugAdapterClient::start(
433            crate::client::SessionId(1),
434            DebugAdapterName("test-adapter".into()),
435            DebugAdapterBinary {
436                command: "command".into(),
437                arguments: Default::default(),
438                envs: Default::default(),
439                connection: None,
440                cwd: None,
441            },
442            Box::new({
443                let called_event_handler = called_event_handler.clone();
444                move |event| {
445                    called_event_handler.store(true, Ordering::SeqCst);
446
447                    assert_eq!(
448                        Message::Request(dap_types::messages::Request {
449                            seq: 1,
450                            command: RunInTerminal::COMMAND.into(),
451                            arguments: Some(json!({
452                                "cwd": "/project/path/src",
453                                "args": ["node", "test.js"],
454                            }))
455                        }),
456                        event
457                    );
458                }
459            }),
460            cx.to_async(),
461        )
462        .await
463        .unwrap();
464
465        cx.run_until_parked();
466
467        client
468            .fake_reverse_request::<RunInTerminal>(RunInTerminalRequestArguments {
469                kind: None,
470                title: None,
471                cwd: "/project/path/src".into(),
472                args: vec!["node".into(), "test.js".into()],
473                env: None,
474                args_can_be_interpreted_by_shell: None,
475            })
476            .await;
477
478        cx.run_until_parked();
479
480        assert!(
481            called_event_handler.load(std::sync::atomic::Ordering::SeqCst),
482            "Event handler was not called"
483        );
484
485        client.shutdown().await.unwrap();
486    }
487}