client.rs

  1use crate::{
  2    adapters::DebugAdapterBinary,
  3    transport::{IoKind, LogKind, TransportDelegate},
  4};
  5use anyhow::{Result, anyhow};
  6use dap_types::{
  7    messages::{Message, Response},
  8    requests::Request,
  9};
 10use futures::channel::oneshot;
 11use gpui::{AppContext, AsyncApp};
 12use smol::channel::{Receiver, Sender};
 13use std::{
 14    hash::Hash,
 15    sync::atomic::{AtomicU64, Ordering},
 16};
 17
 18#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
 19#[repr(transparent)]
 20pub struct SessionId(pub u32);
 21
 22impl SessionId {
 23    pub fn from_proto(client_id: u64) -> Self {
 24        Self(client_id as u32)
 25    }
 26
 27    pub fn to_proto(&self) -> u64 {
 28        self.0 as u64
 29    }
 30}
 31
 32/// Represents a connection to the debug adapter process, either via stdout/stdin or a socket.
 33pub struct DebugAdapterClient {
 34    id: SessionId,
 35    sequence_count: AtomicU64,
 36    binary: DebugAdapterBinary,
 37    transport_delegate: TransportDelegate,
 38}
 39
 40pub type DapMessageHandler = Box<dyn FnMut(Message) + 'static + Send + Sync>;
 41
 42impl DebugAdapterClient {
 43    pub async fn start(
 44        id: SessionId,
 45        binary: DebugAdapterBinary,
 46        message_handler: DapMessageHandler,
 47        cx: AsyncApp,
 48    ) -> Result<Self> {
 49        let ((server_rx, server_tx), transport_delegate) =
 50            TransportDelegate::start(&binary, cx.clone()).await?;
 51        let this = Self {
 52            id,
 53            binary,
 54            transport_delegate,
 55            sequence_count: AtomicU64::new(1),
 56        };
 57        log::info!("Successfully connected to debug adapter");
 58
 59        let client_id = this.id;
 60
 61        // start handling events/reverse requests
 62        cx.background_spawn(Self::handle_receive_messages(
 63            client_id,
 64            server_rx,
 65            server_tx.clone(),
 66            message_handler,
 67        ))
 68        .detach();
 69
 70        Ok(this)
 71    }
 72
 73    pub async fn reconnect(
 74        &self,
 75        session_id: SessionId,
 76        binary: DebugAdapterBinary,
 77        message_handler: DapMessageHandler,
 78        cx: AsyncApp,
 79    ) -> Result<Self> {
 80        let binary = match self.transport_delegate.transport() {
 81            crate::transport::Transport::Tcp(tcp_transport) => DebugAdapterBinary {
 82                command: binary.command,
 83                arguments: binary.arguments,
 84                envs: binary.envs,
 85                cwd: binary.cwd,
 86                connection: Some(crate::adapters::TcpArguments {
 87                    host: tcp_transport.host,
 88                    port: tcp_transport.port,
 89                    timeout: Some(tcp_transport.timeout),
 90                }),
 91                request_args: binary.request_args,
 92            },
 93            _ => self.binary.clone(),
 94        };
 95
 96        Self::start(session_id, binary, message_handler, cx).await
 97    }
 98
 99    async fn handle_receive_messages(
100        client_id: SessionId,
101        server_rx: Receiver<Message>,
102        client_tx: Sender<Message>,
103        mut message_handler: DapMessageHandler,
104    ) -> Result<()> {
105        let result = loop {
106            let message = match server_rx.recv().await {
107                Ok(message) => message,
108                Err(e) => break Err(e.into()),
109            };
110            match message {
111                Message::Event(ev) => {
112                    log::debug!("Client {} received event `{}`", client_id.0, &ev);
113
114                    message_handler(Message::Event(ev))
115                }
116                Message::Request(req) => {
117                    log::debug!(
118                        "Client {} received reverse request `{}`",
119                        client_id.0,
120                        &req.command
121                    );
122
123                    message_handler(Message::Request(req))
124                }
125                Message::Response(response) => {
126                    log::debug!("Received response after request timeout: {:#?}", response);
127                }
128            }
129
130            smol::future::yield_now().await;
131        };
132
133        drop(client_tx);
134
135        log::debug!("Handle receive messages dropped");
136
137        result
138    }
139
140    /// Send a request to an adapter and get a response back
141    /// Note: This function will block until a response is sent back from the adapter
142    pub async fn request<R: Request>(&self, arguments: R::Arguments) -> Result<R::Response> {
143        let serialized_arguments = serde_json::to_value(arguments)?;
144
145        let (callback_tx, callback_rx) = oneshot::channel::<Result<Response>>();
146
147        let sequence_id = self.next_sequence_id();
148
149        let request = crate::messages::Request {
150            seq: sequence_id,
151            command: R::COMMAND.to_string(),
152            arguments: Some(serialized_arguments),
153        };
154        self.transport_delegate
155            .add_pending_request(sequence_id, callback_tx)
156            .await;
157
158        log::debug!(
159            "Client {} send `{}` request with sequence_id: {}",
160            self.id.0,
161            R::COMMAND,
162            sequence_id
163        );
164
165        self.send_message(Message::Request(request)).await?;
166
167        let command = R::COMMAND.to_string();
168
169        let response = callback_rx.await??;
170        log::debug!(
171            "Client {} received response for: `{}` sequence_id: {}",
172            self.id.0,
173            command,
174            sequence_id
175        );
176        match response.success {
177            true => {
178                if let Some(json) = response.body {
179                    Ok(serde_json::from_value(json)?)
180                // Note: dap types configure themselves to return `None` when an empty object is received,
181                // which then fails here...
182                } else if let Ok(result) =
183                    serde_json::from_value(serde_json::Value::Object(Default::default()))
184                {
185                    Ok(result)
186                } else {
187                    Ok(serde_json::from_value(Default::default())?)
188                }
189            }
190            false => Err(anyhow!(
191                "Request failed: {}",
192                response.message.unwrap_or_default()
193            )),
194        }
195    }
196
197    pub async fn send_message(&self, message: Message) -> Result<()> {
198        self.transport_delegate.send_message(message).await
199    }
200
201    pub fn id(&self) -> SessionId {
202        self.id
203    }
204
205    pub fn binary(&self) -> &DebugAdapterBinary {
206        &self.binary
207    }
208
209    /// Get the next sequence id to be used in a request
210    pub fn next_sequence_id(&self) -> u64 {
211        self.sequence_count.fetch_add(1, Ordering::Relaxed)
212    }
213
214    pub async fn shutdown(&self) -> Result<()> {
215        self.transport_delegate.shutdown().await
216    }
217
218    pub fn has_adapter_logs(&self) -> bool {
219        self.transport_delegate.has_adapter_logs()
220    }
221
222    pub fn add_log_handler<F>(&self, f: F, kind: LogKind)
223    where
224        F: 'static + Send + FnMut(IoKind, &str),
225    {
226        self.transport_delegate.add_log_handler(f, kind);
227    }
228
229    #[cfg(any(test, feature = "test-support"))]
230    pub fn on_request<R: dap_types::requests::Request, F>(&self, handler: F)
231    where
232        F: 'static
233            + Send
234            + FnMut(u64, R::Arguments) -> Result<R::Response, dap_types::ErrorResponse>,
235    {
236        let transport = self.transport_delegate.transport().as_fake();
237        transport.on_request::<R, F>(handler);
238    }
239
240    #[cfg(any(test, feature = "test-support"))]
241    pub async fn fake_reverse_request<R: dap_types::requests::Request>(&self, args: R::Arguments) {
242        self.send_message(Message::Request(dap_types::messages::Request {
243            seq: self.sequence_count.load(Ordering::Relaxed),
244            command: R::COMMAND.into(),
245            arguments: serde_json::to_value(args).ok(),
246        }))
247        .await
248        .unwrap();
249    }
250
251    #[cfg(any(test, feature = "test-support"))]
252    pub async fn on_response<R: dap_types::requests::Request, F>(&self, handler: F)
253    where
254        F: 'static + Send + Fn(Response),
255    {
256        let transport = self.transport_delegate.transport().as_fake();
257        transport.on_response::<R, F>(handler).await;
258    }
259
260    #[cfg(any(test, feature = "test-support"))]
261    pub async fn fake_event(&self, event: dap_types::messages::Events) {
262        self.send_message(Message::Event(Box::new(event)))
263            .await
264            .unwrap();
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use crate::{client::DebugAdapterClient, debugger_settings::DebuggerSettings};
272    use dap_types::{
273        Capabilities, InitializeRequestArguments, InitializeRequestArgumentsPathFormat,
274        RunInTerminalRequestArguments, StartDebuggingRequestArguments,
275        messages::Events,
276        requests::{Initialize, Request, RunInTerminal},
277    };
278    use gpui::TestAppContext;
279    use serde_json::json;
280    use settings::{Settings, SettingsStore};
281    use std::sync::{
282        Arc,
283        atomic::{AtomicBool, Ordering},
284    };
285
286    pub fn init_test(cx: &mut gpui::TestAppContext) {
287        if std::env::var("RUST_LOG").is_ok() {
288            env_logger::try_init().ok();
289        }
290
291        cx.update(|cx| {
292            let settings = SettingsStore::test(cx);
293            cx.set_global(settings);
294            DebuggerSettings::register(cx);
295        });
296    }
297
298    #[gpui::test]
299    pub async fn test_initialize_client(cx: &mut TestAppContext) {
300        init_test(cx);
301
302        let client = DebugAdapterClient::start(
303            crate::client::SessionId(1),
304            DebugAdapterBinary {
305                command: "command".into(),
306                arguments: Default::default(),
307                envs: Default::default(),
308                connection: None,
309                cwd: None,
310                request_args: StartDebuggingRequestArguments {
311                    configuration: serde_json::Value::Null,
312                    request: dap_types::StartDebuggingRequestArgumentsRequest::Launch,
313                },
314            },
315            Box::new(|_| panic!("Did not expect to hit this code path")),
316            cx.to_async(),
317        )
318        .await
319        .unwrap();
320
321        client.on_request::<Initialize, _>(move |_, _| {
322            Ok(dap_types::Capabilities {
323                supports_configuration_done_request: Some(true),
324                ..Default::default()
325            })
326        });
327
328        cx.run_until_parked();
329
330        let response = client
331            .request::<Initialize>(InitializeRequestArguments {
332                client_id: Some("zed".to_owned()),
333                client_name: Some("Zed".to_owned()),
334                adapter_id: "fake-adapter".to_owned(),
335                locale: Some("en-US".to_owned()),
336                path_format: Some(InitializeRequestArgumentsPathFormat::Path),
337                supports_variable_type: Some(true),
338                supports_variable_paging: Some(false),
339                supports_run_in_terminal_request: Some(true),
340                supports_memory_references: Some(true),
341                supports_progress_reporting: Some(false),
342                supports_invalidated_event: Some(false),
343                lines_start_at1: Some(true),
344                columns_start_at1: Some(true),
345                supports_memory_event: Some(false),
346                supports_args_can_be_interpreted_by_shell: Some(false),
347                supports_start_debugging_request: Some(true),
348                supports_ansistyling: Some(false),
349            })
350            .await
351            .unwrap();
352
353        cx.run_until_parked();
354
355        assert_eq!(
356            dap_types::Capabilities {
357                supports_configuration_done_request: Some(true),
358                ..Default::default()
359            },
360            response
361        );
362
363        client.shutdown().await.unwrap();
364    }
365
366    #[gpui::test]
367    pub async fn test_calls_event_handler(cx: &mut TestAppContext) {
368        init_test(cx);
369
370        let called_event_handler = Arc::new(AtomicBool::new(false));
371
372        let client = DebugAdapterClient::start(
373            crate::client::SessionId(1),
374            DebugAdapterBinary {
375                command: "command".into(),
376                arguments: Default::default(),
377                envs: Default::default(),
378                connection: None,
379                cwd: None,
380                request_args: StartDebuggingRequestArguments {
381                    configuration: serde_json::Value::Null,
382                    request: dap_types::StartDebuggingRequestArgumentsRequest::Launch,
383                },
384            },
385            Box::new({
386                let called_event_handler = called_event_handler.clone();
387                move |event| {
388                    called_event_handler.store(true, Ordering::SeqCst);
389
390                    assert_eq!(
391                        Message::Event(Box::new(Events::Initialized(
392                            Some(Capabilities::default())
393                        ))),
394                        event
395                    );
396                }
397            }),
398            cx.to_async(),
399        )
400        .await
401        .unwrap();
402
403        cx.run_until_parked();
404
405        client
406            .fake_event(Events::Initialized(Some(Capabilities::default())))
407            .await;
408
409        cx.run_until_parked();
410
411        assert!(
412            called_event_handler.load(std::sync::atomic::Ordering::SeqCst),
413            "Event handler was not called"
414        );
415
416        client.shutdown().await.unwrap();
417    }
418
419    #[gpui::test]
420    pub async fn test_calls_event_handler_for_reverse_request(cx: &mut TestAppContext) {
421        init_test(cx);
422
423        let called_event_handler = Arc::new(AtomicBool::new(false));
424
425        let client = DebugAdapterClient::start(
426            crate::client::SessionId(1),
427            DebugAdapterBinary {
428                command: "command".into(),
429                arguments: Default::default(),
430                envs: Default::default(),
431                connection: None,
432                cwd: None,
433                request_args: dap_types::StartDebuggingRequestArguments {
434                    configuration: serde_json::Value::Null,
435                    request: dap_types::StartDebuggingRequestArgumentsRequest::Launch,
436                },
437            },
438            Box::new({
439                let called_event_handler = called_event_handler.clone();
440                move |event| {
441                    called_event_handler.store(true, Ordering::SeqCst);
442
443                    assert_eq!(
444                        Message::Request(dap_types::messages::Request {
445                            seq: 1,
446                            command: RunInTerminal::COMMAND.into(),
447                            arguments: Some(json!({
448                                "cwd": "/project/path/src",
449                                "args": ["node", "test.js"],
450                            }))
451                        }),
452                        event
453                    );
454                }
455            }),
456            cx.to_async(),
457        )
458        .await
459        .unwrap();
460
461        cx.run_until_parked();
462
463        client
464            .fake_reverse_request::<RunInTerminal>(RunInTerminalRequestArguments {
465                kind: None,
466                title: None,
467                cwd: "/project/path/src".into(),
468                args: vec!["node".into(), "test.js".into()],
469                env: None,
470                args_can_be_interpreted_by_shell: None,
471            })
472            .await;
473
474        cx.run_until_parked();
475
476        assert!(
477            called_event_handler.load(std::sync::atomic::Ordering::SeqCst),
478            "Event handler was not called"
479        );
480
481        client.shutdown().await.unwrap();
482    }
483}