1use crate::{
2 adapters::{DebugAdapterBinary, DebugAdapterName},
3 transport::{IoKind, LogKind, TransportDelegate},
4};
5use anyhow::{Result, anyhow};
6use dap_types::{
7 messages::{Message, Response},
8 requests::Request,
9};
10use futures::{FutureExt as _, channel::oneshot, select};
11use gpui::{AppContext, AsyncApp, BackgroundExecutor};
12use smol::channel::{Receiver, Sender};
13use std::{
14 hash::Hash,
15 sync::atomic::{AtomicU64, Ordering},
16 time::Duration,
17};
18
19#[cfg(any(test, feature = "test-support"))]
20const DAP_REQUEST_TIMEOUT: Duration = Duration::from_secs(2);
21
22#[cfg(not(any(test, feature = "test-support")))]
23const DAP_REQUEST_TIMEOUT: Duration = Duration::from_secs(12);
24
25#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
26#[repr(transparent)]
27pub struct SessionId(pub u32);
28
29impl SessionId {
30 pub fn from_proto(client_id: u64) -> Self {
31 Self(client_id as u32)
32 }
33
34 pub fn to_proto(&self) -> u64 {
35 self.0 as u64
36 }
37}
38
39/// Represents a connection to the debug adapter process, either via stdout/stdin or a socket.
40pub struct DebugAdapterClient {
41 id: SessionId,
42 name: DebugAdapterName,
43 sequence_count: AtomicU64,
44 binary: DebugAdapterBinary,
45 executor: BackgroundExecutor,
46 transport_delegate: TransportDelegate,
47}
48
49pub type DapMessageHandler = Box<dyn FnMut(Message) + 'static + Send + Sync>;
50
51impl DebugAdapterClient {
52 pub async fn start(
53 id: SessionId,
54 name: DebugAdapterName,
55 binary: DebugAdapterBinary,
56 message_handler: DapMessageHandler,
57 cx: AsyncApp,
58 ) -> Result<Self> {
59 let ((server_rx, server_tx), transport_delegate) =
60 TransportDelegate::start(&binary, cx.clone()).await?;
61 let this = Self {
62 id,
63 name,
64 binary,
65 transport_delegate,
66 sequence_count: AtomicU64::new(1),
67 executor: cx.background_executor().clone(),
68 };
69 log::info!("Successfully connected to debug adapter");
70
71 let client_id = this.id;
72
73 // start handling events/reverse requests
74 cx.background_spawn(Self::handle_receive_messages(
75 client_id,
76 server_rx,
77 server_tx.clone(),
78 message_handler,
79 ))
80 .detach();
81
82 Ok(this)
83 }
84
85 pub async fn reconnect(
86 &self,
87 session_id: SessionId,
88 binary: DebugAdapterBinary,
89 message_handler: DapMessageHandler,
90 cx: AsyncApp,
91 ) -> Result<Self> {
92 let binary = match self.transport_delegate.transport() {
93 crate::transport::Transport::Tcp(tcp_transport) => DebugAdapterBinary {
94 command: binary.command,
95 arguments: binary.arguments,
96 envs: binary.envs,
97 cwd: binary.cwd,
98 connection: Some(crate::adapters::TcpArguments {
99 host: tcp_transport.host,
100 port: tcp_transport.port,
101 timeout: Some(tcp_transport.timeout),
102 }),
103 },
104 _ => self.binary.clone(),
105 };
106
107 Self::start(session_id, self.name(), binary, message_handler, cx).await
108 }
109
110 async fn handle_receive_messages(
111 client_id: SessionId,
112 server_rx: Receiver<Message>,
113 client_tx: Sender<Message>,
114 mut message_handler: DapMessageHandler,
115 ) -> Result<()> {
116 let result = loop {
117 let message = match server_rx.recv().await {
118 Ok(message) => message,
119 Err(e) => break Err(e.into()),
120 };
121 match message {
122 Message::Event(ev) => {
123 log::debug!("Client {} received event `{}`", client_id.0, &ev);
124
125 message_handler(Message::Event(ev))
126 }
127 Message::Request(req) => {
128 log::debug!(
129 "Client {} received reverse request `{}`",
130 client_id.0,
131 &req.command
132 );
133
134 message_handler(Message::Request(req))
135 }
136 Message::Response(response) => {
137 log::debug!("Received response after request timeout: {:#?}", response);
138 }
139 }
140
141 smol::future::yield_now().await;
142 };
143
144 drop(client_tx);
145
146 log::debug!("Handle receive messages dropped");
147
148 result
149 }
150
151 /// Send a request to an adapter and get a response back
152 /// Note: This function will block until a response is sent back from the adapter
153 pub async fn request<R: Request>(&self, arguments: R::Arguments) -> Result<R::Response> {
154 let serialized_arguments = serde_json::to_value(arguments)?;
155
156 let (callback_tx, callback_rx) = oneshot::channel::<Result<Response>>();
157
158 let sequence_id = self.next_sequence_id();
159
160 let request = crate::messages::Request {
161 seq: sequence_id,
162 command: R::COMMAND.to_string(),
163 arguments: Some(serialized_arguments),
164 };
165 self.transport_delegate
166 .add_pending_request(sequence_id, callback_tx)
167 .await;
168
169 log::debug!(
170 "Client {} send `{}` request with sequence_id: {}",
171 self.id.0,
172 R::COMMAND,
173 sequence_id
174 );
175
176 self.send_message(Message::Request(request)).await?;
177
178 let mut timeout = self.executor.timer(DAP_REQUEST_TIMEOUT).fuse();
179 let command = R::COMMAND.to_string();
180
181 select! {
182 response = callback_rx.fuse() => {
183 log::debug!(
184 "Client {} received response for: `{}` sequence_id: {}",
185 self.id.0,
186 command,
187 sequence_id
188 );
189
190 let response = response??;
191 match response.success {
192 true => Ok(serde_json::from_value(response.body.unwrap_or_default())?),
193 false => Err(anyhow!("Request failed: {}", response.message.unwrap_or_default())),
194 }
195 }
196
197 _ = timeout => {
198 self.transport_delegate.cancel_pending_request(&sequence_id).await;
199 log::error!("Cancelled DAP request for {command:?} id {sequence_id} which took over {DAP_REQUEST_TIMEOUT:?}");
200 anyhow::bail!("DAP request timeout");
201 }
202 }
203 }
204
205 pub async fn send_message(&self, message: Message) -> Result<()> {
206 self.transport_delegate.send_message(message).await
207 }
208
209 pub fn id(&self) -> SessionId {
210 self.id
211 }
212
213 pub fn name(&self) -> DebugAdapterName {
214 self.name.clone()
215 }
216 pub fn binary(&self) -> &DebugAdapterBinary {
217 &self.binary
218 }
219
220 /// Get the next sequence id to be used in a request
221 pub fn next_sequence_id(&self) -> u64 {
222 self.sequence_count.fetch_add(1, Ordering::Relaxed)
223 }
224
225 pub async fn shutdown(&self) -> Result<()> {
226 self.transport_delegate.shutdown().await
227 }
228
229 pub fn has_adapter_logs(&self) -> bool {
230 self.transport_delegate.has_adapter_logs()
231 }
232
233 pub fn add_log_handler<F>(&self, f: F, kind: LogKind)
234 where
235 F: 'static + Send + FnMut(IoKind, &str),
236 {
237 self.transport_delegate.add_log_handler(f, kind);
238 }
239
240 #[cfg(any(test, feature = "test-support"))]
241 pub async fn on_request<R: dap_types::requests::Request, F>(&self, handler: F)
242 where
243 F: 'static
244 + Send
245 + FnMut(u64, R::Arguments) -> Result<R::Response, dap_types::ErrorResponse>,
246 {
247 let transport = self.transport_delegate.transport().as_fake();
248 transport.on_request::<R, F>(handler).await;
249 }
250
251 #[cfg(any(test, feature = "test-support"))]
252 pub async fn fake_reverse_request<R: dap_types::requests::Request>(&self, args: R::Arguments) {
253 self.send_message(Message::Request(dap_types::messages::Request {
254 seq: self.sequence_count.load(Ordering::Relaxed),
255 command: R::COMMAND.into(),
256 arguments: serde_json::to_value(args).ok(),
257 }))
258 .await
259 .unwrap();
260 }
261
262 #[cfg(any(test, feature = "test-support"))]
263 pub async fn on_response<R: dap_types::requests::Request, F>(&self, handler: F)
264 where
265 F: 'static + Send + Fn(Response),
266 {
267 let transport = self.transport_delegate.transport().as_fake();
268 transport.on_response::<R, F>(handler).await;
269 }
270
271 #[cfg(any(test, feature = "test-support"))]
272 pub async fn fake_event(&self, event: dap_types::messages::Events) {
273 self.send_message(Message::Event(Box::new(event)))
274 .await
275 .unwrap();
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use crate::{client::DebugAdapterClient, debugger_settings::DebuggerSettings};
283 use dap_types::{
284 Capabilities, InitializeRequestArguments, InitializeRequestArgumentsPathFormat,
285 RunInTerminalRequestArguments,
286 messages::Events,
287 requests::{Initialize, Request, RunInTerminal},
288 };
289 use gpui::TestAppContext;
290 use serde_json::json;
291 use settings::{Settings, SettingsStore};
292 use std::sync::{
293 Arc,
294 atomic::{AtomicBool, Ordering},
295 };
296
297 pub fn init_test(cx: &mut gpui::TestAppContext) {
298 if std::env::var("RUST_LOG").is_ok() {
299 env_logger::try_init().ok();
300 }
301
302 cx.update(|cx| {
303 let settings = SettingsStore::test(cx);
304 cx.set_global(settings);
305 DebuggerSettings::register(cx);
306 });
307 }
308
309 #[gpui::test]
310 pub async fn test_initialize_client(cx: &mut TestAppContext) {
311 init_test(cx);
312
313 let client = DebugAdapterClient::start(
314 crate::client::SessionId(1),
315 DebugAdapterName("adapter".into()),
316 DebugAdapterBinary {
317 command: "command".into(),
318 arguments: Default::default(),
319 envs: Default::default(),
320 connection: None,
321 cwd: None,
322 },
323 Box::new(|_| panic!("Did not expect to hit this code path")),
324 cx.to_async(),
325 )
326 .await
327 .unwrap();
328
329 client
330 .on_request::<Initialize, _>(move |_, _| {
331 Ok(dap_types::Capabilities {
332 supports_configuration_done_request: Some(true),
333 ..Default::default()
334 })
335 })
336 .await;
337
338 cx.run_until_parked();
339
340 let response = client
341 .request::<Initialize>(InitializeRequestArguments {
342 client_id: Some("zed".to_owned()),
343 client_name: Some("Zed".to_owned()),
344 adapter_id: "fake-adapter".to_owned(),
345 locale: Some("en-US".to_owned()),
346 path_format: Some(InitializeRequestArgumentsPathFormat::Path),
347 supports_variable_type: Some(true),
348 supports_variable_paging: Some(false),
349 supports_run_in_terminal_request: Some(true),
350 supports_memory_references: Some(true),
351 supports_progress_reporting: Some(false),
352 supports_invalidated_event: Some(false),
353 lines_start_at1: Some(true),
354 columns_start_at1: Some(true),
355 supports_memory_event: Some(false),
356 supports_args_can_be_interpreted_by_shell: Some(false),
357 supports_start_debugging_request: Some(true),
358 supports_ansistyling: Some(false),
359 })
360 .await
361 .unwrap();
362
363 cx.run_until_parked();
364
365 assert_eq!(
366 dap_types::Capabilities {
367 supports_configuration_done_request: Some(true),
368 ..Default::default()
369 },
370 response
371 );
372
373 client.shutdown().await.unwrap();
374 }
375
376 #[gpui::test]
377 pub async fn test_calls_event_handler(cx: &mut TestAppContext) {
378 init_test(cx);
379
380 let called_event_handler = Arc::new(AtomicBool::new(false));
381
382 let client = DebugAdapterClient::start(
383 crate::client::SessionId(1),
384 DebugAdapterName("adapter".into()),
385 DebugAdapterBinary {
386 command: "command".into(),
387 arguments: Default::default(),
388 envs: Default::default(),
389 connection: None,
390 cwd: None,
391 },
392 Box::new({
393 let called_event_handler = called_event_handler.clone();
394 move |event| {
395 called_event_handler.store(true, Ordering::SeqCst);
396
397 assert_eq!(
398 Message::Event(Box::new(Events::Initialized(
399 Some(Capabilities::default())
400 ))),
401 event
402 );
403 }
404 }),
405 cx.to_async(),
406 )
407 .await
408 .unwrap();
409
410 cx.run_until_parked();
411
412 client
413 .fake_event(Events::Initialized(Some(Capabilities::default())))
414 .await;
415
416 cx.run_until_parked();
417
418 assert!(
419 called_event_handler.load(std::sync::atomic::Ordering::SeqCst),
420 "Event handler was not called"
421 );
422
423 client.shutdown().await.unwrap();
424 }
425
426 #[gpui::test]
427 pub async fn test_calls_event_handler_for_reverse_request(cx: &mut TestAppContext) {
428 init_test(cx);
429
430 let called_event_handler = Arc::new(AtomicBool::new(false));
431
432 let client = DebugAdapterClient::start(
433 crate::client::SessionId(1),
434 DebugAdapterName("test-adapter".into()),
435 DebugAdapterBinary {
436 command: "command".into(),
437 arguments: Default::default(),
438 envs: Default::default(),
439 connection: None,
440 cwd: None,
441 },
442 Box::new({
443 let called_event_handler = called_event_handler.clone();
444 move |event| {
445 called_event_handler.store(true, Ordering::SeqCst);
446
447 assert_eq!(
448 Message::Request(dap_types::messages::Request {
449 seq: 1,
450 command: RunInTerminal::COMMAND.into(),
451 arguments: Some(json!({
452 "cwd": "/project/path/src",
453 "args": ["node", "test.js"],
454 }))
455 }),
456 event
457 );
458 }
459 }),
460 cx.to_async(),
461 )
462 .await
463 .unwrap();
464
465 cx.run_until_parked();
466
467 client
468 .fake_reverse_request::<RunInTerminal>(RunInTerminalRequestArguments {
469 kind: None,
470 title: None,
471 cwd: "/project/path/src".into(),
472 args: vec!["node".into(), "test.js".into()],
473 env: None,
474 args_can_be_interpreted_by_shell: None,
475 })
476 .await;
477
478 cx.run_until_parked();
479
480 assert!(
481 called_event_handler.load(std::sync::atomic::Ordering::SeqCst),
482 "Event handler was not called"
483 );
484
485 client.shutdown().await.unwrap();
486 }
487}