transport.rs

   1use anyhow::{Context as _, Result, anyhow, bail};
   2use dap_types::{
   3    ErrorResponse,
   4    messages::{Message, Response},
   5};
   6use futures::{AsyncRead, AsyncReadExt as _, AsyncWrite, FutureExt as _, channel::oneshot, select};
   7use gpui::{AppContext as _, AsyncApp, Task};
   8use settings::Settings as _;
   9use smallvec::SmallVec;
  10use smol::{
  11    channel::{Receiver, Sender, unbounded},
  12    io::{AsyncBufReadExt as _, AsyncWriteExt, BufReader},
  13    lock::Mutex,
  14    net::{TcpListener, TcpStream},
  15};
  16use std::{
  17    collections::HashMap,
  18    net::{Ipv4Addr, SocketAddrV4},
  19    process::Stdio,
  20    sync::Arc,
  21    time::Duration,
  22};
  23use task::TcpArgumentsTemplate;
  24use util::ConnectionResult;
  25
  26use crate::{adapters::DebugAdapterBinary, debugger_settings::DebuggerSettings};
  27
  28pub type IoHandler = Box<dyn Send + FnMut(IoKind, &str)>;
  29
  30#[derive(PartialEq, Eq, Clone, Copy)]
  31pub enum LogKind {
  32    Adapter,
  33    Rpc,
  34}
  35
  36pub enum IoKind {
  37    StdIn,
  38    StdOut,
  39    StdErr,
  40}
  41
  42pub struct TransportPipe {
  43    input: Box<dyn AsyncWrite + Unpin + Send + 'static>,
  44    output: Box<dyn AsyncRead + Unpin + Send + 'static>,
  45    stdout: Option<Box<dyn AsyncRead + Unpin + Send + 'static>>,
  46    stderr: Option<Box<dyn AsyncRead + Unpin + Send + 'static>>,
  47}
  48
  49impl TransportPipe {
  50    pub fn new(
  51        input: Box<dyn AsyncWrite + Unpin + Send + 'static>,
  52        output: Box<dyn AsyncRead + Unpin + Send + 'static>,
  53        stdout: Option<Box<dyn AsyncRead + Unpin + Send + 'static>>,
  54        stderr: Option<Box<dyn AsyncRead + Unpin + Send + 'static>>,
  55    ) -> Self {
  56        TransportPipe {
  57            input,
  58            output,
  59            stdout,
  60            stderr,
  61        }
  62    }
  63}
  64
  65type Requests = Arc<Mutex<HashMap<u64, oneshot::Sender<Result<Response>>>>>;
  66type LogHandlers = Arc<parking_lot::Mutex<SmallVec<[(LogKind, IoHandler); 2]>>>;
  67
  68pub enum Transport {
  69    Stdio(StdioTransport),
  70    Tcp(TcpTransport),
  71    #[cfg(any(test, feature = "test-support"))]
  72    Fake(FakeTransport),
  73}
  74
  75impl Transport {
  76    async fn start(binary: &DebugAdapterBinary, cx: AsyncApp) -> Result<(TransportPipe, Self)> {
  77        #[cfg(any(test, feature = "test-support"))]
  78        if cfg!(any(test, feature = "test-support")) {
  79            return FakeTransport::start(cx)
  80                .await
  81                .map(|(transports, fake)| (transports, Self::Fake(fake)));
  82        }
  83
  84        if binary.connection.is_some() {
  85            TcpTransport::start(binary, cx)
  86                .await
  87                .map(|(transports, tcp)| (transports, Self::Tcp(tcp)))
  88                .context("Tried to connect to a debug adapter via TCP transport layer")
  89        } else {
  90            StdioTransport::start(binary, cx)
  91                .await
  92                .map(|(transports, stdio)| (transports, Self::Stdio(stdio)))
  93                .context("Tried to connect to a debug adapter via stdin/stdout transport layer")
  94        }
  95    }
  96
  97    fn has_adapter_logs(&self) -> bool {
  98        match self {
  99            Transport::Stdio(stdio_transport) => stdio_transport.has_adapter_logs(),
 100            Transport::Tcp(tcp_transport) => tcp_transport.has_adapter_logs(),
 101            #[cfg(any(test, feature = "test-support"))]
 102            Transport::Fake(fake_transport) => fake_transport.has_adapter_logs(),
 103        }
 104    }
 105
 106    async fn kill(&self) {
 107        match self {
 108            Transport::Stdio(stdio_transport) => stdio_transport.kill().await,
 109            Transport::Tcp(tcp_transport) => tcp_transport.kill().await,
 110            #[cfg(any(test, feature = "test-support"))]
 111            Transport::Fake(fake_transport) => fake_transport.kill().await,
 112        }
 113    }
 114
 115    #[cfg(any(test, feature = "test-support"))]
 116    pub(crate) fn as_fake(&self) -> &FakeTransport {
 117        match self {
 118            Transport::Fake(fake_transport) => fake_transport,
 119            _ => panic!("Not a fake transport layer"),
 120        }
 121    }
 122}
 123
 124pub(crate) struct TransportDelegate {
 125    log_handlers: LogHandlers,
 126    current_requests: Requests,
 127    pending_requests: Requests,
 128    transport: Transport,
 129    server_tx: Arc<Mutex<Option<Sender<Message>>>>,
 130    _tasks: Vec<Task<()>>,
 131}
 132
 133impl TransportDelegate {
 134    pub(crate) async fn start(
 135        binary: &DebugAdapterBinary,
 136        cx: AsyncApp,
 137    ) -> Result<((Receiver<Message>, Sender<Message>), Self)> {
 138        let (transport_pipes, transport) = Transport::start(binary, cx.clone()).await?;
 139        let mut this = Self {
 140            transport,
 141            server_tx: Default::default(),
 142            log_handlers: Default::default(),
 143            current_requests: Default::default(),
 144            pending_requests: Default::default(),
 145            _tasks: Vec::new(),
 146        };
 147        let messages = this.start_handlers(transport_pipes, cx).await?;
 148        Ok((messages, this))
 149    }
 150
 151    async fn start_handlers(
 152        &mut self,
 153        mut params: TransportPipe,
 154        cx: AsyncApp,
 155    ) -> Result<(Receiver<Message>, Sender<Message>)> {
 156        let (client_tx, server_rx) = unbounded::<Message>();
 157        let (server_tx, client_rx) = unbounded::<Message>();
 158
 159        let log_dap_communications =
 160            cx.update(|cx| DebuggerSettings::get_global(cx).log_dap_communications)
 161                .with_context(|| "Failed to get Debugger Setting log dap communications error in transport::start_handlers. Defaulting to false")
 162                .unwrap_or(false);
 163
 164        let log_handler = if log_dap_communications {
 165            Some(self.log_handlers.clone())
 166        } else {
 167            None
 168        };
 169
 170        let adapter_log_handler = log_handler.clone();
 171        cx.update(|cx| {
 172            if let Some(stdout) = params.stdout.take() {
 173                self._tasks.push(cx.background_spawn(async move {
 174                    match Self::handle_adapter_log(stdout, adapter_log_handler).await {
 175                        ConnectionResult::Timeout => {
 176                            log::error!("Timed out when handling debugger log");
 177                        }
 178                        ConnectionResult::ConnectionReset => {
 179                            log::info!("Debugger logs connection closed");
 180                        }
 181                        ConnectionResult::Result(Ok(())) => {}
 182                        ConnectionResult::Result(Err(e)) => {
 183                            log::error!("Error handling debugger log: {e}");
 184                        }
 185                    }
 186                }));
 187            }
 188
 189            let pending_requests = self.pending_requests.clone();
 190            let output_log_handler = log_handler.clone();
 191            self._tasks.push(cx.background_spawn(async move {
 192                match Self::handle_output(
 193                    params.output,
 194                    client_tx,
 195                    pending_requests.clone(),
 196                    output_log_handler,
 197                )
 198                .await
 199                {
 200                    Ok(()) => {}
 201                    Err(e) => log::error!("Error handling debugger output: {e}"),
 202                }
 203                let mut pending_requests = pending_requests.lock().await;
 204                pending_requests.drain().for_each(|(_, request)| {
 205                    request
 206                        .send(Err(anyhow!("debugger shutdown unexpectedly")))
 207                        .ok();
 208                });
 209            }));
 210
 211            if let Some(stderr) = params.stderr.take() {
 212                let log_handlers = self.log_handlers.clone();
 213                self._tasks.push(cx.background_spawn(async move {
 214                    match Self::handle_error(stderr, log_handlers).await {
 215                        ConnectionResult::Timeout => {
 216                            log::error!("Timed out reading debugger error stream")
 217                        }
 218                        ConnectionResult::ConnectionReset => {
 219                            log::info!("Debugger closed its error stream")
 220                        }
 221                        ConnectionResult::Result(Ok(())) => {}
 222                        ConnectionResult::Result(Err(e)) => {
 223                            log::error!("Error handling debugger error: {e}")
 224                        }
 225                    }
 226                }));
 227            }
 228
 229            let current_requests = self.current_requests.clone();
 230            let pending_requests = self.pending_requests.clone();
 231            let log_handler = log_handler.clone();
 232            self._tasks.push(cx.background_spawn(async move {
 233                match Self::handle_input(
 234                    params.input,
 235                    client_rx,
 236                    current_requests,
 237                    pending_requests,
 238                    log_handler,
 239                )
 240                .await
 241                {
 242                    Ok(()) => {}
 243                    Err(e) => log::error!("Error handling debugger input: {e}"),
 244                }
 245            }));
 246        })?;
 247
 248        {
 249            let mut lock = self.server_tx.lock().await;
 250            *lock = Some(server_tx.clone());
 251        }
 252
 253        Ok((server_rx, server_tx))
 254    }
 255
 256    pub(crate) async fn add_pending_request(
 257        &self,
 258        sequence_id: u64,
 259        request: oneshot::Sender<Result<Response>>,
 260    ) {
 261        let mut pending_requests = self.pending_requests.lock().await;
 262        pending_requests.insert(sequence_id, request);
 263    }
 264
 265    pub(crate) async fn send_message(&self, message: Message) -> Result<()> {
 266        if let Some(server_tx) = self.server_tx.lock().await.as_ref() {
 267            server_tx.send(message).await.context("sending message")
 268        } else {
 269            anyhow::bail!("Server tx already dropped")
 270        }
 271    }
 272
 273    async fn handle_adapter_log<Stdout>(
 274        stdout: Stdout,
 275        log_handlers: Option<LogHandlers>,
 276    ) -> ConnectionResult<()>
 277    where
 278        Stdout: AsyncRead + Unpin + Send + 'static,
 279    {
 280        let mut reader = BufReader::new(stdout);
 281        let mut line = String::new();
 282
 283        let result = loop {
 284            line.truncate(0);
 285
 286            match reader
 287                .read_line(&mut line)
 288                .await
 289                .context("reading adapter log line")
 290            {
 291                Ok(0) => break ConnectionResult::ConnectionReset,
 292                Ok(_) => {}
 293                Err(e) => break ConnectionResult::Result(Err(e)),
 294            }
 295
 296            if let Some(log_handlers) = log_handlers.as_ref() {
 297                for (kind, handler) in log_handlers.lock().iter_mut() {
 298                    if matches!(kind, LogKind::Adapter) {
 299                        handler(IoKind::StdOut, line.as_str());
 300                    }
 301                }
 302            }
 303        };
 304
 305        log::debug!("Handle adapter log dropped");
 306
 307        result
 308    }
 309
 310    fn build_rpc_message(message: String) -> String {
 311        format!("Content-Length: {}\r\n\r\n{}", message.len(), message)
 312    }
 313
 314    async fn handle_input<Stdin>(
 315        mut server_stdin: Stdin,
 316        client_rx: Receiver<Message>,
 317        current_requests: Requests,
 318        pending_requests: Requests,
 319        log_handlers: Option<LogHandlers>,
 320    ) -> Result<()>
 321    where
 322        Stdin: AsyncWrite + Unpin + Send + 'static,
 323    {
 324        let result = loop {
 325            match client_rx.recv().await {
 326                Ok(message) => {
 327                    if let Message::Request(request) = &message {
 328                        if let Some(sender) = current_requests.lock().await.remove(&request.seq) {
 329                            pending_requests.lock().await.insert(request.seq, sender);
 330                        }
 331                    }
 332
 333                    let message = match serde_json::to_string(&message) {
 334                        Ok(message) => message,
 335                        Err(e) => break Err(e.into()),
 336                    };
 337
 338                    if let Some(log_handlers) = log_handlers.as_ref() {
 339                        for (kind, log_handler) in log_handlers.lock().iter_mut() {
 340                            if matches!(kind, LogKind::Rpc) {
 341                                log_handler(IoKind::StdIn, &message);
 342                            }
 343                        }
 344                    }
 345
 346                    if let Err(e) = server_stdin
 347                        .write_all(Self::build_rpc_message(message).as_bytes())
 348                        .await
 349                    {
 350                        break Err(e.into());
 351                    }
 352
 353                    if let Err(e) = server_stdin.flush().await {
 354                        break Err(e.into());
 355                    }
 356                }
 357                Err(error) => break Err(error.into()),
 358            }
 359        };
 360
 361        log::debug!("Handle adapter input dropped");
 362
 363        result
 364    }
 365
 366    async fn handle_output<Stdout>(
 367        server_stdout: Stdout,
 368        client_tx: Sender<Message>,
 369        pending_requests: Requests,
 370        log_handlers: Option<LogHandlers>,
 371    ) -> Result<()>
 372    where
 373        Stdout: AsyncRead + Unpin + Send + 'static,
 374    {
 375        let mut recv_buffer = String::new();
 376        let mut reader = BufReader::new(server_stdout);
 377
 378        let result = loop {
 379            match Self::receive_server_message(&mut reader, &mut recv_buffer, log_handlers.as_ref())
 380                .await
 381            {
 382                ConnectionResult::Timeout => anyhow::bail!("Timed out when connecting to debugger"),
 383                ConnectionResult::ConnectionReset => {
 384                    log::info!("Debugger closed the connection");
 385                    return Ok(());
 386                }
 387                ConnectionResult::Result(Ok(Message::Response(res))) => {
 388                    if let Some(tx) = pending_requests.lock().await.remove(&res.request_seq) {
 389                        if let Err(e) = tx.send(Self::process_response(res)) {
 390                            log::trace!("Did not send response `{:?}` for a cancelled", e);
 391                        }
 392                    } else {
 393                        client_tx.send(Message::Response(res)).await?;
 394                    }
 395                }
 396                ConnectionResult::Result(Ok(message)) => client_tx.send(message).await?,
 397                ConnectionResult::Result(Err(e)) => break Err(e),
 398            }
 399        };
 400
 401        drop(client_tx);
 402        log::debug!("Handle adapter output dropped");
 403
 404        result
 405    }
 406
 407    async fn handle_error<Stderr>(stderr: Stderr, log_handlers: LogHandlers) -> ConnectionResult<()>
 408    where
 409        Stderr: AsyncRead + Unpin + Send + 'static,
 410    {
 411        log::debug!("Handle error started");
 412        let mut buffer = String::new();
 413
 414        let mut reader = BufReader::new(stderr);
 415
 416        let result = loop {
 417            match reader
 418                .read_line(&mut buffer)
 419                .await
 420                .context("reading error log line")
 421            {
 422                Ok(0) => break ConnectionResult::ConnectionReset,
 423                Ok(_) => {
 424                    for (kind, log_handler) in log_handlers.lock().iter_mut() {
 425                        if matches!(kind, LogKind::Adapter) {
 426                            log_handler(IoKind::StdErr, buffer.as_str());
 427                        }
 428                    }
 429
 430                    buffer.truncate(0);
 431                }
 432                Err(error) => break ConnectionResult::Result(Err(error)),
 433            }
 434        };
 435
 436        log::debug!("Handle adapter error dropped");
 437
 438        result
 439    }
 440
 441    fn process_response(response: Response) -> Result<Response> {
 442        if response.success {
 443            Ok(response)
 444        } else {
 445            if let Some(error_message) = response
 446                .body
 447                .clone()
 448                .and_then(|body| serde_json::from_value::<ErrorResponse>(body).ok())
 449                .and_then(|response| response.error.map(|msg| msg.format))
 450                .or_else(|| response.message.clone())
 451            {
 452                anyhow::bail!(error_message);
 453            };
 454
 455            anyhow::bail!(
 456                "Received error response from adapter. Response: {:?}",
 457                response
 458            );
 459        }
 460    }
 461
 462    async fn receive_server_message<Stdout>(
 463        reader: &mut BufReader<Stdout>,
 464        buffer: &mut String,
 465        log_handlers: Option<&LogHandlers>,
 466    ) -> ConnectionResult<Message>
 467    where
 468        Stdout: AsyncRead + Unpin + Send + 'static,
 469    {
 470        let mut content_length = None;
 471        loop {
 472            buffer.truncate(0);
 473
 474            match reader
 475                .read_line(buffer)
 476                .await
 477                .with_context(|| "reading a message from server")
 478            {
 479                Ok(0) => return ConnectionResult::ConnectionReset,
 480                Ok(_) => {}
 481                Err(e) => return ConnectionResult::Result(Err(e)),
 482            };
 483
 484            if buffer == "\r\n" {
 485                break;
 486            }
 487
 488            if let Some(("Content-Length", value)) = buffer.trim().split_once(": ") {
 489                match value.parse().context("invalid content length") {
 490                    Ok(length) => content_length = Some(length),
 491                    Err(e) => return ConnectionResult::Result(Err(e)),
 492                }
 493            }
 494        }
 495
 496        let content_length = match content_length.context("missing content length") {
 497            Ok(length) => length,
 498            Err(e) => return ConnectionResult::Result(Err(e)),
 499        };
 500
 501        let mut content = vec![0; content_length];
 502        if let Err(e) = reader
 503            .read_exact(&mut content)
 504            .await
 505            .with_context(|| "reading after a loop")
 506        {
 507            return ConnectionResult::Result(Err(e));
 508        }
 509
 510        let message_str = match std::str::from_utf8(&content).context("invalid utf8 from server") {
 511            Ok(str) => str,
 512            Err(e) => return ConnectionResult::Result(Err(e)),
 513        };
 514
 515        if let Some(log_handlers) = log_handlers {
 516            for (kind, log_handler) in log_handlers.lock().iter_mut() {
 517                if matches!(kind, LogKind::Rpc) {
 518                    log_handler(IoKind::StdOut, message_str);
 519                }
 520            }
 521        }
 522
 523        ConnectionResult::Result(
 524            serde_json::from_str::<Message>(message_str).context("deserializing server message"),
 525        )
 526    }
 527
 528    pub async fn shutdown(&self) -> Result<()> {
 529        log::debug!("Start shutdown client");
 530
 531        if let Some(server_tx) = self.server_tx.lock().await.take().as_ref() {
 532            server_tx.close();
 533        }
 534
 535        let mut current_requests = self.current_requests.lock().await;
 536        let mut pending_requests = self.pending_requests.lock().await;
 537
 538        current_requests.clear();
 539        pending_requests.clear();
 540
 541        self.transport.kill().await;
 542
 543        drop(current_requests);
 544        drop(pending_requests);
 545
 546        log::debug!("Shutdown client completed");
 547
 548        anyhow::Ok(())
 549    }
 550
 551    pub fn has_adapter_logs(&self) -> bool {
 552        self.transport.has_adapter_logs()
 553    }
 554
 555    pub fn transport(&self) -> &Transport {
 556        &self.transport
 557    }
 558
 559    pub fn add_log_handler<F>(&self, f: F, kind: LogKind)
 560    where
 561        F: 'static + Send + FnMut(IoKind, &str),
 562    {
 563        let mut log_handlers = self.log_handlers.lock();
 564        log_handlers.push((kind, Box::new(f)));
 565    }
 566}
 567
 568pub struct TcpTransport {
 569    pub port: u16,
 570    pub host: Ipv4Addr,
 571    pub timeout: u64,
 572    process: Option<Mutex<Child>>,
 573}
 574
 575impl TcpTransport {
 576    /// Get an open port to use with the tcp client when not supplied by debug config
 577    pub async fn port(host: &TcpArgumentsTemplate) -> Result<u16> {
 578        if let Some(port) = host.port {
 579            Ok(port)
 580        } else {
 581            Self::unused_port(host.host()).await
 582        }
 583    }
 584
 585    pub async fn unused_port(host: Ipv4Addr) -> Result<u16> {
 586        Ok(TcpListener::bind(SocketAddrV4::new(host, 0))
 587            .await?
 588            .local_addr()?
 589            .port())
 590    }
 591
 592    async fn start(binary: &DebugAdapterBinary, cx: AsyncApp) -> Result<(TransportPipe, Self)> {
 593        let connection_args = binary
 594            .connection
 595            .as_ref()
 596            .context("No connection arguments provided")?;
 597
 598        let host = connection_args.host;
 599        let port = connection_args.port;
 600
 601        let mut process = if let Some(command) = &binary.command {
 602            let mut command = util::command::new_std_command(&command);
 603
 604            if let Some(cwd) = &binary.cwd {
 605                command.current_dir(cwd);
 606            }
 607
 608            command.args(&binary.arguments);
 609            command.envs(&binary.envs);
 610
 611            Some(
 612                Child::spawn(command, Stdio::null())
 613                    .with_context(|| "failed to start debug adapter.")?,
 614            )
 615        } else {
 616            None
 617        };
 618
 619        let address = SocketAddrV4::new(host, port);
 620
 621        let timeout = connection_args.timeout.unwrap_or_else(|| {
 622            cx.update(|cx| DebuggerSettings::get_global(cx).timeout)
 623                .unwrap_or(2000u64)
 624        });
 625
 626        let (mut process, (rx, tx)) = select! {
 627            _ = cx.background_executor().timer(Duration::from_millis(timeout)).fuse() => {
 628                anyhow::bail!("Connection to TCP DAP timeout {host}:{port}");
 629            },
 630            result = cx.spawn(async move |cx| {
 631                loop {
 632                    match TcpStream::connect(address).await {
 633                        Ok(stream) => return Ok((process, stream.split())),
 634                        Err(_) => {
 635                            if let Some(p) = &mut process {
 636                                if let Ok(Some(_)) = p.try_status() {
 637                                    let output = process.take().unwrap().into_inner().output().await?;
 638                                    let output = if output.stderr.is_empty() {
 639                                        String::from_utf8_lossy(&output.stdout).to_string()
 640                                    } else {
 641                                        String::from_utf8_lossy(&output.stderr).to_string()
 642                                    };
 643                                    anyhow::bail!("{output}\nerror: process exited before debugger attached.");
 644                                }
 645                            }
 646
 647                            cx.background_executor().timer(Duration::from_millis(100)).await;
 648                        }
 649                    }
 650                }
 651            }).fuse() => result?
 652        };
 653
 654        log::info!(
 655            "Debug adapter has connected to TCP server {}:{}",
 656            host,
 657            port
 658        );
 659        let stdout = process.as_mut().and_then(|p| p.stdout.take());
 660        let stderr = process.as_mut().and_then(|p| p.stderr.take());
 661
 662        let this = Self {
 663            port,
 664            host,
 665            process: process.map(Mutex::new),
 666            timeout,
 667        };
 668
 669        let pipe = TransportPipe::new(
 670            Box::new(tx),
 671            Box::new(BufReader::new(rx)),
 672            stdout.map(|s| Box::new(s) as Box<dyn AsyncRead + Unpin + Send>),
 673            stderr.map(|s| Box::new(s) as Box<dyn AsyncRead + Unpin + Send>),
 674        );
 675
 676        Ok((pipe, this))
 677    }
 678
 679    fn has_adapter_logs(&self) -> bool {
 680        true
 681    }
 682
 683    async fn kill(&self) {
 684        if let Some(process) = &self.process {
 685            let mut process = process.lock().await;
 686            Child::kill(&mut process);
 687        }
 688    }
 689}
 690
 691impl Drop for TcpTransport {
 692    fn drop(&mut self) {
 693        if let Some(mut p) = self.process.take() {
 694            p.get_mut().kill();
 695        }
 696    }
 697}
 698
 699pub struct StdioTransport {
 700    process: Mutex<Child>,
 701}
 702
 703impl StdioTransport {
 704    #[allow(dead_code, reason = "This is used in non test builds of Zed")]
 705    async fn start(binary: &DebugAdapterBinary, _: AsyncApp) -> Result<(TransportPipe, Self)> {
 706        let Some(binary_command) = &binary.command else {
 707            bail!(
 708                "When using the `stdio` transport, the path to a debug adapter binary must be set by Zed."
 709            );
 710        };
 711        let mut command = util::command::new_std_command(&binary_command);
 712
 713        if let Some(cwd) = &binary.cwd {
 714            command.current_dir(cwd);
 715        }
 716
 717        command.args(&binary.arguments);
 718        command.envs(&binary.envs);
 719
 720        let mut process = Child::spawn(command, Stdio::piped()).with_context(|| {
 721            format!(
 722                "failed to spawn command `{} {}`.",
 723                binary_command,
 724                binary.arguments.join(" ")
 725            )
 726        })?;
 727
 728        let stdin = process.stdin.take().context("Failed to open stdin")?;
 729        let stdout = process.stdout.take().context("Failed to open stdout")?;
 730        let stderr = process
 731            .stderr
 732            .take()
 733            .map(|io_err| Box::new(io_err) as Box<dyn AsyncRead + Unpin + Send>);
 734
 735        if stderr.is_none() {
 736            bail!(
 737                "Failed to connect to stderr for debug adapter command {}",
 738                &binary_command
 739            );
 740        }
 741
 742        log::info!("Debug adapter has connected to stdio adapter");
 743
 744        let process = Mutex::new(process);
 745
 746        Ok((
 747            TransportPipe::new(
 748                Box::new(stdin),
 749                Box::new(BufReader::new(stdout)),
 750                None,
 751                stderr,
 752            ),
 753            Self { process },
 754        ))
 755    }
 756
 757    fn has_adapter_logs(&self) -> bool {
 758        false
 759    }
 760
 761    async fn kill(&self) {
 762        let mut process = self.process.lock().await;
 763        Child::kill(&mut process);
 764    }
 765}
 766
 767impl Drop for StdioTransport {
 768    fn drop(&mut self) {
 769        self.process.get_mut().kill();
 770    }
 771}
 772
 773#[cfg(any(test, feature = "test-support"))]
 774type RequestHandler =
 775    Box<dyn Send + FnMut(u64, serde_json::Value) -> dap_types::messages::Response>;
 776
 777#[cfg(any(test, feature = "test-support"))]
 778type ResponseHandler = Box<dyn Send + Fn(Response)>;
 779
 780#[cfg(any(test, feature = "test-support"))]
 781pub struct FakeTransport {
 782    // for sending fake response back from adapter side
 783    request_handlers: Arc<parking_lot::Mutex<HashMap<&'static str, RequestHandler>>>,
 784    // for reverse request responses
 785    response_handlers: Arc<parking_lot::Mutex<HashMap<&'static str, ResponseHandler>>>,
 786}
 787
 788#[cfg(any(test, feature = "test-support"))]
 789impl FakeTransport {
 790    pub fn on_request<R: dap_types::requests::Request, F>(&self, mut handler: F)
 791    where
 792        F: 'static + Send + FnMut(u64, R::Arguments) -> Result<R::Response, ErrorResponse>,
 793    {
 794        self.request_handlers.lock().insert(
 795            R::COMMAND,
 796            Box::new(move |seq, args| {
 797                let result = handler(seq, serde_json::from_value(args).unwrap());
 798                let response = match result {
 799                    Ok(response) => Response {
 800                        seq: seq + 1,
 801                        request_seq: seq,
 802                        success: true,
 803                        command: R::COMMAND.into(),
 804                        body: Some(serde_json::to_value(response).unwrap()),
 805                        message: None,
 806                    },
 807                    Err(response) => Response {
 808                        seq: seq + 1,
 809                        request_seq: seq,
 810                        success: false,
 811                        command: R::COMMAND.into(),
 812                        body: Some(serde_json::to_value(response).unwrap()),
 813                        message: None,
 814                    },
 815                };
 816                response
 817            }),
 818        );
 819    }
 820
 821    pub async fn on_response<R: dap_types::requests::Request, F>(&self, handler: F)
 822    where
 823        F: 'static + Send + Fn(Response),
 824    {
 825        self.response_handlers
 826            .lock()
 827            .insert(R::COMMAND, Box::new(handler));
 828    }
 829
 830    async fn start(cx: AsyncApp) -> Result<(TransportPipe, Self)> {
 831        let this = Self {
 832            request_handlers: Arc::new(parking_lot::Mutex::new(HashMap::default())),
 833            response_handlers: Arc::new(parking_lot::Mutex::new(HashMap::default())),
 834        };
 835        use dap_types::requests::{Request, RunInTerminal, StartDebugging};
 836        use serde_json::json;
 837
 838        let (stdin_writer, stdin_reader) = async_pipe::pipe();
 839        let (stdout_writer, stdout_reader) = async_pipe::pipe();
 840
 841        let request_handlers = this.request_handlers.clone();
 842        let response_handlers = this.response_handlers.clone();
 843        let stdout_writer = Arc::new(Mutex::new(stdout_writer));
 844
 845        cx.background_spawn(async move {
 846            let mut reader = BufReader::new(stdin_reader);
 847            let mut buffer = String::new();
 848
 849            loop {
 850                match TransportDelegate::receive_server_message(&mut reader, &mut buffer, None)
 851                    .await
 852                {
 853                    ConnectionResult::Timeout => {
 854                        anyhow::bail!("Timed out when connecting to debugger");
 855                    }
 856                    ConnectionResult::ConnectionReset => {
 857                        log::info!("Debugger closed the connection");
 858                        break Ok(());
 859                    }
 860                    ConnectionResult::Result(Err(e)) => break Err(e),
 861                    ConnectionResult::Result(Ok(message)) => {
 862                        match message {
 863                            Message::Request(request) => {
 864                                // redirect reverse requests to stdout writer/reader
 865                                if request.command == RunInTerminal::COMMAND
 866                                    || request.command == StartDebugging::COMMAND
 867                                {
 868                                    let message =
 869                                        serde_json::to_string(&Message::Request(request)).unwrap();
 870
 871                                    let mut writer = stdout_writer.lock().await;
 872                                    writer
 873                                        .write_all(
 874                                            TransportDelegate::build_rpc_message(message)
 875                                                .as_bytes(),
 876                                        )
 877                                        .await
 878                                        .unwrap();
 879                                    writer.flush().await.unwrap();
 880                                } else {
 881                                    let response = if let Some(handle) =
 882                                        request_handlers.lock().get_mut(request.command.as_str())
 883                                    {
 884                                        handle(request.seq, request.arguments.unwrap_or(json!({})))
 885                                    } else {
 886                                        panic!("No request handler for {}", request.command);
 887                                    };
 888                                    let message =
 889                                        serde_json::to_string(&Message::Response(response))
 890                                            .unwrap();
 891
 892                                    let mut writer = stdout_writer.lock().await;
 893
 894                                    writer
 895                                        .write_all(
 896                                            TransportDelegate::build_rpc_message(message)
 897                                                .as_bytes(),
 898                                        )
 899                                        .await
 900                                        .unwrap();
 901                                    writer.flush().await.unwrap();
 902                                }
 903                            }
 904                            Message::Event(event) => {
 905                                let message =
 906                                    serde_json::to_string(&Message::Event(event)).unwrap();
 907
 908                                let mut writer = stdout_writer.lock().await;
 909                                writer
 910                                    .write_all(
 911                                        TransportDelegate::build_rpc_message(message).as_bytes(),
 912                                    )
 913                                    .await
 914                                    .unwrap();
 915                                writer.flush().await.unwrap();
 916                            }
 917                            Message::Response(response) => {
 918                                if let Some(handle) =
 919                                    response_handlers.lock().get(response.command.as_str())
 920                                {
 921                                    handle(response);
 922                                } else {
 923                                    log::error!("No response handler for {}", response.command);
 924                                }
 925                            }
 926                        }
 927                    }
 928                }
 929            }
 930        })
 931        .detach();
 932
 933        Ok((
 934            TransportPipe::new(Box::new(stdin_writer), Box::new(stdout_reader), None, None),
 935            this,
 936        ))
 937    }
 938
 939    fn has_adapter_logs(&self) -> bool {
 940        false
 941    }
 942
 943    async fn kill(&self) {}
 944}
 945
 946struct Child {
 947    process: smol::process::Child,
 948}
 949
 950impl std::ops::Deref for Child {
 951    type Target = smol::process::Child;
 952
 953    fn deref(&self) -> &Self::Target {
 954        &self.process
 955    }
 956}
 957
 958impl std::ops::DerefMut for Child {
 959    fn deref_mut(&mut self) -> &mut Self::Target {
 960        &mut self.process
 961    }
 962}
 963
 964impl Child {
 965    fn into_inner(self) -> smol::process::Child {
 966        self.process
 967    }
 968
 969    #[cfg(not(windows))]
 970    fn spawn(mut command: std::process::Command, stdin: Stdio) -> Result<Self> {
 971        util::set_pre_exec_to_start_new_session(&mut command);
 972        let process = smol::process::Command::from(command)
 973            .stdin(stdin)
 974            .stdout(Stdio::piped())
 975            .stderr(Stdio::piped())
 976            .spawn()?;
 977        Ok(Self { process })
 978    }
 979
 980    #[cfg(windows)]
 981    fn spawn(command: std::process::Command, stdin: Stdio) -> Result<Self> {
 982        // TODO(windows): create a job object and add the child process handle to it,
 983        // see https://learn.microsoft.com/en-us/windows/win32/procthread/job-objects
 984        let process = smol::process::Command::from(command)
 985            .stdin(stdin)
 986            .stdout(Stdio::piped())
 987            .stderr(Stdio::piped())
 988            .spawn()?;
 989        Ok(Self { process })
 990    }
 991
 992    #[cfg(not(windows))]
 993    fn kill(&mut self) {
 994        let pid = self.process.id();
 995        unsafe {
 996            libc::killpg(pid as i32, libc::SIGKILL);
 997        }
 998    }
 999
1000    #[cfg(windows)]
1001    fn kill(&mut self) {
1002        // TODO(windows): terminate the job object in kill
1003        let _ = self.process.kill();
1004    }
1005}