client.rs

  1use crate::{
  2    adapters::DebugAdapterBinary,
  3    transport::{IoKind, LogKind, TransportDelegate},
  4};
  5use anyhow::Result;
  6use dap_types::{
  7    messages::{Message, Response},
  8    requests::Request,
  9};
 10use futures::channel::oneshot;
 11use gpui::{AppContext, AsyncApp};
 12use smol::channel::{Receiver, Sender};
 13use std::{
 14    hash::Hash,
 15    sync::atomic::{AtomicU64, Ordering},
 16};
 17
 18#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
 19#[repr(transparent)]
 20pub struct SessionId(pub u32);
 21
 22impl SessionId {
 23    pub fn from_proto(client_id: u64) -> Self {
 24        Self(client_id as u32)
 25    }
 26
 27    pub fn to_proto(&self) -> u64 {
 28        self.0 as u64
 29    }
 30}
 31
 32/// Represents a connection to the debug adapter process, either via stdout/stdin or a socket.
 33pub struct DebugAdapterClient {
 34    id: SessionId,
 35    sequence_count: AtomicU64,
 36    binary: DebugAdapterBinary,
 37    transport_delegate: TransportDelegate,
 38}
 39
 40pub type DapMessageHandler = Box<dyn FnMut(Message) + 'static + Send + Sync>;
 41
 42impl DebugAdapterClient {
 43    pub async fn start(
 44        id: SessionId,
 45        binary: DebugAdapterBinary,
 46        message_handler: DapMessageHandler,
 47        cx: AsyncApp,
 48    ) -> Result<Self> {
 49        let ((server_rx, server_tx), transport_delegate) =
 50            TransportDelegate::start(&binary, cx.clone()).await?;
 51        let this = Self {
 52            id,
 53            binary,
 54            transport_delegate,
 55            sequence_count: AtomicU64::new(1),
 56        };
 57        log::info!("Successfully connected to debug adapter");
 58
 59        let client_id = this.id;
 60
 61        // start handling events/reverse requests
 62        cx.background_spawn(Self::handle_receive_messages(
 63            client_id,
 64            server_rx,
 65            server_tx.clone(),
 66            message_handler,
 67        ))
 68        .detach();
 69
 70        Ok(this)
 71    }
 72
 73    pub async fn reconnect(
 74        &self,
 75        session_id: SessionId,
 76        binary: DebugAdapterBinary,
 77        message_handler: DapMessageHandler,
 78        cx: AsyncApp,
 79    ) -> Result<Self> {
 80        let binary = match self.transport_delegate.transport() {
 81            crate::transport::Transport::Tcp(tcp_transport) => DebugAdapterBinary {
 82                command: binary.command,
 83                arguments: binary.arguments,
 84                envs: binary.envs,
 85                cwd: binary.cwd,
 86                connection: Some(crate::adapters::TcpArguments {
 87                    host: tcp_transport.host,
 88                    port: tcp_transport.port,
 89                    timeout: Some(tcp_transport.timeout),
 90                }),
 91                request_args: binary.request_args,
 92            },
 93            _ => self.binary.clone(),
 94        };
 95
 96        Self::start(session_id, binary, message_handler, cx).await
 97    }
 98
 99    async fn handle_receive_messages(
100        client_id: SessionId,
101        server_rx: Receiver<Message>,
102        client_tx: Sender<Message>,
103        mut message_handler: DapMessageHandler,
104    ) -> Result<()> {
105        let result = loop {
106            let message = match server_rx.recv().await {
107                Ok(message) => message,
108                Err(e) => break Err(e.into()),
109            };
110            match message {
111                Message::Event(ev) => {
112                    log::debug!("Client {} received event `{}`", client_id.0, &ev);
113
114                    message_handler(Message::Event(ev))
115                }
116                Message::Request(req) => {
117                    log::debug!(
118                        "Client {} received reverse request `{}`",
119                        client_id.0,
120                        &req.command
121                    );
122
123                    message_handler(Message::Request(req))
124                }
125                Message::Response(response) => {
126                    log::debug!("Received response after request timeout: {:#?}", response);
127                }
128            }
129
130            smol::future::yield_now().await;
131        };
132
133        drop(client_tx);
134
135        log::debug!("Handle receive messages dropped");
136
137        result
138    }
139
140    /// Send a request to an adapter and get a response back
141    /// Note: This function will block until a response is sent back from the adapter
142    pub async fn request<R: Request>(&self, arguments: R::Arguments) -> Result<R::Response> {
143        let serialized_arguments = serde_json::to_value(arguments)?;
144
145        let (callback_tx, callback_rx) = oneshot::channel::<Result<Response>>();
146
147        let sequence_id = self.next_sequence_id();
148
149        let request = crate::messages::Request {
150            seq: sequence_id,
151            command: R::COMMAND.to_string(),
152            arguments: Some(serialized_arguments),
153        };
154        self.transport_delegate
155            .add_pending_request(sequence_id, callback_tx)
156            .await;
157
158        log::debug!(
159            "Client {} send `{}` request with sequence_id: {}",
160            self.id.0,
161            R::COMMAND,
162            sequence_id
163        );
164
165        self.send_message(Message::Request(request)).await?;
166
167        let command = R::COMMAND.to_string();
168
169        let response = callback_rx.await??;
170        log::debug!(
171            "Client {} received response for: `{}` sequence_id: {}",
172            self.id.0,
173            command,
174            sequence_id
175        );
176        match response.success {
177            true => {
178                if let Some(json) = response.body {
179                    Ok(serde_json::from_value(json)?)
180                // Note: dap types configure themselves to return `None` when an empty object is received,
181                // which then fails here...
182                } else if let Ok(result) =
183                    serde_json::from_value(serde_json::Value::Object(Default::default()))
184                {
185                    Ok(result)
186                } else {
187                    Ok(serde_json::from_value(Default::default())?)
188                }
189            }
190            false => anyhow::bail!("Request failed: {}", response.message.unwrap_or_default()),
191        }
192    }
193
194    pub async fn send_message(&self, message: Message) -> Result<()> {
195        self.transport_delegate.send_message(message).await
196    }
197
198    pub fn id(&self) -> SessionId {
199        self.id
200    }
201
202    pub fn binary(&self) -> &DebugAdapterBinary {
203        &self.binary
204    }
205
206    /// Get the next sequence id to be used in a request
207    pub fn next_sequence_id(&self) -> u64 {
208        self.sequence_count.fetch_add(1, Ordering::Relaxed)
209    }
210
211    pub async fn shutdown(&self) -> Result<()> {
212        self.transport_delegate.shutdown().await
213    }
214
215    pub fn has_adapter_logs(&self) -> bool {
216        self.transport_delegate.has_adapter_logs()
217    }
218
219    pub fn add_log_handler<F>(&self, f: F, kind: LogKind)
220    where
221        F: 'static + Send + FnMut(IoKind, &str),
222    {
223        self.transport_delegate.add_log_handler(f, kind);
224    }
225
226    #[cfg(any(test, feature = "test-support"))]
227    pub fn on_request<R: dap_types::requests::Request, F>(&self, handler: F)
228    where
229        F: 'static
230            + Send
231            + FnMut(u64, R::Arguments) -> Result<R::Response, dap_types::ErrorResponse>,
232    {
233        let transport = self.transport_delegate.transport().as_fake();
234        transport.on_request::<R, F>(handler);
235    }
236
237    #[cfg(any(test, feature = "test-support"))]
238    pub async fn fake_reverse_request<R: dap_types::requests::Request>(&self, args: R::Arguments) {
239        self.send_message(Message::Request(dap_types::messages::Request {
240            seq: self.sequence_count.load(Ordering::Relaxed),
241            command: R::COMMAND.into(),
242            arguments: serde_json::to_value(args).ok(),
243        }))
244        .await
245        .unwrap();
246    }
247
248    #[cfg(any(test, feature = "test-support"))]
249    pub async fn on_response<R: dap_types::requests::Request, F>(&self, handler: F)
250    where
251        F: 'static + Send + Fn(Response),
252    {
253        let transport = self.transport_delegate.transport().as_fake();
254        transport.on_response::<R, F>(handler).await;
255    }
256
257    #[cfg(any(test, feature = "test-support"))]
258    pub async fn fake_event(&self, event: dap_types::messages::Events) {
259        self.send_message(Message::Event(Box::new(event)))
260            .await
261            .unwrap();
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use crate::{client::DebugAdapterClient, debugger_settings::DebuggerSettings};
269    use dap_types::{
270        Capabilities, InitializeRequestArguments, InitializeRequestArgumentsPathFormat,
271        RunInTerminalRequestArguments, StartDebuggingRequestArguments,
272        messages::Events,
273        requests::{Initialize, Request, RunInTerminal},
274    };
275    use gpui::TestAppContext;
276    use serde_json::json;
277    use settings::{Settings, SettingsStore};
278    use std::sync::{
279        Arc,
280        atomic::{AtomicBool, Ordering},
281    };
282
283    pub fn init_test(cx: &mut gpui::TestAppContext) {
284        zlog::init_test();
285
286        cx.update(|cx| {
287            let settings = SettingsStore::test(cx);
288            cx.set_global(settings);
289            DebuggerSettings::register(cx);
290        });
291    }
292
293    #[gpui::test]
294    pub async fn test_initialize_client(cx: &mut TestAppContext) {
295        init_test(cx);
296
297        let client = DebugAdapterClient::start(
298            crate::client::SessionId(1),
299            DebugAdapterBinary {
300                command: "command".into(),
301                arguments: Default::default(),
302                envs: Default::default(),
303                connection: None,
304                cwd: None,
305                request_args: StartDebuggingRequestArguments {
306                    configuration: serde_json::Value::Null,
307                    request: dap_types::StartDebuggingRequestArgumentsRequest::Launch,
308                },
309            },
310            Box::new(|_| panic!("Did not expect to hit this code path")),
311            cx.to_async(),
312        )
313        .await
314        .unwrap();
315
316        client.on_request::<Initialize, _>(move |_, _| {
317            Ok(dap_types::Capabilities {
318                supports_configuration_done_request: Some(true),
319                ..Default::default()
320            })
321        });
322
323        cx.run_until_parked();
324
325        let response = client
326            .request::<Initialize>(InitializeRequestArguments {
327                client_id: Some("zed".to_owned()),
328                client_name: Some("Zed".to_owned()),
329                adapter_id: "fake-adapter".to_owned(),
330                locale: Some("en-US".to_owned()),
331                path_format: Some(InitializeRequestArgumentsPathFormat::Path),
332                supports_variable_type: Some(true),
333                supports_variable_paging: Some(false),
334                supports_run_in_terminal_request: Some(true),
335                supports_memory_references: Some(true),
336                supports_progress_reporting: Some(false),
337                supports_invalidated_event: Some(false),
338                lines_start_at1: Some(true),
339                columns_start_at1: Some(true),
340                supports_memory_event: Some(false),
341                supports_args_can_be_interpreted_by_shell: Some(false),
342                supports_start_debugging_request: Some(true),
343                supports_ansistyling: Some(false),
344            })
345            .await
346            .unwrap();
347
348        cx.run_until_parked();
349
350        assert_eq!(
351            dap_types::Capabilities {
352                supports_configuration_done_request: Some(true),
353                ..Default::default()
354            },
355            response
356        );
357
358        client.shutdown().await.unwrap();
359    }
360
361    #[gpui::test]
362    pub async fn test_calls_event_handler(cx: &mut TestAppContext) {
363        init_test(cx);
364
365        let called_event_handler = Arc::new(AtomicBool::new(false));
366
367        let client = DebugAdapterClient::start(
368            crate::client::SessionId(1),
369            DebugAdapterBinary {
370                command: "command".into(),
371                arguments: Default::default(),
372                envs: Default::default(),
373                connection: None,
374                cwd: None,
375                request_args: StartDebuggingRequestArguments {
376                    configuration: serde_json::Value::Null,
377                    request: dap_types::StartDebuggingRequestArgumentsRequest::Launch,
378                },
379            },
380            Box::new({
381                let called_event_handler = called_event_handler.clone();
382                move |event| {
383                    called_event_handler.store(true, Ordering::SeqCst);
384
385                    assert_eq!(
386                        Message::Event(Box::new(Events::Initialized(
387                            Some(Capabilities::default())
388                        ))),
389                        event
390                    );
391                }
392            }),
393            cx.to_async(),
394        )
395        .await
396        .unwrap();
397
398        cx.run_until_parked();
399
400        client
401            .fake_event(Events::Initialized(Some(Capabilities::default())))
402            .await;
403
404        cx.run_until_parked();
405
406        assert!(
407            called_event_handler.load(std::sync::atomic::Ordering::SeqCst),
408            "Event handler was not called"
409        );
410
411        client.shutdown().await.unwrap();
412    }
413
414    #[gpui::test]
415    pub async fn test_calls_event_handler_for_reverse_request(cx: &mut TestAppContext) {
416        init_test(cx);
417
418        let called_event_handler = Arc::new(AtomicBool::new(false));
419
420        let client = DebugAdapterClient::start(
421            crate::client::SessionId(1),
422            DebugAdapterBinary {
423                command: "command".into(),
424                arguments: Default::default(),
425                envs: Default::default(),
426                connection: None,
427                cwd: None,
428                request_args: dap_types::StartDebuggingRequestArguments {
429                    configuration: serde_json::Value::Null,
430                    request: dap_types::StartDebuggingRequestArgumentsRequest::Launch,
431                },
432            },
433            Box::new({
434                let called_event_handler = called_event_handler.clone();
435                move |event| {
436                    called_event_handler.store(true, Ordering::SeqCst);
437
438                    assert_eq!(
439                        Message::Request(dap_types::messages::Request {
440                            seq: 1,
441                            command: RunInTerminal::COMMAND.into(),
442                            arguments: Some(json!({
443                                "cwd": "/project/path/src",
444                                "args": ["node", "test.js"],
445                            }))
446                        }),
447                        event
448                    );
449                }
450            }),
451            cx.to_async(),
452        )
453        .await
454        .unwrap();
455
456        cx.run_until_parked();
457
458        client
459            .fake_reverse_request::<RunInTerminal>(RunInTerminalRequestArguments {
460                kind: None,
461                title: None,
462                cwd: "/project/path/src".into(),
463                args: vec!["node".into(), "test.js".into()],
464                env: None,
465                args_can_be_interpreted_by_shell: None,
466            })
467            .await;
468
469        cx.run_until_parked();
470
471        assert!(
472            called_event_handler.load(std::sync::atomic::Ordering::SeqCst),
473            "Event handler was not called"
474        );
475
476        client.shutdown().await.unwrap();
477    }
478}