transport.rs

   1use anyhow::{Context as _, Result, anyhow, bail};
   2#[cfg(any(test, feature = "test-support"))]
   3use async_pipe::{PipeReader, PipeWriter};
   4use dap_types::{
   5    ErrorResponse,
   6    messages::{Message, Response},
   7};
   8use futures::{AsyncRead, AsyncReadExt as _, AsyncWrite, FutureExt as _, channel::oneshot, select};
   9use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Task};
  10use parking_lot::Mutex;
  11use proto::ErrorExt;
  12use settings::Settings as _;
  13use smallvec::SmallVec;
  14use smol::{
  15    channel::{Receiver, Sender, unbounded},
  16    io::{AsyncBufReadExt as _, AsyncWriteExt, BufReader},
  17    net::{TcpListener, TcpStream},
  18};
  19use std::{
  20    collections::HashMap,
  21    net::{Ipv4Addr, SocketAddrV4},
  22    process::Stdio,
  23    sync::Arc,
  24    time::Duration,
  25};
  26use task::TcpArgumentsTemplate;
  27use util::ConnectionResult;
  28
  29use crate::{
  30    adapters::{DebugAdapterBinary, TcpArguments},
  31    client::DapMessageHandler,
  32    debugger_settings::DebuggerSettings,
  33};
  34
  35pub(crate) type IoMessage = str;
  36pub(crate) type Command = str;
  37pub type IoHandler = Box<dyn Send + FnMut(IoKind, Option<&Command>, &IoMessage)>;
  38
  39#[derive(PartialEq, Eq, Clone, Copy)]
  40pub enum LogKind {
  41    Adapter,
  42    Rpc,
  43}
  44
  45#[derive(Clone, Copy)]
  46pub enum IoKind {
  47    StdIn,
  48    StdOut,
  49    StdErr,
  50}
  51
  52#[cfg(any(test, feature = "test-support"))]
  53pub enum RequestHandling<T> {
  54    Respond(T),
  55    Exit,
  56}
  57
  58type LogHandlers = Arc<Mutex<SmallVec<[(LogKind, IoHandler); 2]>>>;
  59
  60pub trait Transport: Send + Sync {
  61    fn has_adapter_logs(&self) -> bool;
  62    fn tcp_arguments(&self) -> Option<TcpArguments>;
  63    fn connect(
  64        &mut self,
  65    ) -> Task<
  66        Result<(
  67            Box<dyn AsyncWrite + Unpin + Send + 'static>,
  68            Box<dyn AsyncRead + Unpin + Send + 'static>,
  69        )>,
  70    >;
  71    fn kill(&mut self);
  72    #[cfg(any(test, feature = "test-support"))]
  73    fn as_fake(&self) -> &FakeTransport {
  74        unreachable!()
  75    }
  76}
  77
  78async fn start(
  79    binary: &DebugAdapterBinary,
  80    log_handlers: LogHandlers,
  81    cx: &mut AsyncApp,
  82) -> Result<Box<dyn Transport>> {
  83    #[cfg(any(test, feature = "test-support"))]
  84    if cfg!(any(test, feature = "test-support")) {
  85        if let Some(connection) = binary.connection.clone() {
  86            return Ok(Box::new(FakeTransport::start_tcp(connection, cx).await?));
  87        } else {
  88            return Ok(Box::new(FakeTransport::start_stdio(cx).await?));
  89        }
  90    }
  91
  92    if binary.connection.is_some() {
  93        Ok(Box::new(
  94            TcpTransport::start(binary, log_handlers, cx).await?,
  95        ))
  96    } else {
  97        Ok(Box::new(
  98            StdioTransport::start(binary, log_handlers, cx).await?,
  99        ))
 100    }
 101}
 102
 103pub(crate) struct PendingRequests {
 104    inner: Option<HashMap<u64, oneshot::Sender<Result<Response>>>>,
 105}
 106
 107impl PendingRequests {
 108    fn new() -> Self {
 109        Self {
 110            inner: Some(HashMap::default()),
 111        }
 112    }
 113
 114    fn flush(&mut self, e: anyhow::Error) {
 115        let Some(inner) = self.inner.as_mut() else {
 116            return;
 117        };
 118        for (_, sender) in inner.drain() {
 119            sender.send(Err(e.cloned())).ok();
 120        }
 121    }
 122
 123    pub(crate) fn insert(
 124        &mut self,
 125        sequence_id: u64,
 126        callback_tx: oneshot::Sender<Result<Response>>,
 127    ) -> anyhow::Result<()> {
 128        let Some(inner) = self.inner.as_mut() else {
 129            bail!("client is closed")
 130        };
 131        inner.insert(sequence_id, callback_tx);
 132        Ok(())
 133    }
 134
 135    pub(crate) fn remove(
 136        &mut self,
 137        sequence_id: u64,
 138    ) -> anyhow::Result<Option<oneshot::Sender<Result<Response>>>> {
 139        let Some(inner) = self.inner.as_mut() else {
 140            bail!("client is closed");
 141        };
 142        Ok(inner.remove(&sequence_id))
 143    }
 144
 145    pub(crate) fn shutdown(&mut self) {
 146        self.flush(anyhow!("transport shutdown"));
 147        self.inner = None;
 148    }
 149}
 150
 151pub(crate) struct TransportDelegate {
 152    log_handlers: LogHandlers,
 153    pub(crate) pending_requests: Arc<Mutex<PendingRequests>>,
 154    pub(crate) transport: Mutex<Box<dyn Transport>>,
 155    pub(crate) server_tx: smol::lock::Mutex<Option<Sender<Message>>>,
 156    tasks: Mutex<Vec<Task<()>>>,
 157}
 158
 159impl TransportDelegate {
 160    pub(crate) async fn start(binary: &DebugAdapterBinary, cx: &mut AsyncApp) -> Result<Self> {
 161        let log_handlers: LogHandlers = Default::default();
 162        let transport = start(binary, log_handlers.clone(), cx).await?;
 163        Ok(Self {
 164            transport: Mutex::new(transport),
 165            log_handlers,
 166            server_tx: Default::default(),
 167            pending_requests: Arc::new(Mutex::new(PendingRequests::new())),
 168            tasks: Default::default(),
 169        })
 170    }
 171
 172    pub async fn connect(
 173        &self,
 174        message_handler: DapMessageHandler,
 175        cx: &mut AsyncApp,
 176    ) -> Result<()> {
 177        let (server_tx, client_rx) = unbounded::<Message>();
 178        self.tasks.lock().clear();
 179
 180        let log_dap_communications =
 181            cx.update(|cx| DebuggerSettings::get_global(cx).log_dap_communications)
 182                .with_context(|| "Failed to get Debugger Setting log dap communications error in transport::start_handlers. Defaulting to false")
 183                .unwrap_or(false);
 184
 185        let connect = self.transport.lock().connect();
 186        let (input, output) = connect.await?;
 187
 188        let log_handler = if log_dap_communications {
 189            Some(self.log_handlers.clone())
 190        } else {
 191            None
 192        };
 193
 194        let pending_requests = self.pending_requests.clone();
 195        let output_log_handler = log_handler.clone();
 196        {
 197            let mut tasks = self.tasks.lock();
 198            tasks.push(cx.background_spawn(async move {
 199                match Self::recv_from_server(
 200                    output,
 201                    message_handler,
 202                    pending_requests.clone(),
 203                    output_log_handler,
 204                )
 205                .await
 206                {
 207                    Ok(()) => {
 208                        pending_requests
 209                            .lock()
 210                            .flush(anyhow!("debugger shutdown unexpectedly"));
 211                    }
 212                    Err(e) => {
 213                        pending_requests.lock().flush(e);
 214                    }
 215                }
 216            }));
 217
 218            tasks.push(cx.background_spawn(async move {
 219                match Self::send_to_server(input, client_rx, log_handler).await {
 220                    Ok(()) => {}
 221                    Err(e) => log::error!("Error handling debugger input: {e}"),
 222                }
 223            }));
 224        }
 225
 226        {
 227            let mut lock = self.server_tx.lock().await;
 228            *lock = Some(server_tx.clone());
 229        }
 230
 231        Ok(())
 232    }
 233
 234    pub(crate) fn tcp_arguments(&self) -> Option<TcpArguments> {
 235        self.transport.lock().tcp_arguments()
 236    }
 237
 238    pub(crate) async fn send_message(&self, message: Message) -> Result<()> {
 239        if let Some(server_tx) = self.server_tx.lock().await.as_ref() {
 240            server_tx.send(message).await.context("sending message")
 241        } else {
 242            anyhow::bail!("Server tx already dropped")
 243        }
 244    }
 245
 246    async fn handle_adapter_log(
 247        stdout: impl AsyncRead + Unpin + Send + 'static,
 248        iokind: IoKind,
 249        log_handlers: LogHandlers,
 250    ) {
 251        let mut reader = BufReader::new(stdout);
 252        let mut line = String::new();
 253
 254        loop {
 255            line.truncate(0);
 256
 257            match reader.read_line(&mut line).await {
 258                Ok(0) => break,
 259                Ok(_) => {}
 260                Err(e) => {
 261                    log::debug!("handle_adapter_log: {}", e);
 262                    break;
 263                }
 264            }
 265
 266            // Clean up logs by trimming unnecessary whitespace/newlines before inserting into log.
 267            let line = line.trim();
 268
 269            log::debug!("stderr: {line}");
 270
 271            for (kind, handler) in log_handlers.lock().iter_mut() {
 272                if matches!(kind, LogKind::Adapter) {
 273                    handler(iokind, None, line);
 274                }
 275            }
 276        }
 277    }
 278
 279    fn build_rpc_message(message: String) -> String {
 280        format!("Content-Length: {}\r\n\r\n{}", message.len(), message)
 281    }
 282
 283    async fn send_to_server<Stdin>(
 284        mut server_stdin: Stdin,
 285        client_rx: Receiver<Message>,
 286        log_handlers: Option<LogHandlers>,
 287    ) -> Result<()>
 288    where
 289        Stdin: AsyncWrite + Unpin + Send + 'static,
 290    {
 291        let result = loop {
 292            match client_rx.recv().await {
 293                Ok(message) => {
 294                    let command = match &message {
 295                        Message::Request(request) => Some(request.command.as_str()),
 296                        Message::Response(response) => Some(response.command.as_str()),
 297                        _ => None,
 298                    };
 299
 300                    let message = match serde_json::to_string(&message) {
 301                        Ok(message) => message,
 302                        Err(e) => break Err(e.into()),
 303                    };
 304
 305                    if let Some(log_handlers) = log_handlers.as_ref() {
 306                        for (kind, log_handler) in log_handlers.lock().iter_mut() {
 307                            if matches!(kind, LogKind::Rpc) {
 308                                log_handler(IoKind::StdIn, command, &message);
 309                            }
 310                        }
 311                    }
 312
 313                    if let Err(e) = server_stdin
 314                        .write_all(Self::build_rpc_message(message).as_bytes())
 315                        .await
 316                    {
 317                        break Err(e.into());
 318                    }
 319
 320                    if let Err(e) = server_stdin.flush().await {
 321                        break Err(e.into());
 322                    }
 323                }
 324                Err(error) => break Err(error.into()),
 325            }
 326        };
 327
 328        log::debug!("Handle adapter input dropped");
 329
 330        result
 331    }
 332
 333    async fn recv_from_server<Stdout>(
 334        server_stdout: Stdout,
 335        mut message_handler: DapMessageHandler,
 336        pending_requests: Arc<Mutex<PendingRequests>>,
 337        log_handlers: Option<LogHandlers>,
 338    ) -> Result<()>
 339    where
 340        Stdout: AsyncRead + Unpin + Send + 'static,
 341    {
 342        let mut recv_buffer = String::new();
 343        let mut reader = BufReader::new(server_stdout);
 344
 345        let result = loop {
 346            let result =
 347                Self::receive_server_message(&mut reader, &mut recv_buffer, log_handlers.as_ref())
 348                    .await;
 349            match result {
 350                ConnectionResult::Timeout => anyhow::bail!("Timed out when connecting to debugger"),
 351                ConnectionResult::ConnectionReset => {
 352                    log::info!("Debugger closed the connection");
 353                    return Ok(());
 354                }
 355                ConnectionResult::Result(Ok(Message::Response(res))) => {
 356                    let tx = pending_requests.lock().remove(res.request_seq)?;
 357                    if let Some(tx) = tx {
 358                        if let Err(e) = tx.send(Self::process_response(res)) {
 359                            log::trace!("Did not send response `{:?}` for a cancelled", e);
 360                        }
 361                    } else {
 362                        message_handler(Message::Response(res))
 363                    }
 364                }
 365                ConnectionResult::Result(Ok(message)) => message_handler(message),
 366                ConnectionResult::Result(Err(e)) => break Err(e),
 367            }
 368        };
 369
 370        log::debug!("Handle adapter output dropped");
 371
 372        result
 373    }
 374
 375    fn process_response(response: Response) -> Result<Response> {
 376        if response.success {
 377            Ok(response)
 378        } else {
 379            if let Some(error_message) = response
 380                .body
 381                .clone()
 382                .and_then(|body| serde_json::from_value::<ErrorResponse>(body).ok())
 383                .and_then(|response| response.error.map(|msg| msg.format))
 384                .or_else(|| response.message.clone())
 385            {
 386                anyhow::bail!(error_message);
 387            };
 388
 389            anyhow::bail!(
 390                "Received error response from adapter. Response: {:?}",
 391                response
 392            );
 393        }
 394    }
 395
 396    async fn receive_server_message<Stdout>(
 397        reader: &mut BufReader<Stdout>,
 398        buffer: &mut String,
 399        log_handlers: Option<&LogHandlers>,
 400    ) -> ConnectionResult<Message>
 401    where
 402        Stdout: AsyncRead + Unpin + Send + 'static,
 403    {
 404        let mut content_length = None;
 405        loop {
 406            buffer.truncate(0);
 407            match reader.read_line(buffer).await {
 408                Ok(0) => return ConnectionResult::ConnectionReset,
 409                Ok(_) => {}
 410                Err(e) => return ConnectionResult::Result(Err(e.into())),
 411            };
 412
 413            if buffer == "\r\n" {
 414                break;
 415            }
 416
 417            if let Some(("Content-Length", value)) = buffer.trim().split_once(": ") {
 418                match value.parse().context("invalid content length") {
 419                    Ok(length) => content_length = Some(length),
 420                    Err(e) => return ConnectionResult::Result(Err(e)),
 421                }
 422            }
 423        }
 424
 425        let content_length = match content_length.context("missing content length") {
 426            Ok(length) => length,
 427            Err(e) => return ConnectionResult::Result(Err(e)),
 428        };
 429
 430        let mut content = vec![0; content_length];
 431        if let Err(e) = reader
 432            .read_exact(&mut content)
 433            .await
 434            .with_context(|| "reading after a loop")
 435        {
 436            return ConnectionResult::Result(Err(e));
 437        }
 438
 439        let message_str = match std::str::from_utf8(&content).context("invalid utf8 from server") {
 440            Ok(str) => str,
 441            Err(e) => return ConnectionResult::Result(Err(e)),
 442        };
 443
 444        let message =
 445            serde_json::from_str::<Message>(message_str).context("deserializing server message");
 446
 447        if let Some(log_handlers) = log_handlers {
 448            let command = match &message {
 449                Ok(Message::Request(request)) => Some(request.command.as_str()),
 450                Ok(Message::Response(response)) => Some(response.command.as_str()),
 451                _ => None,
 452            };
 453
 454            for (kind, log_handler) in log_handlers.lock().iter_mut() {
 455                if matches!(kind, LogKind::Rpc) {
 456                    log_handler(IoKind::StdOut, command, message_str);
 457                }
 458            }
 459        }
 460
 461        ConnectionResult::Result(message)
 462    }
 463
 464    pub fn has_adapter_logs(&self) -> bool {
 465        self.transport.lock().has_adapter_logs()
 466    }
 467
 468    pub fn add_log_handler<F>(&self, f: F, kind: LogKind)
 469    where
 470        F: 'static + Send + FnMut(IoKind, Option<&Command>, &IoMessage),
 471    {
 472        let mut log_handlers = self.log_handlers.lock();
 473        log_handlers.push((kind, Box::new(f)));
 474    }
 475}
 476
 477pub struct TcpTransport {
 478    executor: BackgroundExecutor,
 479    pub port: u16,
 480    pub host: Ipv4Addr,
 481    pub timeout: u64,
 482    process: Arc<Mutex<Option<Child>>>,
 483    _stderr_task: Option<Task<()>>,
 484    _stdout_task: Option<Task<()>>,
 485}
 486
 487impl TcpTransport {
 488    /// Get an open port to use with the tcp client when not supplied by debug config
 489    pub async fn port(host: &TcpArgumentsTemplate) -> Result<u16> {
 490        if let Some(port) = host.port {
 491            Ok(port)
 492        } else {
 493            Self::unused_port(host.host()).await
 494        }
 495    }
 496
 497    pub async fn unused_port(host: Ipv4Addr) -> Result<u16> {
 498        Ok(TcpListener::bind(SocketAddrV4::new(host, 0))
 499            .await?
 500            .local_addr()?
 501            .port())
 502    }
 503
 504    async fn start(
 505        binary: &DebugAdapterBinary,
 506        log_handlers: LogHandlers,
 507        cx: &mut AsyncApp,
 508    ) -> Result<Self> {
 509        let connection_args = binary
 510            .connection
 511            .as_ref()
 512            .context("No connection arguments provided")?;
 513
 514        let host = connection_args.host;
 515        let port = connection_args.port;
 516
 517        let mut process = None;
 518        let mut stdout_task = None;
 519        let mut stderr_task = None;
 520
 521        if let Some(command) = &binary.command {
 522            let mut command = util::command::new_std_command(&command);
 523
 524            if let Some(cwd) = &binary.cwd {
 525                command.current_dir(cwd);
 526            }
 527
 528            command.args(&binary.arguments);
 529            command.envs(&binary.envs);
 530
 531            let mut p = Child::spawn(command, Stdio::null())
 532                .with_context(|| "failed to start debug adapter.")?;
 533
 534            stdout_task = p.stdout.take().map(|stdout| {
 535                cx.background_executor()
 536                    .spawn(TransportDelegate::handle_adapter_log(
 537                        stdout,
 538                        IoKind::StdOut,
 539                        log_handlers.clone(),
 540                    ))
 541            });
 542            stderr_task = p.stderr.take().map(|stderr| {
 543                cx.background_executor()
 544                    .spawn(TransportDelegate::handle_adapter_log(
 545                        stderr,
 546                        IoKind::StdErr,
 547                        log_handlers,
 548                    ))
 549            });
 550            process = Some(p);
 551        };
 552
 553        let timeout = connection_args.timeout.unwrap_or_else(|| {
 554            cx.update(|cx| DebuggerSettings::get_global(cx).timeout)
 555                .unwrap_or(20000u64)
 556        });
 557
 558        log::info!(
 559            "Debug adapter has connected to TCP server {}:{}",
 560            host,
 561            port
 562        );
 563
 564        let this = Self {
 565            executor: cx.background_executor().clone(),
 566            port,
 567            host,
 568            process: Arc::new(Mutex::new(process)),
 569            timeout,
 570            _stdout_task: stdout_task,
 571            _stderr_task: stderr_task,
 572        };
 573
 574        Ok(this)
 575    }
 576}
 577
 578impl Transport for TcpTransport {
 579    fn has_adapter_logs(&self) -> bool {
 580        true
 581    }
 582
 583    fn kill(&mut self) {
 584        if let Some(process) = &mut *self.process.lock() {
 585            process.kill();
 586        }
 587    }
 588
 589    fn tcp_arguments(&self) -> Option<TcpArguments> {
 590        Some(TcpArguments {
 591            host: self.host,
 592            port: self.port,
 593            timeout: Some(self.timeout),
 594        })
 595    }
 596
 597    fn connect(
 598        &mut self,
 599    ) -> Task<
 600        Result<(
 601            Box<dyn AsyncWrite + Unpin + Send + 'static>,
 602            Box<dyn AsyncRead + Unpin + Send + 'static>,
 603        )>,
 604    > {
 605        let executor = self.executor.clone();
 606        let timeout = self.timeout;
 607        let address = SocketAddrV4::new(self.host, self.port);
 608        let process = self.process.clone();
 609        executor.clone().spawn(async move {
 610            select! {
 611                _ = executor.timer(Duration::from_millis(timeout)).fuse() => {
 612                    anyhow::bail!("Connection to TCP DAP timeout {address}");
 613                },
 614                result = executor.clone().spawn(async move {
 615                    loop {
 616                        match TcpStream::connect(address).await {
 617                            Ok(stream) => {
 618                                let (read, write) = stream.split();
 619                                return Ok((Box::new(write) as _, Box::new(read) as _))
 620                            },
 621                            Err(_) => {
 622                                let has_process = process.lock().is_some();
 623                                if has_process {
 624                                    let status = process.lock().as_mut().unwrap().try_status();
 625                                    if let Ok(Some(_)) = status {
 626                                        let process = process.lock().take().unwrap().into_inner();
 627                                        let output = process.output().await?;
 628                                        let output = if output.stderr.is_empty() {
 629                                            String::from_utf8_lossy(&output.stdout).to_string()
 630                                        } else {
 631                                            String::from_utf8_lossy(&output.stderr).to_string()
 632                                        };
 633                                        anyhow::bail!("{output}\nerror: process exited before debugger attached.");
 634                                    }
 635                                }
 636
 637                                executor.timer(Duration::from_millis(100)).await;
 638                            }
 639                        }
 640                    }
 641                }).fuse() => result
 642            }
 643        })
 644    }
 645}
 646
 647impl Drop for TcpTransport {
 648    fn drop(&mut self) {
 649        if let Some(mut p) = self.process.lock().take() {
 650            p.kill()
 651        }
 652    }
 653}
 654
 655pub struct StdioTransport {
 656    process: Mutex<Child>,
 657    _stderr_task: Option<Task<()>>,
 658}
 659
 660impl StdioTransport {
 661    // #[allow(dead_code, reason = "This is used in non test builds of Zed")]
 662    async fn start(
 663        binary: &DebugAdapterBinary,
 664        log_handlers: LogHandlers,
 665        cx: &mut AsyncApp,
 666    ) -> Result<Self> {
 667        let Some(binary_command) = &binary.command else {
 668            bail!(
 669                "When using the `stdio` transport, the path to a debug adapter binary must be set by Zed."
 670            );
 671        };
 672        let mut command = util::command::new_std_command(&binary_command);
 673
 674        if let Some(cwd) = &binary.cwd {
 675            command.current_dir(cwd);
 676        }
 677
 678        command.args(&binary.arguments);
 679        command.envs(&binary.envs);
 680
 681        let mut process = Child::spawn(command, Stdio::piped())?;
 682
 683        let _stderr_task = process.stderr.take().map(|stderr| {
 684            cx.background_spawn(TransportDelegate::handle_adapter_log(
 685                stderr,
 686                IoKind::StdErr,
 687                log_handlers,
 688            ))
 689        });
 690
 691        let process = Mutex::new(process);
 692
 693        Ok(Self {
 694            process,
 695            _stderr_task,
 696        })
 697    }
 698}
 699
 700impl Transport for StdioTransport {
 701    fn has_adapter_logs(&self) -> bool {
 702        true
 703    }
 704
 705    fn kill(&mut self) {
 706        self.process.lock().kill();
 707    }
 708
 709    fn connect(
 710        &mut self,
 711    ) -> Task<
 712        Result<(
 713            Box<dyn AsyncWrite + Unpin + Send + 'static>,
 714            Box<dyn AsyncRead + Unpin + Send + 'static>,
 715        )>,
 716    > {
 717        let result = util::maybe!({
 718            let mut process = self.process.lock();
 719            Ok((
 720                Box::new(process.stdin.take().context("Cannot reconnect")?) as _,
 721                Box::new(process.stdout.take().context("Cannot reconnect")?) as _,
 722            ))
 723        });
 724        Task::ready(result)
 725    }
 726
 727    fn tcp_arguments(&self) -> Option<TcpArguments> {
 728        None
 729    }
 730}
 731
 732impl Drop for StdioTransport {
 733    fn drop(&mut self) {
 734        self.process.lock().kill();
 735    }
 736}
 737
 738#[cfg(any(test, feature = "test-support"))]
 739type RequestHandler = Box<dyn Send + FnMut(u64, serde_json::Value) -> RequestHandling<Response>>;
 740
 741#[cfg(any(test, feature = "test-support"))]
 742type ResponseHandler = Box<dyn Send + Fn(Response)>;
 743
 744#[cfg(any(test, feature = "test-support"))]
 745pub struct FakeTransport {
 746    // for sending fake response back from adapter side
 747    request_handlers: Arc<Mutex<HashMap<&'static str, RequestHandler>>>,
 748    // for reverse request responses
 749    response_handlers: Arc<Mutex<HashMap<&'static str, ResponseHandler>>>,
 750    message_handler: Option<Task<Result<()>>>,
 751    kind: FakeTransportKind,
 752}
 753
 754#[cfg(any(test, feature = "test-support"))]
 755pub enum FakeTransportKind {
 756    Stdio {
 757        stdin_writer: Option<PipeWriter>,
 758        stdout_reader: Option<PipeReader>,
 759    },
 760    Tcp {
 761        connection: TcpArguments,
 762        executor: BackgroundExecutor,
 763    },
 764}
 765
 766#[cfg(any(test, feature = "test-support"))]
 767impl FakeTransport {
 768    pub fn on_request<R: dap_types::requests::Request, F>(&self, mut handler: F)
 769    where
 770        F: 'static
 771            + Send
 772            + FnMut(u64, R::Arguments) -> RequestHandling<Result<R::Response, ErrorResponse>>,
 773    {
 774        self.request_handlers.lock().insert(
 775            R::COMMAND,
 776            Box::new(move |seq, args| {
 777                let result = handler(seq, serde_json::from_value(args).unwrap());
 778                let RequestHandling::Respond(response) = result else {
 779                    return RequestHandling::Exit;
 780                };
 781                let response = match response {
 782                    Ok(response) => Response {
 783                        seq: seq + 1,
 784                        request_seq: seq,
 785                        success: true,
 786                        command: R::COMMAND.into(),
 787                        body: Some(serde_json::to_value(response).unwrap()),
 788                        message: None,
 789                    },
 790                    Err(response) => Response {
 791                        seq: seq + 1,
 792                        request_seq: seq,
 793                        success: false,
 794                        command: R::COMMAND.into(),
 795                        body: Some(serde_json::to_value(response).unwrap()),
 796                        message: None,
 797                    },
 798                };
 799                RequestHandling::Respond(response)
 800            }),
 801        );
 802    }
 803
 804    pub fn on_response<R: dap_types::requests::Request, F>(&self, handler: F)
 805    where
 806        F: 'static + Send + Fn(Response),
 807    {
 808        self.response_handlers
 809            .lock()
 810            .insert(R::COMMAND, Box::new(handler));
 811    }
 812
 813    async fn start_tcp(connection: TcpArguments, cx: &mut AsyncApp) -> Result<Self> {
 814        Ok(Self {
 815            request_handlers: Arc::new(Mutex::new(HashMap::default())),
 816            response_handlers: Arc::new(Mutex::new(HashMap::default())),
 817            message_handler: None,
 818            kind: FakeTransportKind::Tcp {
 819                connection,
 820                executor: cx.background_executor().clone(),
 821            },
 822        })
 823    }
 824
 825    async fn handle_messages(
 826        request_handlers: Arc<Mutex<HashMap<&'static str, RequestHandler>>>,
 827        response_handlers: Arc<Mutex<HashMap<&'static str, ResponseHandler>>>,
 828        stdin_reader: PipeReader,
 829        stdout_writer: PipeWriter,
 830    ) -> Result<()> {
 831        use dap_types::requests::{Request, RunInTerminal, StartDebugging};
 832        use serde_json::json;
 833
 834        let mut reader = BufReader::new(stdin_reader);
 835        let stdout_writer = Arc::new(smol::lock::Mutex::new(stdout_writer));
 836        let mut buffer = String::new();
 837
 838        loop {
 839            match TransportDelegate::receive_server_message(&mut reader, &mut buffer, None).await {
 840                ConnectionResult::Timeout => {
 841                    anyhow::bail!("Timed out when connecting to debugger");
 842                }
 843                ConnectionResult::ConnectionReset => {
 844                    log::info!("Debugger closed the connection");
 845                    break Ok(());
 846                }
 847                ConnectionResult::Result(Err(e)) => break Err(e),
 848                ConnectionResult::Result(Ok(message)) => {
 849                    match message {
 850                        Message::Request(request) => {
 851                            // redirect reverse requests to stdout writer/reader
 852                            if request.command == RunInTerminal::COMMAND
 853                                || request.command == StartDebugging::COMMAND
 854                            {
 855                                let message =
 856                                    serde_json::to_string(&Message::Request(request)).unwrap();
 857
 858                                let mut writer = stdout_writer.lock().await;
 859                                writer
 860                                    .write_all(
 861                                        TransportDelegate::build_rpc_message(message).as_bytes(),
 862                                    )
 863                                    .await
 864                                    .unwrap();
 865                                writer.flush().await.unwrap();
 866                            } else {
 867                                let response = if let Some(handle) =
 868                                    request_handlers.lock().get_mut(request.command.as_str())
 869                                {
 870                                    handle(request.seq, request.arguments.unwrap_or(json!({})))
 871                                } else {
 872                                    panic!("No request handler for {}", request.command);
 873                                };
 874                                let response = match response {
 875                                    RequestHandling::Respond(response) => response,
 876                                    RequestHandling::Exit => {
 877                                        break Err(anyhow!("exit in response to request"));
 878                                    }
 879                                };
 880                                let success = response.success;
 881                                let message =
 882                                    serde_json::to_string(&Message::Response(response)).unwrap();
 883
 884                                let mut writer = stdout_writer.lock().await;
 885                                writer
 886                                    .write_all(
 887                                        TransportDelegate::build_rpc_message(message).as_bytes(),
 888                                    )
 889                                    .await
 890                                    .unwrap();
 891
 892                                if request.command == dap_types::requests::Initialize::COMMAND
 893                                    && success
 894                                {
 895                                    let message = serde_json::to_string(&Message::Event(Box::new(
 896                                        dap_types::messages::Events::Initialized(Some(
 897                                            Default::default(),
 898                                        )),
 899                                    )))
 900                                    .unwrap();
 901                                    writer
 902                                        .write_all(
 903                                            TransportDelegate::build_rpc_message(message)
 904                                                .as_bytes(),
 905                                        )
 906                                        .await
 907                                        .unwrap();
 908                                }
 909
 910                                writer.flush().await.unwrap();
 911                            }
 912                        }
 913                        Message::Event(event) => {
 914                            let message = serde_json::to_string(&Message::Event(event)).unwrap();
 915
 916                            let mut writer = stdout_writer.lock().await;
 917                            writer
 918                                .write_all(TransportDelegate::build_rpc_message(message).as_bytes())
 919                                .await
 920                                .unwrap();
 921                            writer.flush().await.unwrap();
 922                        }
 923                        Message::Response(response) => {
 924                            if let Some(handle) =
 925                                response_handlers.lock().get(response.command.as_str())
 926                            {
 927                                handle(response);
 928                            } else {
 929                                log::error!("No response handler for {}", response.command);
 930                            }
 931                        }
 932                    }
 933                }
 934            }
 935        }
 936    }
 937
 938    async fn start_stdio(cx: &mut AsyncApp) -> Result<Self> {
 939        let (stdin_writer, stdin_reader) = async_pipe::pipe();
 940        let (stdout_writer, stdout_reader) = async_pipe::pipe();
 941        let kind = FakeTransportKind::Stdio {
 942            stdin_writer: Some(stdin_writer),
 943            stdout_reader: Some(stdout_reader),
 944        };
 945
 946        let mut this = Self {
 947            request_handlers: Arc::new(Mutex::new(HashMap::default())),
 948            response_handlers: Arc::new(Mutex::new(HashMap::default())),
 949            message_handler: None,
 950            kind,
 951        };
 952
 953        let request_handlers = this.request_handlers.clone();
 954        let response_handlers = this.response_handlers.clone();
 955
 956        this.message_handler = Some(cx.background_spawn(Self::handle_messages(
 957            request_handlers,
 958            response_handlers,
 959            stdin_reader,
 960            stdout_writer,
 961        )));
 962
 963        Ok(this)
 964    }
 965}
 966
 967#[cfg(any(test, feature = "test-support"))]
 968impl Transport for FakeTransport {
 969    fn tcp_arguments(&self) -> Option<TcpArguments> {
 970        match &self.kind {
 971            FakeTransportKind::Stdio { .. } => None,
 972            FakeTransportKind::Tcp { connection, .. } => Some(connection.clone()),
 973        }
 974    }
 975
 976    fn connect(
 977        &mut self,
 978    ) -> Task<
 979        Result<(
 980            Box<dyn AsyncWrite + Unpin + Send + 'static>,
 981            Box<dyn AsyncRead + Unpin + Send + 'static>,
 982        )>,
 983    > {
 984        let result = match &mut self.kind {
 985            FakeTransportKind::Stdio {
 986                stdin_writer,
 987                stdout_reader,
 988            } => util::maybe!({
 989                Ok((
 990                    Box::new(stdin_writer.take().context("Cannot reconnect")?) as _,
 991                    Box::new(stdout_reader.take().context("Cannot reconnect")?) as _,
 992                ))
 993            }),
 994            FakeTransportKind::Tcp { executor, .. } => {
 995                let (stdin_writer, stdin_reader) = async_pipe::pipe();
 996                let (stdout_writer, stdout_reader) = async_pipe::pipe();
 997
 998                let request_handlers = self.request_handlers.clone();
 999                let response_handlers = self.response_handlers.clone();
1000
1001                self.message_handler = Some(executor.spawn(Self::handle_messages(
1002                    request_handlers,
1003                    response_handlers,
1004                    stdin_reader,
1005                    stdout_writer,
1006                )));
1007
1008                Ok((Box::new(stdin_writer) as _, Box::new(stdout_reader) as _))
1009            }
1010        };
1011        Task::ready(result)
1012    }
1013
1014    fn has_adapter_logs(&self) -> bool {
1015        false
1016    }
1017
1018    fn kill(&mut self) {
1019        self.message_handler.take();
1020    }
1021
1022    #[cfg(any(test, feature = "test-support"))]
1023    fn as_fake(&self) -> &FakeTransport {
1024        self
1025    }
1026}
1027
1028struct Child {
1029    process: smol::process::Child,
1030}
1031
1032impl std::ops::Deref for Child {
1033    type Target = smol::process::Child;
1034
1035    fn deref(&self) -> &Self::Target {
1036        &self.process
1037    }
1038}
1039
1040impl std::ops::DerefMut for Child {
1041    fn deref_mut(&mut self) -> &mut Self::Target {
1042        &mut self.process
1043    }
1044}
1045
1046impl Child {
1047    fn into_inner(self) -> smol::process::Child {
1048        self.process
1049    }
1050
1051    #[cfg(not(windows))]
1052    fn spawn(mut command: std::process::Command, stdin: Stdio) -> Result<Self> {
1053        util::set_pre_exec_to_start_new_session(&mut command);
1054        let mut command = smol::process::Command::from(command);
1055        let process = command
1056            .stdin(stdin)
1057            .stdout(Stdio::piped())
1058            .stderr(Stdio::piped())
1059            .spawn()
1060            .with_context(|| format!("failed to spawn command `{command:?}`",))?;
1061        Ok(Self { process })
1062    }
1063
1064    #[cfg(windows)]
1065    fn spawn(command: std::process::Command, stdin: Stdio) -> Result<Self> {
1066        // TODO(windows): create a job object and add the child process handle to it,
1067        // see https://learn.microsoft.com/en-us/windows/win32/procthread/job-objects
1068        let mut command = smol::process::Command::from(command);
1069        let process = command
1070            .stdin(stdin)
1071            .stdout(Stdio::piped())
1072            .stderr(Stdio::piped())
1073            .spawn()
1074            .with_context(|| format!("failed to spawn command `{command:?}`",))?;
1075        Ok(Self { process })
1076    }
1077
1078    #[cfg(not(windows))]
1079    fn kill(&mut self) {
1080        let pid = self.process.id();
1081        unsafe {
1082            libc::killpg(pid as i32, libc::SIGKILL);
1083        }
1084    }
1085
1086    #[cfg(windows)]
1087    fn kill(&mut self) {
1088        // TODO(windows): terminate the job object in kill
1089        let _ = self.process.kill();
1090    }
1091}