client.rs

  1use crate::{
  2    adapters::DebugAdapterBinary,
  3    transport::{IoKind, LogKind, TransportDelegate},
  4};
  5use anyhow::{Context as _, Result};
  6use dap_types::{
  7    messages::{Message, Response},
  8    requests::Request,
  9};
 10use futures::channel::oneshot;
 11use gpui::AsyncApp;
 12use std::{
 13    hash::Hash,
 14    sync::atomic::{AtomicU64, Ordering},
 15};
 16
 17#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
 18#[repr(transparent)]
 19pub struct SessionId(pub u32);
 20
 21impl SessionId {
 22    pub fn from_proto(client_id: u64) -> Self {
 23        Self(client_id as u32)
 24    }
 25
 26    pub fn to_proto(&self) -> u64 {
 27        self.0 as u64
 28    }
 29}
 30
 31/// Represents a connection to the debug adapter process, either via stdout/stdin or a socket.
 32pub struct DebugAdapterClient {
 33    id: SessionId,
 34    sequence_count: AtomicU64,
 35    binary: DebugAdapterBinary,
 36    transport_delegate: TransportDelegate,
 37}
 38
 39pub type DapMessageHandler = Box<dyn FnMut(Message) + 'static + Send + Sync>;
 40
 41impl DebugAdapterClient {
 42    pub async fn start(
 43        id: SessionId,
 44        binary: DebugAdapterBinary,
 45        message_handler: DapMessageHandler,
 46        cx: &mut AsyncApp,
 47    ) -> Result<Self> {
 48        let transport_delegate = TransportDelegate::start(&binary, cx).await?;
 49        let this = Self {
 50            id,
 51            binary,
 52            transport_delegate,
 53            sequence_count: AtomicU64::new(1),
 54        };
 55        this.connect(message_handler, cx).await?;
 56
 57        Ok(this)
 58    }
 59
 60    pub fn should_reconnect_for_ssh(&self) -> bool {
 61        self.transport_delegate.tcp_arguments().is_some()
 62            && self.binary.command.as_deref() == Some("ssh")
 63    }
 64
 65    pub async fn connect(
 66        &self,
 67        message_handler: DapMessageHandler,
 68        cx: &mut AsyncApp,
 69    ) -> Result<()> {
 70        self.transport_delegate.connect(message_handler, cx).await
 71    }
 72
 73    pub async fn create_child_connection(
 74        &self,
 75        session_id: SessionId,
 76        binary: DebugAdapterBinary,
 77        message_handler: DapMessageHandler,
 78        cx: &mut AsyncApp,
 79    ) -> Result<Self> {
 80        let binary = if let Some(connection) = self.transport_delegate.tcp_arguments() {
 81            DebugAdapterBinary {
 82                command: None,
 83                arguments: Default::default(),
 84                envs: Default::default(),
 85                cwd: Default::default(),
 86                connection: Some(connection),
 87                request_args: binary.request_args,
 88            }
 89        } else {
 90            self.binary.clone()
 91        };
 92
 93        Self::start(session_id, binary, message_handler, cx).await
 94    }
 95
 96    /// Send a request to an adapter and get a response back
 97    /// Note: This function will block until a response is sent back from the adapter
 98    pub async fn request<R: Request>(&self, arguments: R::Arguments) -> Result<R::Response> {
 99        let serialized_arguments = serde_json::to_value(arguments)?;
100
101        let (callback_tx, callback_rx) = oneshot::channel::<Result<Response>>();
102
103        let sequence_id = self.next_sequence_id();
104
105        let request = crate::messages::Request {
106            seq: sequence_id,
107            command: R::COMMAND.to_string(),
108            arguments: Some(serialized_arguments),
109        };
110        self.transport_delegate
111            .pending_requests
112            .lock()
113            .as_mut()
114            .context("client is closed")?
115            .insert(sequence_id, callback_tx);
116
117        log::debug!(
118            "Client {} send `{}` request with sequence_id: {}",
119            self.id.0,
120            R::COMMAND,
121            sequence_id
122        );
123
124        self.send_message(Message::Request(request)).await?;
125
126        let command = R::COMMAND.to_string();
127
128        let response = callback_rx.await??;
129        log::debug!(
130            "Client {} received response for: `{}` sequence_id: {}",
131            self.id.0,
132            command,
133            sequence_id
134        );
135        match response.success {
136            true => {
137                if let Some(json) = response.body {
138                    Ok(serde_json::from_value(json)?)
139                // Note: dap types configure themselves to return `None` when an empty object is received,
140                // which then fails here...
141                } else if let Ok(result) =
142                    serde_json::from_value(serde_json::Value::Object(Default::default()))
143                {
144                    Ok(result)
145                } else {
146                    Ok(serde_json::from_value(Default::default())?)
147                }
148            }
149            false => anyhow::bail!("Request failed: {}", response.message.unwrap_or_default()),
150        }
151    }
152
153    pub async fn send_message(&self, message: Message) -> Result<()> {
154        self.transport_delegate.send_message(message).await
155    }
156
157    pub fn id(&self) -> SessionId {
158        self.id
159    }
160
161    pub fn binary(&self) -> &DebugAdapterBinary {
162        &self.binary
163    }
164
165    /// Get the next sequence id to be used in a request
166    pub fn next_sequence_id(&self) -> u64 {
167        self.sequence_count.fetch_add(1, Ordering::Relaxed)
168    }
169
170    pub fn kill(&self) {
171        log::debug!("Killing DAP process");
172        self.transport_delegate.transport.lock().kill();
173    }
174
175    pub fn has_adapter_logs(&self) -> bool {
176        self.transport_delegate.has_adapter_logs()
177    }
178
179    pub fn add_log_handler<F>(&self, f: F, kind: LogKind)
180    where
181        F: 'static + Send + FnMut(IoKind, Option<&str>, &str),
182    {
183        self.transport_delegate.add_log_handler(f, kind);
184    }
185
186    #[cfg(any(test, feature = "test-support"))]
187    pub fn on_request<R: dap_types::requests::Request, F>(&self, handler: F)
188    where
189        F: 'static
190            + Send
191            + FnMut(u64, R::Arguments) -> Result<R::Response, dap_types::ErrorResponse>,
192    {
193        self.transport_delegate
194            .transport
195            .lock()
196            .as_fake()
197            .on_request::<R, F>(handler);
198    }
199
200    #[cfg(any(test, feature = "test-support"))]
201    pub async fn fake_reverse_request<R: dap_types::requests::Request>(&self, args: R::Arguments) {
202        self.send_message(Message::Request(dap_types::messages::Request {
203            seq: self.sequence_count.load(Ordering::Relaxed),
204            command: R::COMMAND.into(),
205            arguments: serde_json::to_value(args).ok(),
206        }))
207        .await
208        .unwrap();
209    }
210
211    #[cfg(any(test, feature = "test-support"))]
212    pub async fn on_response<R: dap_types::requests::Request, F>(&self, handler: F)
213    where
214        F: 'static + Send + Fn(Response),
215    {
216        self.transport_delegate
217            .transport
218            .lock()
219            .as_fake()
220            .on_response::<R, F>(handler);
221    }
222
223    #[cfg(any(test, feature = "test-support"))]
224    pub async fn fake_event(&self, event: dap_types::messages::Events) {
225        self.send_message(Message::Event(Box::new(event)))
226            .await
227            .unwrap();
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use crate::{client::DebugAdapterClient, debugger_settings::DebuggerSettings};
235    use dap_types::{
236        Capabilities, InitializeRequestArguments, InitializeRequestArgumentsPathFormat,
237        RunInTerminalRequestArguments, StartDebuggingRequestArguments,
238        messages::Events,
239        requests::{Initialize, Request, RunInTerminal},
240    };
241    use gpui::TestAppContext;
242    use serde_json::json;
243    use settings::{Settings, SettingsStore};
244    use std::sync::{
245        Arc,
246        atomic::{AtomicBool, Ordering},
247    };
248
249    pub fn init_test(cx: &mut gpui::TestAppContext) {
250        zlog::init_test();
251
252        cx.update(|cx| {
253            let settings = SettingsStore::test(cx);
254            cx.set_global(settings);
255            DebuggerSettings::register(cx);
256        });
257    }
258
259    #[gpui::test]
260    pub async fn test_initialize_client(cx: &mut TestAppContext) {
261        init_test(cx);
262
263        let client = DebugAdapterClient::start(
264            crate::client::SessionId(1),
265            DebugAdapterBinary {
266                command: Some("command".into()),
267                arguments: Default::default(),
268                envs: Default::default(),
269                connection: None,
270                cwd: None,
271                request_args: StartDebuggingRequestArguments {
272                    configuration: serde_json::Value::Null,
273                    request: dap_types::StartDebuggingRequestArgumentsRequest::Launch,
274                },
275            },
276            Box::new(|_| panic!("Did not expect to hit this code path")),
277            &mut cx.to_async(),
278        )
279        .await
280        .unwrap();
281
282        client.on_request::<Initialize, _>(move |_, _| {
283            Ok(dap_types::Capabilities {
284                supports_configuration_done_request: Some(true),
285                ..Default::default()
286            })
287        });
288
289        cx.run_until_parked();
290
291        let response = client
292            .request::<Initialize>(InitializeRequestArguments {
293                client_id: Some("zed".to_owned()),
294                client_name: Some("Zed".to_owned()),
295                adapter_id: "fake-adapter".to_owned(),
296                locale: Some("en-US".to_owned()),
297                path_format: Some(InitializeRequestArgumentsPathFormat::Path),
298                supports_variable_type: Some(true),
299                supports_variable_paging: Some(false),
300                supports_run_in_terminal_request: Some(true),
301                supports_memory_references: Some(true),
302                supports_progress_reporting: Some(false),
303                supports_invalidated_event: Some(false),
304                lines_start_at1: Some(true),
305                columns_start_at1: Some(true),
306                supports_memory_event: Some(false),
307                supports_args_can_be_interpreted_by_shell: Some(false),
308                supports_start_debugging_request: Some(true),
309                supports_ansistyling: Some(false),
310            })
311            .await
312            .unwrap();
313
314        cx.run_until_parked();
315
316        assert_eq!(
317            dap_types::Capabilities {
318                supports_configuration_done_request: Some(true),
319                ..Default::default()
320            },
321            response
322        );
323    }
324
325    #[gpui::test]
326    pub async fn test_calls_event_handler(cx: &mut TestAppContext) {
327        init_test(cx);
328
329        let called_event_handler = Arc::new(AtomicBool::new(false));
330
331        let client = DebugAdapterClient::start(
332            crate::client::SessionId(1),
333            DebugAdapterBinary {
334                command: Some("command".into()),
335                arguments: Default::default(),
336                envs: Default::default(),
337                connection: None,
338                cwd: None,
339                request_args: StartDebuggingRequestArguments {
340                    configuration: serde_json::Value::Null,
341                    request: dap_types::StartDebuggingRequestArgumentsRequest::Launch,
342                },
343            },
344            Box::new({
345                let called_event_handler = called_event_handler.clone();
346                move |event| {
347                    called_event_handler.store(true, Ordering::SeqCst);
348
349                    assert_eq!(
350                        Message::Event(Box::new(Events::Initialized(
351                            Some(Capabilities::default())
352                        ))),
353                        event
354                    );
355                }
356            }),
357            &mut cx.to_async(),
358        )
359        .await
360        .unwrap();
361
362        cx.run_until_parked();
363
364        client
365            .fake_event(Events::Initialized(Some(Capabilities::default())))
366            .await;
367
368        cx.run_until_parked();
369
370        assert!(
371            called_event_handler.load(std::sync::atomic::Ordering::SeqCst),
372            "Event handler was not called"
373        );
374    }
375
376    #[gpui::test]
377    pub async fn test_calls_event_handler_for_reverse_request(cx: &mut TestAppContext) {
378        init_test(cx);
379
380        let called_event_handler = Arc::new(AtomicBool::new(false));
381
382        let client = DebugAdapterClient::start(
383            crate::client::SessionId(1),
384            DebugAdapterBinary {
385                command: Some("command".into()),
386                arguments: Default::default(),
387                envs: Default::default(),
388                connection: None,
389                cwd: None,
390                request_args: dap_types::StartDebuggingRequestArguments {
391                    configuration: serde_json::Value::Null,
392                    request: dap_types::StartDebuggingRequestArgumentsRequest::Launch,
393                },
394            },
395            Box::new({
396                let called_event_handler = called_event_handler.clone();
397                move |event| {
398                    called_event_handler.store(true, Ordering::SeqCst);
399
400                    assert_eq!(
401                        Message::Request(dap_types::messages::Request {
402                            seq: 1,
403                            command: RunInTerminal::COMMAND.into(),
404                            arguments: Some(json!({
405                                "cwd": "/project/path/src",
406                                "args": ["node", "test.js"],
407                            }))
408                        }),
409                        event
410                    );
411                }
412            }),
413            &mut cx.to_async(),
414        )
415        .await
416        .unwrap();
417
418        cx.run_until_parked();
419
420        client
421            .fake_reverse_request::<RunInTerminal>(RunInTerminalRequestArguments {
422                kind: None,
423                title: None,
424                cwd: "/project/path/src".into(),
425                args: vec!["node".into(), "test.js".into()],
426                env: None,
427                args_can_be_interpreted_by_shell: None,
428            })
429            .await;
430
431        cx.run_until_parked();
432
433        assert!(
434            called_event_handler.load(std::sync::atomic::Ordering::SeqCst),
435            "Event handler was not called"
436        );
437    }
438}