transport.rs

   1use anyhow::{Context as _, Result, anyhow, bail};
   2#[cfg(any(test, feature = "test-support"))]
   3use async_pipe::{PipeReader, PipeWriter};
   4use dap_types::{
   5    ErrorResponse,
   6    messages::{Message, Response},
   7};
   8use futures::{AsyncRead, AsyncReadExt as _, AsyncWrite, FutureExt as _, channel::oneshot, select};
   9use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Task};
  10use parking_lot::Mutex;
  11use proto::ErrorExt;
  12use settings::Settings as _;
  13use smallvec::SmallVec;
  14use smol::{
  15    channel::{Receiver, Sender, unbounded},
  16    io::{AsyncBufReadExt as _, AsyncWriteExt, BufReader},
  17    net::{TcpListener, TcpStream},
  18};
  19use std::{
  20    collections::HashMap,
  21    net::{Ipv4Addr, SocketAddrV4},
  22    process::Stdio,
  23    sync::Arc,
  24    time::Duration,
  25};
  26use task::TcpArgumentsTemplate;
  27use util::ConnectionResult;
  28
  29use crate::{
  30    adapters::{DebugAdapterBinary, TcpArguments},
  31    client::DapMessageHandler,
  32    debugger_settings::DebuggerSettings,
  33};
  34
  35pub(crate) type IoMessage = str;
  36pub(crate) type Command = str;
  37pub type IoHandler = Box<dyn Send + FnMut(IoKind, Option<&Command>, &IoMessage)>;
  38
  39#[derive(PartialEq, Eq, Clone, Copy)]
  40pub enum LogKind {
  41    Adapter,
  42    Rpc,
  43}
  44
  45#[derive(Clone, Copy)]
  46pub enum IoKind {
  47    StdIn,
  48    StdOut,
  49    StdErr,
  50}
  51
  52#[cfg(any(test, feature = "test-support"))]
  53pub enum RequestHandling<T> {
  54    Respond(T),
  55    Exit,
  56}
  57
  58type LogHandlers = Arc<Mutex<SmallVec<[(LogKind, IoHandler); 2]>>>;
  59
  60pub trait Transport: Send + Sync {
  61    fn has_adapter_logs(&self) -> bool;
  62    fn tcp_arguments(&self) -> Option<TcpArguments>;
  63    fn connect(
  64        &mut self,
  65    ) -> Task<
  66        Result<(
  67            Box<dyn AsyncWrite + Unpin + Send + 'static>,
  68            Box<dyn AsyncRead + Unpin + Send + 'static>,
  69        )>,
  70    >;
  71    fn kill(&mut self);
  72    #[cfg(any(test, feature = "test-support"))]
  73    fn as_fake(&self) -> &FakeTransport {
  74        unreachable!()
  75    }
  76}
  77
  78async fn start(
  79    binary: &DebugAdapterBinary,
  80    log_handlers: LogHandlers,
  81    cx: &mut AsyncApp,
  82) -> Result<Box<dyn Transport>> {
  83    #[cfg(any(test, feature = "test-support"))]
  84    if cfg!(any(test, feature = "test-support")) {
  85        if let Some(connection) = binary.connection.clone() {
  86            return Ok(Box::new(FakeTransport::start_tcp(connection, cx).await?));
  87        } else {
  88            return Ok(Box::new(FakeTransport::start_stdio(cx).await?));
  89        }
  90    }
  91
  92    if binary.connection.is_some() {
  93        Ok(Box::new(
  94            TcpTransport::start(binary, log_handlers, cx).await?,
  95        ))
  96    } else {
  97        Ok(Box::new(
  98            StdioTransport::start(binary, log_handlers, cx).await?,
  99        ))
 100    }
 101}
 102
 103pub(crate) struct PendingRequests {
 104    inner: Option<HashMap<u64, oneshot::Sender<Result<Response>>>>,
 105}
 106
 107impl PendingRequests {
 108    fn new() -> Self {
 109        Self {
 110            inner: Some(HashMap::default()),
 111        }
 112    }
 113
 114    fn flush(&mut self, e: anyhow::Error) {
 115        let Some(inner) = self.inner.as_mut() else {
 116            return;
 117        };
 118        for (_, sender) in inner.drain() {
 119            sender.send(Err(e.cloned())).ok();
 120        }
 121    }
 122
 123    pub(crate) fn insert(
 124        &mut self,
 125        sequence_id: u64,
 126        callback_tx: oneshot::Sender<Result<Response>>,
 127    ) -> anyhow::Result<()> {
 128        let Some(inner) = self.inner.as_mut() else {
 129            bail!("client is closed")
 130        };
 131        inner.insert(sequence_id, callback_tx);
 132        Ok(())
 133    }
 134
 135    pub(crate) fn remove(
 136        &mut self,
 137        sequence_id: u64,
 138    ) -> anyhow::Result<Option<oneshot::Sender<Result<Response>>>> {
 139        let Some(inner) = self.inner.as_mut() else {
 140            bail!("client is closed");
 141        };
 142        Ok(inner.remove(&sequence_id))
 143    }
 144
 145    pub(crate) fn shutdown(&mut self) {
 146        self.flush(anyhow!("transport shutdown"));
 147        self.inner = None;
 148    }
 149}
 150
 151pub(crate) struct TransportDelegate {
 152    log_handlers: LogHandlers,
 153    pub(crate) pending_requests: Arc<Mutex<PendingRequests>>,
 154    pub(crate) transport: Mutex<Box<dyn Transport>>,
 155    pub(crate) server_tx: smol::lock::Mutex<Option<Sender<Message>>>,
 156    tasks: Mutex<Vec<Task<()>>>,
 157}
 158
 159impl TransportDelegate {
 160    pub(crate) async fn start(binary: &DebugAdapterBinary, cx: &mut AsyncApp) -> Result<Self> {
 161        let log_handlers: LogHandlers = Default::default();
 162        let transport = start(binary, log_handlers.clone(), cx).await?;
 163        Ok(Self {
 164            transport: Mutex::new(transport),
 165            log_handlers,
 166            server_tx: Default::default(),
 167            pending_requests: Arc::new(Mutex::new(PendingRequests::new())),
 168            tasks: Default::default(),
 169        })
 170    }
 171
 172    pub async fn connect(
 173        &self,
 174        message_handler: DapMessageHandler,
 175        cx: &mut AsyncApp,
 176    ) -> Result<()> {
 177        let (server_tx, client_rx) = unbounded::<Message>();
 178        self.tasks.lock().clear();
 179
 180        let log_dap_communications =
 181            cx.update(|cx| DebuggerSettings::get_global(cx).log_dap_communications)
 182                .with_context(|| "Failed to get Debugger Setting log dap communications error in transport::start_handlers. Defaulting to false")
 183                .unwrap_or(false);
 184
 185        let connect = self.transport.lock().connect();
 186        let (input, output) = connect.await?;
 187
 188        let log_handler = if log_dap_communications {
 189            Some(self.log_handlers.clone())
 190        } else {
 191            None
 192        };
 193
 194        let pending_requests = self.pending_requests.clone();
 195        let output_log_handler = log_handler.clone();
 196        {
 197            let mut tasks = self.tasks.lock();
 198            tasks.push(cx.background_spawn(async move {
 199                match Self::recv_from_server(
 200                    output,
 201                    message_handler,
 202                    pending_requests.clone(),
 203                    output_log_handler,
 204                )
 205                .await
 206                {
 207                    Ok(()) => {
 208                        pending_requests
 209                            .lock()
 210                            .flush(anyhow!("debugger shutdown unexpectedly"));
 211                    }
 212                    Err(e) => {
 213                        pending_requests.lock().flush(e);
 214                    }
 215                }
 216            }));
 217
 218            tasks.push(cx.background_spawn(async move {
 219                match Self::send_to_server(input, client_rx, log_handler).await {
 220                    Ok(()) => {}
 221                    Err(e) => log::error!("Error handling debugger input: {e}"),
 222                }
 223            }));
 224        }
 225
 226        {
 227            let mut lock = self.server_tx.lock().await;
 228            *lock = Some(server_tx.clone());
 229        }
 230
 231        Ok(())
 232    }
 233
 234    pub(crate) fn tcp_arguments(&self) -> Option<TcpArguments> {
 235        self.transport.lock().tcp_arguments()
 236    }
 237
 238    pub(crate) async fn send_message(&self, message: Message) -> Result<()> {
 239        if let Some(server_tx) = self.server_tx.lock().await.as_ref() {
 240            server_tx.send(message).await.context("sending message")
 241        } else {
 242            anyhow::bail!("Server tx already dropped")
 243        }
 244    }
 245
 246    async fn handle_adapter_log(
 247        stdout: impl AsyncRead + Unpin + Send + 'static,
 248        iokind: IoKind,
 249        log_handlers: LogHandlers,
 250    ) {
 251        let mut reader = BufReader::new(stdout);
 252        let mut line = String::new();
 253
 254        loop {
 255            line.truncate(0);
 256
 257            match reader.read_line(&mut line).await {
 258                Ok(0) => break,
 259                Ok(_) => {}
 260                Err(e) => {
 261                    log::debug!("handle_adapter_log: {}", e);
 262                    break;
 263                }
 264            }
 265            log::debug!("stderr: {line}");
 266
 267            for (kind, handler) in log_handlers.lock().iter_mut() {
 268                if matches!(kind, LogKind::Adapter) {
 269                    handler(iokind, None, line.as_str());
 270                }
 271            }
 272        }
 273    }
 274
 275    fn build_rpc_message(message: String) -> String {
 276        format!("Content-Length: {}\r\n\r\n{}", message.len(), message)
 277    }
 278
 279    async fn send_to_server<Stdin>(
 280        mut server_stdin: Stdin,
 281        client_rx: Receiver<Message>,
 282        log_handlers: Option<LogHandlers>,
 283    ) -> Result<()>
 284    where
 285        Stdin: AsyncWrite + Unpin + Send + 'static,
 286    {
 287        let result = loop {
 288            match client_rx.recv().await {
 289                Ok(message) => {
 290                    let command = match &message {
 291                        Message::Request(request) => Some(request.command.as_str()),
 292                        Message::Response(response) => Some(response.command.as_str()),
 293                        _ => None,
 294                    };
 295
 296                    let message = match serde_json::to_string(&message) {
 297                        Ok(message) => message,
 298                        Err(e) => break Err(e.into()),
 299                    };
 300
 301                    if let Some(log_handlers) = log_handlers.as_ref() {
 302                        for (kind, log_handler) in log_handlers.lock().iter_mut() {
 303                            if matches!(kind, LogKind::Rpc) {
 304                                log_handler(IoKind::StdIn, command, &message);
 305                            }
 306                        }
 307                    }
 308
 309                    if let Err(e) = server_stdin
 310                        .write_all(Self::build_rpc_message(message).as_bytes())
 311                        .await
 312                    {
 313                        break Err(e.into());
 314                    }
 315
 316                    if let Err(e) = server_stdin.flush().await {
 317                        break Err(e.into());
 318                    }
 319                }
 320                Err(error) => break Err(error.into()),
 321            }
 322        };
 323
 324        log::debug!("Handle adapter input dropped");
 325
 326        result
 327    }
 328
 329    async fn recv_from_server<Stdout>(
 330        server_stdout: Stdout,
 331        mut message_handler: DapMessageHandler,
 332        pending_requests: Arc<Mutex<PendingRequests>>,
 333        log_handlers: Option<LogHandlers>,
 334    ) -> Result<()>
 335    where
 336        Stdout: AsyncRead + Unpin + Send + 'static,
 337    {
 338        let mut recv_buffer = String::new();
 339        let mut reader = BufReader::new(server_stdout);
 340
 341        let result = loop {
 342            let result =
 343                Self::receive_server_message(&mut reader, &mut recv_buffer, log_handlers.as_ref())
 344                    .await;
 345            match result {
 346                ConnectionResult::Timeout => anyhow::bail!("Timed out when connecting to debugger"),
 347                ConnectionResult::ConnectionReset => {
 348                    log::info!("Debugger closed the connection");
 349                    return Ok(());
 350                }
 351                ConnectionResult::Result(Ok(Message::Response(res))) => {
 352                    let tx = pending_requests.lock().remove(res.request_seq)?;
 353                    if let Some(tx) = tx {
 354                        if let Err(e) = tx.send(Self::process_response(res)) {
 355                            log::trace!("Did not send response `{:?}` for a cancelled", e);
 356                        }
 357                    } else {
 358                        message_handler(Message::Response(res))
 359                    }
 360                }
 361                ConnectionResult::Result(Ok(message)) => message_handler(message),
 362                ConnectionResult::Result(Err(e)) => break Err(e),
 363            }
 364        };
 365
 366        log::debug!("Handle adapter output dropped");
 367
 368        result
 369    }
 370
 371    fn process_response(response: Response) -> Result<Response> {
 372        if response.success {
 373            Ok(response)
 374        } else {
 375            if let Some(error_message) = response
 376                .body
 377                .clone()
 378                .and_then(|body| serde_json::from_value::<ErrorResponse>(body).ok())
 379                .and_then(|response| response.error.map(|msg| msg.format))
 380                .or_else(|| response.message.clone())
 381            {
 382                anyhow::bail!(error_message);
 383            };
 384
 385            anyhow::bail!(
 386                "Received error response from adapter. Response: {:?}",
 387                response
 388            );
 389        }
 390    }
 391
 392    async fn receive_server_message<Stdout>(
 393        reader: &mut BufReader<Stdout>,
 394        buffer: &mut String,
 395        log_handlers: Option<&LogHandlers>,
 396    ) -> ConnectionResult<Message>
 397    where
 398        Stdout: AsyncRead + Unpin + Send + 'static,
 399    {
 400        let mut content_length = None;
 401        loop {
 402            buffer.truncate(0);
 403            match reader.read_line(buffer).await {
 404                Ok(0) => return ConnectionResult::ConnectionReset,
 405                Ok(_) => {}
 406                Err(e) => return ConnectionResult::Result(Err(e.into())),
 407            };
 408
 409            if buffer == "\r\n" {
 410                break;
 411            }
 412
 413            if let Some(("Content-Length", value)) = buffer.trim().split_once(": ") {
 414                match value.parse().context("invalid content length") {
 415                    Ok(length) => content_length = Some(length),
 416                    Err(e) => return ConnectionResult::Result(Err(e)),
 417                }
 418            }
 419        }
 420
 421        let content_length = match content_length.context("missing content length") {
 422            Ok(length) => length,
 423            Err(e) => return ConnectionResult::Result(Err(e)),
 424        };
 425
 426        let mut content = vec![0; content_length];
 427        if let Err(e) = reader
 428            .read_exact(&mut content)
 429            .await
 430            .with_context(|| "reading after a loop")
 431        {
 432            return ConnectionResult::Result(Err(e));
 433        }
 434
 435        let message_str = match std::str::from_utf8(&content).context("invalid utf8 from server") {
 436            Ok(str) => str,
 437            Err(e) => return ConnectionResult::Result(Err(e)),
 438        };
 439
 440        let message =
 441            serde_json::from_str::<Message>(message_str).context("deserializing server message");
 442
 443        if let Some(log_handlers) = log_handlers {
 444            let command = match &message {
 445                Ok(Message::Request(request)) => Some(request.command.as_str()),
 446                Ok(Message::Response(response)) => Some(response.command.as_str()),
 447                _ => None,
 448            };
 449
 450            for (kind, log_handler) in log_handlers.lock().iter_mut() {
 451                if matches!(kind, LogKind::Rpc) {
 452                    log_handler(IoKind::StdOut, command, message_str);
 453                }
 454            }
 455        }
 456
 457        ConnectionResult::Result(message)
 458    }
 459
 460    pub fn has_adapter_logs(&self) -> bool {
 461        self.transport.lock().has_adapter_logs()
 462    }
 463
 464    pub fn add_log_handler<F>(&self, f: F, kind: LogKind)
 465    where
 466        F: 'static + Send + FnMut(IoKind, Option<&Command>, &IoMessage),
 467    {
 468        let mut log_handlers = self.log_handlers.lock();
 469        log_handlers.push((kind, Box::new(f)));
 470    }
 471}
 472
 473pub struct TcpTransport {
 474    executor: BackgroundExecutor,
 475    pub port: u16,
 476    pub host: Ipv4Addr,
 477    pub timeout: u64,
 478    process: Arc<Mutex<Option<Child>>>,
 479    _stderr_task: Option<Task<()>>,
 480    _stdout_task: Option<Task<()>>,
 481}
 482
 483impl TcpTransport {
 484    /// Get an open port to use with the tcp client when not supplied by debug config
 485    pub async fn port(host: &TcpArgumentsTemplate) -> Result<u16> {
 486        if let Some(port) = host.port {
 487            Ok(port)
 488        } else {
 489            Self::unused_port(host.host()).await
 490        }
 491    }
 492
 493    pub async fn unused_port(host: Ipv4Addr) -> Result<u16> {
 494        Ok(TcpListener::bind(SocketAddrV4::new(host, 0))
 495            .await?
 496            .local_addr()?
 497            .port())
 498    }
 499
 500    async fn start(
 501        binary: &DebugAdapterBinary,
 502        log_handlers: LogHandlers,
 503        cx: &mut AsyncApp,
 504    ) -> Result<Self> {
 505        let connection_args = binary
 506            .connection
 507            .as_ref()
 508            .context("No connection arguments provided")?;
 509
 510        let host = connection_args.host;
 511        let port = connection_args.port;
 512
 513        let mut process = None;
 514        let mut stdout_task = None;
 515        let mut stderr_task = None;
 516
 517        if let Some(command) = &binary.command {
 518            let mut command = util::command::new_std_command(&command);
 519
 520            if let Some(cwd) = &binary.cwd {
 521                command.current_dir(cwd);
 522            }
 523
 524            command.args(&binary.arguments);
 525            command.envs(&binary.envs);
 526
 527            let mut p = Child::spawn(command, Stdio::null())
 528                .with_context(|| "failed to start debug adapter.")?;
 529
 530            stdout_task = p.stdout.take().map(|stdout| {
 531                cx.background_executor()
 532                    .spawn(TransportDelegate::handle_adapter_log(
 533                        stdout,
 534                        IoKind::StdOut,
 535                        log_handlers.clone(),
 536                    ))
 537            });
 538            stderr_task = p.stderr.take().map(|stderr| {
 539                cx.background_executor()
 540                    .spawn(TransportDelegate::handle_adapter_log(
 541                        stderr,
 542                        IoKind::StdErr,
 543                        log_handlers,
 544                    ))
 545            });
 546            process = Some(p);
 547        };
 548
 549        let timeout = connection_args.timeout.unwrap_or_else(|| {
 550            cx.update(|cx| DebuggerSettings::get_global(cx).timeout)
 551                .unwrap_or(20000u64)
 552        });
 553
 554        log::info!(
 555            "Debug adapter has connected to TCP server {}:{}",
 556            host,
 557            port
 558        );
 559
 560        let this = Self {
 561            executor: cx.background_executor().clone(),
 562            port,
 563            host,
 564            process: Arc::new(Mutex::new(process)),
 565            timeout,
 566            _stdout_task: stdout_task,
 567            _stderr_task: stderr_task,
 568        };
 569
 570        Ok(this)
 571    }
 572}
 573
 574impl Transport for TcpTransport {
 575    fn has_adapter_logs(&self) -> bool {
 576        true
 577    }
 578
 579    fn kill(&mut self) {
 580        if let Some(process) = &mut *self.process.lock() {
 581            process.kill();
 582        }
 583    }
 584
 585    fn tcp_arguments(&self) -> Option<TcpArguments> {
 586        Some(TcpArguments {
 587            host: self.host,
 588            port: self.port,
 589            timeout: Some(self.timeout),
 590        })
 591    }
 592
 593    fn connect(
 594        &mut self,
 595    ) -> Task<
 596        Result<(
 597            Box<dyn AsyncWrite + Unpin + Send + 'static>,
 598            Box<dyn AsyncRead + Unpin + Send + 'static>,
 599        )>,
 600    > {
 601        let executor = self.executor.clone();
 602        let timeout = self.timeout;
 603        let address = SocketAddrV4::new(self.host, self.port);
 604        let process = self.process.clone();
 605        executor.clone().spawn(async move {
 606            select! {
 607                _ = executor.timer(Duration::from_millis(timeout)).fuse() => {
 608                    anyhow::bail!("Connection to TCP DAP timeout {address}");
 609                },
 610                result = executor.clone().spawn(async move {
 611                    loop {
 612                        match TcpStream::connect(address).await {
 613                            Ok(stream) => {
 614                                let (read, write) = stream.split();
 615                                return Ok((Box::new(write) as _, Box::new(read) as _))
 616                            },
 617                            Err(_) => {
 618                                let has_process = process.lock().is_some();
 619                                if has_process {
 620                                    let status = process.lock().as_mut().unwrap().try_status();
 621                                    if let Ok(Some(_)) = status {
 622                                        let process = process.lock().take().unwrap().into_inner();
 623                                        let output = process.output().await?;
 624                                        let output = if output.stderr.is_empty() {
 625                                            String::from_utf8_lossy(&output.stdout).to_string()
 626                                        } else {
 627                                            String::from_utf8_lossy(&output.stderr).to_string()
 628                                        };
 629                                        anyhow::bail!("{output}\nerror: process exited before debugger attached.");
 630                                    }
 631                                }
 632
 633                                executor.timer(Duration::from_millis(100)).await;
 634                            }
 635                        }
 636                    }
 637                }).fuse() => result
 638            }
 639        })
 640    }
 641}
 642
 643impl Drop for TcpTransport {
 644    fn drop(&mut self) {
 645        if let Some(mut p) = self.process.lock().take() {
 646            p.kill()
 647        }
 648    }
 649}
 650
 651pub struct StdioTransport {
 652    process: Mutex<Option<Child>>,
 653    _stderr_task: Option<Task<()>>,
 654}
 655
 656impl StdioTransport {
 657    // #[allow(dead_code, reason = "This is used in non test builds of Zed")]
 658    async fn start(
 659        binary: &DebugAdapterBinary,
 660        log_handlers: LogHandlers,
 661        cx: &mut AsyncApp,
 662    ) -> Result<Self> {
 663        let Some(binary_command) = &binary.command else {
 664            bail!(
 665                "When using the `stdio` transport, the path to a debug adapter binary must be set by Zed."
 666            );
 667        };
 668        let mut command = util::command::new_std_command(&binary_command);
 669
 670        if let Some(cwd) = &binary.cwd {
 671            command.current_dir(cwd);
 672        }
 673
 674        command.args(&binary.arguments);
 675        command.envs(&binary.envs);
 676
 677        let mut process = Child::spawn(command, Stdio::piped())?;
 678
 679        let err_task = process.stderr.take().map(|stderr| {
 680            cx.background_spawn(TransportDelegate::handle_adapter_log(
 681                stderr,
 682                IoKind::StdErr,
 683                log_handlers,
 684            ))
 685        });
 686
 687        let process = Mutex::new(Some(process));
 688
 689        Ok(Self {
 690            process,
 691            _stderr_task: err_task,
 692        })
 693    }
 694}
 695
 696impl Transport for StdioTransport {
 697    fn has_adapter_logs(&self) -> bool {
 698        false
 699    }
 700
 701    fn kill(&mut self) {
 702        if let Some(process) = &mut *self.process.lock() {
 703            process.kill();
 704        }
 705    }
 706
 707    fn connect(
 708        &mut self,
 709    ) -> Task<
 710        Result<(
 711            Box<dyn AsyncWrite + Unpin + Send + 'static>,
 712            Box<dyn AsyncRead + Unpin + Send + 'static>,
 713        )>,
 714    > {
 715        let result = util::maybe!({
 716            let mut guard = self.process.lock();
 717            let process = guard.as_mut().context("oops")?;
 718            Ok((
 719                Box::new(process.stdin.take().context("Cannot reconnect")?) as _,
 720                Box::new(process.stdout.take().context("Cannot reconnect")?) as _,
 721            ))
 722        });
 723        Task::ready(result)
 724    }
 725
 726    fn tcp_arguments(&self) -> Option<TcpArguments> {
 727        None
 728    }
 729}
 730
 731impl Drop for StdioTransport {
 732    fn drop(&mut self) {
 733        if let Some(process) = &mut *self.process.lock() {
 734            process.kill();
 735        }
 736    }
 737}
 738
 739#[cfg(any(test, feature = "test-support"))]
 740type RequestHandler = Box<dyn Send + FnMut(u64, serde_json::Value) -> RequestHandling<Response>>;
 741
 742#[cfg(any(test, feature = "test-support"))]
 743type ResponseHandler = Box<dyn Send + Fn(Response)>;
 744
 745#[cfg(any(test, feature = "test-support"))]
 746pub struct FakeTransport {
 747    // for sending fake response back from adapter side
 748    request_handlers: Arc<Mutex<HashMap<&'static str, RequestHandler>>>,
 749    // for reverse request responses
 750    response_handlers: Arc<Mutex<HashMap<&'static str, ResponseHandler>>>,
 751    message_handler: Option<Task<Result<()>>>,
 752    kind: FakeTransportKind,
 753}
 754
 755#[cfg(any(test, feature = "test-support"))]
 756pub enum FakeTransportKind {
 757    Stdio {
 758        stdin_writer: Option<PipeWriter>,
 759        stdout_reader: Option<PipeReader>,
 760    },
 761    Tcp {
 762        connection: TcpArguments,
 763        executor: BackgroundExecutor,
 764    },
 765}
 766
 767#[cfg(any(test, feature = "test-support"))]
 768impl FakeTransport {
 769    pub fn on_request<R: dap_types::requests::Request, F>(&self, mut handler: F)
 770    where
 771        F: 'static
 772            + Send
 773            + FnMut(u64, R::Arguments) -> RequestHandling<Result<R::Response, ErrorResponse>>,
 774    {
 775        self.request_handlers.lock().insert(
 776            R::COMMAND,
 777            Box::new(move |seq, args| {
 778                let result = handler(seq, serde_json::from_value(args).unwrap());
 779                let RequestHandling::Respond(response) = result else {
 780                    return RequestHandling::Exit;
 781                };
 782                let response = match response {
 783                    Ok(response) => Response {
 784                        seq: seq + 1,
 785                        request_seq: seq,
 786                        success: true,
 787                        command: R::COMMAND.into(),
 788                        body: Some(serde_json::to_value(response).unwrap()),
 789                        message: None,
 790                    },
 791                    Err(response) => Response {
 792                        seq: seq + 1,
 793                        request_seq: seq,
 794                        success: false,
 795                        command: R::COMMAND.into(),
 796                        body: Some(serde_json::to_value(response).unwrap()),
 797                        message: None,
 798                    },
 799                };
 800                RequestHandling::Respond(response)
 801            }),
 802        );
 803    }
 804
 805    pub fn on_response<R: dap_types::requests::Request, F>(&self, handler: F)
 806    where
 807        F: 'static + Send + Fn(Response),
 808    {
 809        self.response_handlers
 810            .lock()
 811            .insert(R::COMMAND, Box::new(handler));
 812    }
 813
 814    async fn start_tcp(connection: TcpArguments, cx: &mut AsyncApp) -> Result<Self> {
 815        Ok(Self {
 816            request_handlers: Arc::new(Mutex::new(HashMap::default())),
 817            response_handlers: Arc::new(Mutex::new(HashMap::default())),
 818            message_handler: None,
 819            kind: FakeTransportKind::Tcp {
 820                connection,
 821                executor: cx.background_executor().clone(),
 822            },
 823        })
 824    }
 825
 826    async fn handle_messages(
 827        request_handlers: Arc<Mutex<HashMap<&'static str, RequestHandler>>>,
 828        response_handlers: Arc<Mutex<HashMap<&'static str, ResponseHandler>>>,
 829        stdin_reader: PipeReader,
 830        stdout_writer: PipeWriter,
 831    ) -> Result<()> {
 832        use dap_types::requests::{Request, RunInTerminal, StartDebugging};
 833        use serde_json::json;
 834
 835        let mut reader = BufReader::new(stdin_reader);
 836        let stdout_writer = Arc::new(smol::lock::Mutex::new(stdout_writer));
 837        let mut buffer = String::new();
 838
 839        loop {
 840            match TransportDelegate::receive_server_message(&mut reader, &mut buffer, None).await {
 841                ConnectionResult::Timeout => {
 842                    anyhow::bail!("Timed out when connecting to debugger");
 843                }
 844                ConnectionResult::ConnectionReset => {
 845                    log::info!("Debugger closed the connection");
 846                    break Ok(());
 847                }
 848                ConnectionResult::Result(Err(e)) => break Err(e),
 849                ConnectionResult::Result(Ok(message)) => {
 850                    match message {
 851                        Message::Request(request) => {
 852                            // redirect reverse requests to stdout writer/reader
 853                            if request.command == RunInTerminal::COMMAND
 854                                || request.command == StartDebugging::COMMAND
 855                            {
 856                                let message =
 857                                    serde_json::to_string(&Message::Request(request)).unwrap();
 858
 859                                let mut writer = stdout_writer.lock().await;
 860                                writer
 861                                    .write_all(
 862                                        TransportDelegate::build_rpc_message(message).as_bytes(),
 863                                    )
 864                                    .await
 865                                    .unwrap();
 866                                writer.flush().await.unwrap();
 867                            } else {
 868                                let response = if let Some(handle) =
 869                                    request_handlers.lock().get_mut(request.command.as_str())
 870                                {
 871                                    handle(request.seq, request.arguments.unwrap_or(json!({})))
 872                                } else {
 873                                    panic!("No request handler for {}", request.command);
 874                                };
 875                                let response = match response {
 876                                    RequestHandling::Respond(response) => response,
 877                                    RequestHandling::Exit => {
 878                                        break Err(anyhow!("exit in response to request"));
 879                                    }
 880                                };
 881                                let success = response.success;
 882                                let message =
 883                                    serde_json::to_string(&Message::Response(response)).unwrap();
 884
 885                                let mut writer = stdout_writer.lock().await;
 886                                writer
 887                                    .write_all(
 888                                        TransportDelegate::build_rpc_message(message).as_bytes(),
 889                                    )
 890                                    .await
 891                                    .unwrap();
 892
 893                                if request.command == dap_types::requests::Initialize::COMMAND
 894                                    && success
 895                                {
 896                                    let message = serde_json::to_string(&Message::Event(Box::new(
 897                                        dap_types::messages::Events::Initialized(Some(
 898                                            Default::default(),
 899                                        )),
 900                                    )))
 901                                    .unwrap();
 902                                    writer
 903                                        .write_all(
 904                                            TransportDelegate::build_rpc_message(message)
 905                                                .as_bytes(),
 906                                        )
 907                                        .await
 908                                        .unwrap();
 909                                }
 910
 911                                writer.flush().await.unwrap();
 912                            }
 913                        }
 914                        Message::Event(event) => {
 915                            let message = serde_json::to_string(&Message::Event(event)).unwrap();
 916
 917                            let mut writer = stdout_writer.lock().await;
 918                            writer
 919                                .write_all(TransportDelegate::build_rpc_message(message).as_bytes())
 920                                .await
 921                                .unwrap();
 922                            writer.flush().await.unwrap();
 923                        }
 924                        Message::Response(response) => {
 925                            if let Some(handle) =
 926                                response_handlers.lock().get(response.command.as_str())
 927                            {
 928                                handle(response);
 929                            } else {
 930                                log::error!("No response handler for {}", response.command);
 931                            }
 932                        }
 933                    }
 934                }
 935            }
 936        }
 937    }
 938
 939    async fn start_stdio(cx: &mut AsyncApp) -> Result<Self> {
 940        let (stdin_writer, stdin_reader) = async_pipe::pipe();
 941        let (stdout_writer, stdout_reader) = async_pipe::pipe();
 942        let kind = FakeTransportKind::Stdio {
 943            stdin_writer: Some(stdin_writer),
 944            stdout_reader: Some(stdout_reader),
 945        };
 946
 947        let mut this = Self {
 948            request_handlers: Arc::new(Mutex::new(HashMap::default())),
 949            response_handlers: Arc::new(Mutex::new(HashMap::default())),
 950            message_handler: None,
 951            kind,
 952        };
 953
 954        let request_handlers = this.request_handlers.clone();
 955        let response_handlers = this.response_handlers.clone();
 956
 957        this.message_handler = Some(cx.background_spawn(Self::handle_messages(
 958            request_handlers,
 959            response_handlers,
 960            stdin_reader,
 961            stdout_writer,
 962        )));
 963
 964        Ok(this)
 965    }
 966}
 967
 968#[cfg(any(test, feature = "test-support"))]
 969impl Transport for FakeTransport {
 970    fn tcp_arguments(&self) -> Option<TcpArguments> {
 971        match &self.kind {
 972            FakeTransportKind::Stdio { .. } => None,
 973            FakeTransportKind::Tcp { connection, .. } => Some(connection.clone()),
 974        }
 975    }
 976
 977    fn connect(
 978        &mut self,
 979    ) -> Task<
 980        Result<(
 981            Box<dyn AsyncWrite + Unpin + Send + 'static>,
 982            Box<dyn AsyncRead + Unpin + Send + 'static>,
 983        )>,
 984    > {
 985        let result = match &mut self.kind {
 986            FakeTransportKind::Stdio {
 987                stdin_writer,
 988                stdout_reader,
 989            } => util::maybe!({
 990                Ok((
 991                    Box::new(stdin_writer.take().context("Cannot reconnect")?) as _,
 992                    Box::new(stdout_reader.take().context("Cannot reconnect")?) as _,
 993                ))
 994            }),
 995            FakeTransportKind::Tcp { executor, .. } => {
 996                let (stdin_writer, stdin_reader) = async_pipe::pipe();
 997                let (stdout_writer, stdout_reader) = async_pipe::pipe();
 998
 999                let request_handlers = self.request_handlers.clone();
1000                let response_handlers = self.response_handlers.clone();
1001
1002                self.message_handler = Some(executor.spawn(Self::handle_messages(
1003                    request_handlers,
1004                    response_handlers,
1005                    stdin_reader,
1006                    stdout_writer,
1007                )));
1008
1009                Ok((Box::new(stdin_writer) as _, Box::new(stdout_reader) as _))
1010            }
1011        };
1012        Task::ready(result)
1013    }
1014
1015    fn has_adapter_logs(&self) -> bool {
1016        false
1017    }
1018
1019    fn kill(&mut self) {
1020        self.message_handler.take();
1021    }
1022
1023    #[cfg(any(test, feature = "test-support"))]
1024    fn as_fake(&self) -> &FakeTransport {
1025        self
1026    }
1027}
1028
1029struct Child {
1030    process: smol::process::Child,
1031}
1032
1033impl std::ops::Deref for Child {
1034    type Target = smol::process::Child;
1035
1036    fn deref(&self) -> &Self::Target {
1037        &self.process
1038    }
1039}
1040
1041impl std::ops::DerefMut for Child {
1042    fn deref_mut(&mut self) -> &mut Self::Target {
1043        &mut self.process
1044    }
1045}
1046
1047impl Child {
1048    fn into_inner(self) -> smol::process::Child {
1049        self.process
1050    }
1051
1052    #[cfg(not(windows))]
1053    fn spawn(mut command: std::process::Command, stdin: Stdio) -> Result<Self> {
1054        util::set_pre_exec_to_start_new_session(&mut command);
1055        let mut command = smol::process::Command::from(command);
1056        let process = command
1057            .stdin(stdin)
1058            .stdout(Stdio::piped())
1059            .stderr(Stdio::piped())
1060            .spawn()
1061            .with_context(|| format!("failed to spawn command `{command:?}`",))?;
1062        Ok(Self { process })
1063    }
1064
1065    #[cfg(windows)]
1066    fn spawn(command: std::process::Command, stdin: Stdio) -> Result<Self> {
1067        // TODO(windows): create a job object and add the child process handle to it,
1068        // see https://learn.microsoft.com/en-us/windows/win32/procthread/job-objects
1069        let mut command = smol::process::Command::from(command);
1070        let process = command
1071            .stdin(stdin)
1072            .stdout(Stdio::piped())
1073            .stderr(Stdio::piped())
1074            .spawn()
1075            .with_context(|| format!("failed to spawn command `{command:?}`",))?;
1076        Ok(Self { process })
1077    }
1078
1079    #[cfg(not(windows))]
1080    fn kill(&mut self) {
1081        let pid = self.process.id();
1082        unsafe {
1083            libc::killpg(pid as i32, libc::SIGKILL);
1084        }
1085    }
1086
1087    #[cfg(windows)]
1088    fn kill(&mut self) {
1089        // TODO(windows): terminate the job object in kill
1090        let _ = self.process.kill();
1091    }
1092}