1use crate::{
2 adapters::{DebugAdapterBinary, DebugAdapterName},
3 transport::{IoKind, LogKind, TransportDelegate},
4};
5use anyhow::{anyhow, Result};
6use dap_types::{
7 messages::{Message, Response},
8 requests::Request,
9};
10use futures::{channel::oneshot, select, FutureExt as _};
11use gpui::{AppContext, AsyncApp, BackgroundExecutor};
12use smol::channel::{Receiver, Sender};
13use std::{
14 hash::Hash,
15 sync::atomic::{AtomicU64, Ordering},
16 time::Duration,
17};
18
19#[cfg(any(test, feature = "test-support"))]
20const DAP_REQUEST_TIMEOUT: Duration = Duration::from_secs(2);
21
22#[cfg(not(any(test, feature = "test-support")))]
23const DAP_REQUEST_TIMEOUT: Duration = Duration::from_secs(12);
24
25#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
26#[repr(transparent)]
27pub struct SessionId(pub u32);
28
29impl SessionId {
30 pub fn from_proto(client_id: u64) -> Self {
31 Self(client_id as u32)
32 }
33
34 pub fn to_proto(&self) -> u64 {
35 self.0 as u64
36 }
37}
38
39/// Represents a connection to the debug adapter process, either via stdout/stdin or a socket.
40pub struct DebugAdapterClient {
41 id: SessionId,
42 name: DebugAdapterName,
43 sequence_count: AtomicU64,
44 binary: DebugAdapterBinary,
45 executor: BackgroundExecutor,
46 transport_delegate: TransportDelegate,
47}
48
49pub type DapMessageHandler = Box<dyn FnMut(Message) + 'static + Send + Sync>;
50
51impl DebugAdapterClient {
52 pub async fn start(
53 id: SessionId,
54 name: DebugAdapterName,
55 binary: DebugAdapterBinary,
56 message_handler: DapMessageHandler,
57 cx: AsyncApp,
58 ) -> Result<Self> {
59 let ((server_rx, server_tx), transport_delegate) =
60 TransportDelegate::start(&binary, cx.clone()).await?;
61 let this = Self {
62 id,
63 name,
64 binary,
65 transport_delegate,
66 sequence_count: AtomicU64::new(1),
67 executor: cx.background_executor().clone(),
68 };
69 log::info!("Successfully connected to debug adapter");
70
71 let client_id = this.id;
72
73 // start handling events/reverse requests
74
75 cx.background_spawn(Self::handle_receive_messages(
76 client_id,
77 server_rx,
78 server_tx.clone(),
79 message_handler,
80 ))
81 .detach();
82
83 Ok(this)
84 }
85
86 pub async fn reconnect(
87 &self,
88 session_id: SessionId,
89 binary: DebugAdapterBinary,
90 message_handler: DapMessageHandler,
91 cx: AsyncApp,
92 ) -> Result<Self> {
93 let binary = match self.transport_delegate.transport() {
94 crate::transport::Transport::Tcp(tcp_transport) => DebugAdapterBinary {
95 command: binary.command,
96 arguments: binary.arguments,
97 envs: binary.envs,
98 cwd: binary.cwd,
99 connection: Some(crate::adapters::TcpArguments {
100 host: tcp_transport.host,
101 port: tcp_transport.port,
102 timeout: Some(tcp_transport.timeout),
103 }),
104 },
105 _ => self.binary.clone(),
106 };
107
108 Self::start(session_id, self.name(), binary, message_handler, cx).await
109 }
110
111 async fn handle_receive_messages(
112 client_id: SessionId,
113 server_rx: Receiver<Message>,
114 client_tx: Sender<Message>,
115 mut message_handler: DapMessageHandler,
116 ) -> Result<()> {
117 let result = loop {
118 let message = match server_rx.recv().await {
119 Ok(message) => message,
120 Err(e) => break Err(e.into()),
121 };
122
123 match message {
124 Message::Event(ev) => {
125 log::debug!("Client {} received event `{}`", client_id.0, &ev);
126
127 message_handler(Message::Event(ev))
128 }
129 Message::Request(req) => {
130 log::debug!(
131 "Client {} received reverse request `{}`",
132 client_id.0,
133 &req.command
134 );
135
136 message_handler(Message::Request(req))
137 }
138 Message::Response(response) => {
139 log::debug!("Received response after request timeout: {:#?}", response);
140 }
141 }
142
143 smol::future::yield_now().await;
144 };
145
146 drop(client_tx);
147
148 log::debug!("Handle receive messages dropped");
149
150 result
151 }
152
153 /// Send a request to an adapter and get a response back
154 /// Note: This function will block until a response is sent back from the adapter
155 pub async fn request<R: Request>(&self, arguments: R::Arguments) -> Result<R::Response> {
156 let serialized_arguments = serde_json::to_value(arguments)?;
157
158 let (callback_tx, callback_rx) = oneshot::channel::<Result<Response>>();
159
160 let sequence_id = self.next_sequence_id();
161
162 let request = crate::messages::Request {
163 seq: sequence_id,
164 command: R::COMMAND.to_string(),
165 arguments: Some(serialized_arguments),
166 };
167
168 self.transport_delegate
169 .add_pending_request(sequence_id, callback_tx)
170 .await;
171
172 log::debug!(
173 "Client {} send `{}` request with sequence_id: {}",
174 self.id.0,
175 R::COMMAND.to_string(),
176 sequence_id
177 );
178
179 self.send_message(Message::Request(request)).await?;
180
181 let mut timeout = self.executor.timer(DAP_REQUEST_TIMEOUT).fuse();
182 let command = R::COMMAND.to_string();
183
184 select! {
185 response = callback_rx.fuse() => {
186 log::debug!(
187 "Client {} received response for: `{}` sequence_id: {}",
188 self.id.0,
189 command,
190 sequence_id
191 );
192
193 let response = response??;
194 match response.success {
195 true => Ok(serde_json::from_value(response.body.unwrap_or_default())?),
196 false => Err(anyhow!("Request failed: {}", response.message.unwrap_or_default())),
197 }
198 }
199
200 _ = timeout => {
201 self.transport_delegate.cancel_pending_request(&sequence_id).await;
202 log::error!("Cancelled DAP request for {command:?} id {sequence_id} which took over {DAP_REQUEST_TIMEOUT:?}");
203 anyhow::bail!("DAP request timeout");
204 }
205 }
206 }
207
208 pub async fn send_message(&self, message: Message) -> Result<()> {
209 self.transport_delegate.send_message(message).await
210 }
211
212 pub fn id(&self) -> SessionId {
213 self.id
214 }
215
216 pub fn name(&self) -> DebugAdapterName {
217 self.name.clone()
218 }
219 pub fn binary(&self) -> &DebugAdapterBinary {
220 &self.binary
221 }
222
223 /// Get the next sequence id to be used in a request
224 pub fn next_sequence_id(&self) -> u64 {
225 self.sequence_count.fetch_add(1, Ordering::Relaxed)
226 }
227
228 pub async fn shutdown(&self) -> Result<()> {
229 self.transport_delegate.shutdown().await
230 }
231
232 pub fn has_adapter_logs(&self) -> bool {
233 self.transport_delegate.has_adapter_logs()
234 }
235
236 pub fn add_log_handler<F>(&self, f: F, kind: LogKind)
237 where
238 F: 'static + Send + FnMut(IoKind, &str),
239 {
240 self.transport_delegate.add_log_handler(f, kind);
241 }
242
243 #[cfg(any(test, feature = "test-support"))]
244 pub async fn on_request<R: dap_types::requests::Request, F>(&self, handler: F)
245 where
246 F: 'static
247 + Send
248 + FnMut(u64, R::Arguments) -> Result<R::Response, dap_types::ErrorResponse>,
249 {
250 let transport = self.transport_delegate.transport().as_fake();
251 transport.on_request::<R, F>(handler).await;
252 }
253
254 #[cfg(any(test, feature = "test-support"))]
255 pub async fn fake_reverse_request<R: dap_types::requests::Request>(&self, args: R::Arguments) {
256 self.send_message(Message::Request(dap_types::messages::Request {
257 seq: self.sequence_count.load(Ordering::Relaxed),
258 command: R::COMMAND.into(),
259 arguments: serde_json::to_value(args).ok(),
260 }))
261 .await
262 .unwrap();
263 }
264
265 #[cfg(any(test, feature = "test-support"))]
266 pub async fn on_response<R: dap_types::requests::Request, F>(&self, handler: F)
267 where
268 F: 'static + Send + Fn(Response),
269 {
270 let transport = self.transport_delegate.transport().as_fake();
271 transport.on_response::<R, F>(handler).await;
272 }
273
274 #[cfg(any(test, feature = "test-support"))]
275 pub async fn fake_event(&self, event: dap_types::messages::Events) {
276 self.send_message(Message::Event(Box::new(event)))
277 .await
278 .unwrap();
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use crate::{client::DebugAdapterClient, debugger_settings::DebuggerSettings};
286 use dap_types::{
287 messages::Events,
288 requests::{Initialize, Request, RunInTerminal},
289 Capabilities, InitializeRequestArguments, InitializeRequestArgumentsPathFormat,
290 RunInTerminalRequestArguments,
291 };
292 use gpui::TestAppContext;
293 use serde_json::json;
294 use settings::{Settings, SettingsStore};
295 use std::sync::{
296 atomic::{AtomicBool, Ordering},
297 Arc,
298 };
299
300 pub fn init_test(cx: &mut gpui::TestAppContext) {
301 if std::env::var("RUST_LOG").is_ok() {
302 env_logger::try_init().ok();
303 }
304
305 cx.update(|cx| {
306 let settings = SettingsStore::test(cx);
307 cx.set_global(settings);
308 DebuggerSettings::register(cx);
309 });
310 }
311
312 #[gpui::test]
313 pub async fn test_initialize_client(cx: &mut TestAppContext) {
314 init_test(cx);
315
316 let client = DebugAdapterClient::start(
317 crate::client::SessionId(1),
318 DebugAdapterName("adapter".into()),
319 DebugAdapterBinary {
320 command: "command".into(),
321 arguments: Default::default(),
322 envs: Default::default(),
323 connection: None,
324 cwd: None,
325 },
326 Box::new(|_| panic!("Did not expect to hit this code path")),
327 cx.to_async(),
328 )
329 .await
330 .unwrap();
331
332 client
333 .on_request::<Initialize, _>(move |_, _| {
334 Ok(dap_types::Capabilities {
335 supports_configuration_done_request: Some(true),
336 ..Default::default()
337 })
338 })
339 .await;
340
341 cx.run_until_parked();
342
343 let response = client
344 .request::<Initialize>(InitializeRequestArguments {
345 client_id: Some("zed".to_owned()),
346 client_name: Some("Zed".to_owned()),
347 adapter_id: "fake-adapter".to_owned(),
348 locale: Some("en-US".to_owned()),
349 path_format: Some(InitializeRequestArgumentsPathFormat::Path),
350 supports_variable_type: Some(true),
351 supports_variable_paging: Some(false),
352 supports_run_in_terminal_request: Some(true),
353 supports_memory_references: Some(true),
354 supports_progress_reporting: Some(false),
355 supports_invalidated_event: Some(false),
356 lines_start_at1: Some(true),
357 columns_start_at1: Some(true),
358 supports_memory_event: Some(false),
359 supports_args_can_be_interpreted_by_shell: Some(false),
360 supports_start_debugging_request: Some(true),
361 supports_ansistyling: Some(false),
362 })
363 .await
364 .unwrap();
365
366 cx.run_until_parked();
367
368 assert_eq!(
369 dap_types::Capabilities {
370 supports_configuration_done_request: Some(true),
371 ..Default::default()
372 },
373 response
374 );
375
376 client.shutdown().await.unwrap();
377 }
378
379 #[gpui::test]
380 pub async fn test_calls_event_handler(cx: &mut TestAppContext) {
381 init_test(cx);
382
383 let called_event_handler = Arc::new(AtomicBool::new(false));
384
385 let client = DebugAdapterClient::start(
386 crate::client::SessionId(1),
387 DebugAdapterName("adapter".into()),
388 DebugAdapterBinary {
389 command: "command".into(),
390 arguments: Default::default(),
391 envs: Default::default(),
392 connection: None,
393 cwd: None,
394 },
395 Box::new({
396 let called_event_handler = called_event_handler.clone();
397 move |event| {
398 called_event_handler.store(true, Ordering::SeqCst);
399
400 assert_eq!(
401 Message::Event(Box::new(Events::Initialized(
402 Some(Capabilities::default())
403 ))),
404 event
405 );
406 }
407 }),
408 cx.to_async(),
409 )
410 .await
411 .unwrap();
412
413 cx.run_until_parked();
414
415 client
416 .fake_event(Events::Initialized(Some(Capabilities::default())))
417 .await;
418
419 cx.run_until_parked();
420
421 assert!(
422 called_event_handler.load(std::sync::atomic::Ordering::SeqCst),
423 "Event handler was not called"
424 );
425
426 client.shutdown().await.unwrap();
427 }
428
429 #[gpui::test]
430 pub async fn test_calls_event_handler_for_reverse_request(cx: &mut TestAppContext) {
431 init_test(cx);
432
433 let called_event_handler = Arc::new(AtomicBool::new(false));
434
435 let client = DebugAdapterClient::start(
436 crate::client::SessionId(1),
437 DebugAdapterName(Arc::from("test-adapter")),
438 DebugAdapterBinary {
439 command: "command".into(),
440 arguments: Default::default(),
441 envs: Default::default(),
442 connection: None,
443 cwd: None,
444 },
445 Box::new({
446 let called_event_handler = called_event_handler.clone();
447 move |event| {
448 called_event_handler.store(true, Ordering::SeqCst);
449
450 assert_eq!(
451 Message::Request(dap_types::messages::Request {
452 seq: 1,
453 command: RunInTerminal::COMMAND.into(),
454 arguments: Some(json!({
455 "cwd": "/project/path/src",
456 "args": ["node", "test.js"],
457 }))
458 }),
459 event
460 );
461 }
462 }),
463 cx.to_async(),
464 )
465 .await
466 .unwrap();
467
468 cx.run_until_parked();
469
470 client
471 .fake_reverse_request::<RunInTerminal>(RunInTerminalRequestArguments {
472 kind: None,
473 title: None,
474 cwd: "/project/path/src".into(),
475 args: vec!["node".into(), "test.js".into()],
476 env: None,
477 args_can_be_interpreted_by_shell: None,
478 })
479 .await;
480
481 cx.run_until_parked();
482
483 assert!(
484 called_event_handler.load(std::sync::atomic::Ordering::SeqCst),
485 "Event handler was not called"
486 );
487
488 client.shutdown().await.unwrap();
489 }
490}