transport.rs

   1use anyhow::{Context as _, Result, anyhow, bail};
   2#[cfg(any(test, feature = "test-support"))]
   3use async_pipe::{PipeReader, PipeWriter};
   4use dap_types::{
   5    ErrorResponse,
   6    messages::{Message, Response},
   7};
   8use futures::{AsyncRead, AsyncReadExt as _, AsyncWrite, FutureExt as _, channel::oneshot, select};
   9use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Task};
  10use parking_lot::Mutex;
  11use proto::ErrorExt;
  12use settings::Settings as _;
  13use smallvec::SmallVec;
  14use smol::{
  15    channel::{Receiver, Sender, unbounded},
  16    io::{AsyncBufReadExt as _, AsyncWriteExt, BufReader},
  17    net::{TcpListener, TcpStream},
  18};
  19use std::{
  20    collections::HashMap,
  21    net::{Ipv4Addr, SocketAddrV4},
  22    process::Stdio,
  23    sync::Arc,
  24    time::Duration,
  25};
  26use task::TcpArgumentsTemplate;
  27use util::ConnectionResult;
  28
  29use crate::{
  30    adapters::{DebugAdapterBinary, TcpArguments},
  31    client::DapMessageHandler,
  32    debugger_settings::DebuggerSettings,
  33};
  34
  35pub(crate) type IoMessage = str;
  36pub(crate) type Command = str;
  37pub type IoHandler = Box<dyn Send + FnMut(IoKind, Option<&Command>, &IoMessage)>;
  38
  39#[derive(PartialEq, Eq, Clone, Copy)]
  40pub enum LogKind {
  41    Adapter,
  42    Rpc,
  43}
  44
  45#[derive(Clone, Copy)]
  46pub enum IoKind {
  47    StdIn,
  48    StdOut,
  49    StdErr,
  50}
  51
  52#[cfg(any(test, feature = "test-support"))]
  53pub enum RequestHandling<T> {
  54    Respond(T),
  55    Exit,
  56}
  57
  58type LogHandlers = Arc<Mutex<SmallVec<[(LogKind, IoHandler); 2]>>>;
  59
  60pub trait Transport: Send + Sync {
  61    fn has_adapter_logs(&self) -> bool;
  62    fn tcp_arguments(&self) -> Option<TcpArguments>;
  63    fn connect(
  64        &mut self,
  65    ) -> Task<
  66        Result<(
  67            Box<dyn AsyncWrite + Unpin + Send + 'static>,
  68            Box<dyn AsyncRead + Unpin + Send + 'static>,
  69        )>,
  70    >;
  71    fn kill(&mut self);
  72    #[cfg(any(test, feature = "test-support"))]
  73    fn as_fake(&self) -> &FakeTransport {
  74        unreachable!()
  75    }
  76}
  77
  78async fn start(
  79    binary: &DebugAdapterBinary,
  80    log_handlers: LogHandlers,
  81    cx: &mut AsyncApp,
  82) -> Result<Box<dyn Transport>> {
  83    #[cfg(any(test, feature = "test-support"))]
  84    if cfg!(any(test, feature = "test-support")) {
  85        if let Some(connection) = binary.connection.clone() {
  86            return Ok(Box::new(FakeTransport::start_tcp(connection, cx).await?));
  87        } else {
  88            return Ok(Box::new(FakeTransport::start_stdio(cx).await?));
  89        }
  90    }
  91
  92    if binary.connection.is_some() {
  93        Ok(Box::new(
  94            TcpTransport::start(binary, log_handlers, cx).await?,
  95        ))
  96    } else {
  97        Ok(Box::new(
  98            StdioTransport::start(binary, log_handlers, cx).await?,
  99        ))
 100    }
 101}
 102
 103pub(crate) struct PendingRequests {
 104    inner: Option<HashMap<u64, oneshot::Sender<Result<Response>>>>,
 105}
 106
 107impl PendingRequests {
 108    fn new() -> Self {
 109        Self {
 110            inner: Some(HashMap::default()),
 111        }
 112    }
 113
 114    fn flush(&mut self, e: anyhow::Error) {
 115        let Some(inner) = self.inner.as_mut() else {
 116            return;
 117        };
 118        for (_, sender) in inner.drain() {
 119            sender.send(Err(e.cloned())).ok();
 120        }
 121    }
 122
 123    pub(crate) fn insert(
 124        &mut self,
 125        sequence_id: u64,
 126        callback_tx: oneshot::Sender<Result<Response>>,
 127    ) -> anyhow::Result<()> {
 128        let Some(inner) = self.inner.as_mut() else {
 129            bail!("client is closed")
 130        };
 131        inner.insert(sequence_id, callback_tx);
 132        Ok(())
 133    }
 134
 135    pub(crate) fn remove(
 136        &mut self,
 137        sequence_id: u64,
 138    ) -> anyhow::Result<Option<oneshot::Sender<Result<Response>>>> {
 139        let Some(inner) = self.inner.as_mut() else {
 140            bail!("client is closed");
 141        };
 142        Ok(inner.remove(&sequence_id))
 143    }
 144
 145    pub(crate) fn shutdown(&mut self) {
 146        self.flush(anyhow!("transport shutdown"));
 147        self.inner = None;
 148    }
 149}
 150
 151pub(crate) struct TransportDelegate {
 152    log_handlers: LogHandlers,
 153    pub(crate) pending_requests: Arc<Mutex<PendingRequests>>,
 154    pub(crate) transport: Mutex<Box<dyn Transport>>,
 155    pub(crate) server_tx: smol::lock::Mutex<Option<Sender<Message>>>,
 156    tasks: Mutex<Vec<Task<()>>>,
 157}
 158
 159impl TransportDelegate {
 160    pub(crate) async fn start(binary: &DebugAdapterBinary, cx: &mut AsyncApp) -> Result<Self> {
 161        let log_handlers: LogHandlers = Default::default();
 162        let transport = start(binary, log_handlers.clone(), cx).await?;
 163        Ok(Self {
 164            transport: Mutex::new(transport),
 165            log_handlers,
 166            server_tx: Default::default(),
 167            pending_requests: Arc::new(Mutex::new(PendingRequests::new())),
 168            tasks: Default::default(),
 169        })
 170    }
 171
 172    pub async fn connect(
 173        &self,
 174        message_handler: DapMessageHandler,
 175        cx: &mut AsyncApp,
 176    ) -> Result<()> {
 177        let (server_tx, client_rx) = unbounded::<Message>();
 178        self.tasks.lock().clear();
 179
 180        let log_dap_communications =
 181            cx.update(|cx| DebuggerSettings::get_global(cx).log_dap_communications)
 182                .with_context(|| "Failed to get Debugger Setting log dap communications error in transport::start_handlers. Defaulting to false")
 183                .unwrap_or(false);
 184
 185        let connect = self.transport.lock().connect();
 186        let (input, output) = connect.await?;
 187
 188        let log_handler = if log_dap_communications {
 189            Some(self.log_handlers.clone())
 190        } else {
 191            None
 192        };
 193
 194        let pending_requests = self.pending_requests.clone();
 195        let output_log_handler = log_handler.clone();
 196        {
 197            let mut tasks = self.tasks.lock();
 198            tasks.push(cx.background_spawn(async move {
 199                match Self::recv_from_server(
 200                    output,
 201                    message_handler,
 202                    pending_requests.clone(),
 203                    output_log_handler,
 204                )
 205                .await
 206                {
 207                    Ok(()) => {
 208                        pending_requests
 209                            .lock()
 210                            .flush(anyhow!("debugger shutdown unexpectedly"));
 211                    }
 212                    Err(e) => {
 213                        pending_requests.lock().flush(e);
 214                    }
 215                }
 216            }));
 217
 218            tasks.push(cx.background_spawn(async move {
 219                match Self::send_to_server(input, client_rx, log_handler).await {
 220                    Ok(()) => {}
 221                    Err(e) => log::error!("Error handling debugger input: {e}"),
 222                }
 223            }));
 224        }
 225
 226        {
 227            let mut lock = self.server_tx.lock().await;
 228            *lock = Some(server_tx.clone());
 229        }
 230
 231        Ok(())
 232    }
 233
 234    pub(crate) fn tcp_arguments(&self) -> Option<TcpArguments> {
 235        self.transport.lock().tcp_arguments()
 236    }
 237
 238    pub(crate) async fn send_message(&self, message: Message) -> Result<()> {
 239        if let Some(server_tx) = self.server_tx.lock().await.as_ref() {
 240            server_tx.send(message).await.context("sending message")
 241        } else {
 242            anyhow::bail!("Server tx already dropped")
 243        }
 244    }
 245
 246    async fn handle_adapter_log(
 247        stdout: impl AsyncRead + Unpin + Send + 'static,
 248        iokind: IoKind,
 249        log_handlers: LogHandlers,
 250    ) {
 251        let mut reader = BufReader::new(stdout);
 252        let mut line = String::new();
 253
 254        loop {
 255            line.truncate(0);
 256
 257            match reader.read_line(&mut line).await {
 258                Ok(0) => break,
 259                Ok(_) => {}
 260                Err(e) => {
 261                    log::debug!("handle_adapter_log: {}", e);
 262                    break;
 263                }
 264            }
 265            log::debug!("stderr: {line}");
 266
 267            for (kind, handler) in log_handlers.lock().iter_mut() {
 268                if matches!(kind, LogKind::Adapter) {
 269                    handler(iokind, None, line.as_str());
 270                }
 271            }
 272        }
 273    }
 274
 275    fn build_rpc_message(message: String) -> String {
 276        format!("Content-Length: {}\r\n\r\n{}", message.len(), message)
 277    }
 278
 279    async fn send_to_server<Stdin>(
 280        mut server_stdin: Stdin,
 281        client_rx: Receiver<Message>,
 282        log_handlers: Option<LogHandlers>,
 283    ) -> Result<()>
 284    where
 285        Stdin: AsyncWrite + Unpin + Send + 'static,
 286    {
 287        let result = loop {
 288            match client_rx.recv().await {
 289                Ok(message) => {
 290                    let command = match &message {
 291                        Message::Request(request) => Some(request.command.as_str()),
 292                        Message::Response(response) => Some(response.command.as_str()),
 293                        _ => None,
 294                    };
 295
 296                    let message = match serde_json::to_string(&message) {
 297                        Ok(message) => message,
 298                        Err(e) => break Err(e.into()),
 299                    };
 300
 301                    if let Some(log_handlers) = log_handlers.as_ref() {
 302                        for (kind, log_handler) in log_handlers.lock().iter_mut() {
 303                            if matches!(kind, LogKind::Rpc) {
 304                                log_handler(IoKind::StdIn, command, &message);
 305                            }
 306                        }
 307                    }
 308
 309                    if let Err(e) = server_stdin
 310                        .write_all(Self::build_rpc_message(message).as_bytes())
 311                        .await
 312                    {
 313                        break Err(e.into());
 314                    }
 315
 316                    if let Err(e) = server_stdin.flush().await {
 317                        break Err(e.into());
 318                    }
 319                }
 320                Err(error) => break Err(error.into()),
 321            }
 322        };
 323
 324        log::debug!("Handle adapter input dropped");
 325
 326        result
 327    }
 328
 329    async fn recv_from_server<Stdout>(
 330        server_stdout: Stdout,
 331        mut message_handler: DapMessageHandler,
 332        pending_requests: Arc<Mutex<PendingRequests>>,
 333        log_handlers: Option<LogHandlers>,
 334    ) -> Result<()>
 335    where
 336        Stdout: AsyncRead + Unpin + Send + 'static,
 337    {
 338        let mut recv_buffer = String::new();
 339        let mut reader = BufReader::new(server_stdout);
 340
 341        let result = loop {
 342            let result =
 343                Self::receive_server_message(&mut reader, &mut recv_buffer, log_handlers.as_ref())
 344                    .await;
 345            match result {
 346                ConnectionResult::Timeout => anyhow::bail!("Timed out when connecting to debugger"),
 347                ConnectionResult::ConnectionReset => {
 348                    log::info!("Debugger closed the connection");
 349                    return Ok(());
 350                }
 351                ConnectionResult::Result(Ok(Message::Response(res))) => {
 352                    let tx = pending_requests.lock().remove(res.request_seq)?;
 353                    if let Some(tx) = tx {
 354                        if let Err(e) = tx.send(Self::process_response(res)) {
 355                            log::trace!("Did not send response `{:?}` for a cancelled", e);
 356                        }
 357                    } else {
 358                        message_handler(Message::Response(res))
 359                    }
 360                }
 361                ConnectionResult::Result(Ok(message)) => message_handler(message),
 362                ConnectionResult::Result(Err(e)) => break Err(e),
 363            }
 364        };
 365
 366        log::debug!("Handle adapter output dropped");
 367
 368        result
 369    }
 370
 371    fn process_response(response: Response) -> Result<Response> {
 372        if response.success {
 373            Ok(response)
 374        } else {
 375            if let Some(error_message) = response
 376                .body
 377                .clone()
 378                .and_then(|body| serde_json::from_value::<ErrorResponse>(body).ok())
 379                .and_then(|response| response.error.map(|msg| msg.format))
 380                .or_else(|| response.message.clone())
 381            {
 382                anyhow::bail!(error_message);
 383            };
 384
 385            anyhow::bail!(
 386                "Received error response from adapter. Response: {:?}",
 387                response
 388            );
 389        }
 390    }
 391
 392    async fn receive_server_message<Stdout>(
 393        reader: &mut BufReader<Stdout>,
 394        buffer: &mut String,
 395        log_handlers: Option<&LogHandlers>,
 396    ) -> ConnectionResult<Message>
 397    where
 398        Stdout: AsyncRead + Unpin + Send + 'static,
 399    {
 400        let mut content_length = None;
 401        loop {
 402            buffer.truncate(0);
 403            match reader.read_line(buffer).await {
 404                Ok(0) => return ConnectionResult::ConnectionReset,
 405                Ok(_) => {}
 406                Err(e) => return ConnectionResult::Result(Err(e.into())),
 407            };
 408
 409            if buffer == "\r\n" {
 410                break;
 411            }
 412
 413            if let Some(("Content-Length", value)) = buffer.trim().split_once(": ") {
 414                match value.parse().context("invalid content length") {
 415                    Ok(length) => content_length = Some(length),
 416                    Err(e) => return ConnectionResult::Result(Err(e)),
 417                }
 418            }
 419        }
 420
 421        let content_length = match content_length.context("missing content length") {
 422            Ok(length) => length,
 423            Err(e) => return ConnectionResult::Result(Err(e)),
 424        };
 425
 426        let mut content = vec![0; content_length];
 427        if let Err(e) = reader
 428            .read_exact(&mut content)
 429            .await
 430            .with_context(|| "reading after a loop")
 431        {
 432            return ConnectionResult::Result(Err(e));
 433        }
 434
 435        let message_str = match std::str::from_utf8(&content).context("invalid utf8 from server") {
 436            Ok(str) => str,
 437            Err(e) => return ConnectionResult::Result(Err(e)),
 438        };
 439
 440        let message =
 441            serde_json::from_str::<Message>(message_str).context("deserializing server message");
 442
 443        if let Some(log_handlers) = log_handlers {
 444            let command = match &message {
 445                Ok(Message::Request(request)) => Some(request.command.as_str()),
 446                Ok(Message::Response(response)) => Some(response.command.as_str()),
 447                _ => None,
 448            };
 449
 450            for (kind, log_handler) in log_handlers.lock().iter_mut() {
 451                if matches!(kind, LogKind::Rpc) {
 452                    log_handler(IoKind::StdOut, command, message_str);
 453                }
 454            }
 455        }
 456
 457        ConnectionResult::Result(message)
 458    }
 459
 460    pub fn has_adapter_logs(&self) -> bool {
 461        self.transport.lock().has_adapter_logs()
 462    }
 463
 464    pub fn add_log_handler<F>(&self, f: F, kind: LogKind)
 465    where
 466        F: 'static + Send + FnMut(IoKind, Option<&Command>, &IoMessage),
 467    {
 468        let mut log_handlers = self.log_handlers.lock();
 469        log_handlers.push((kind, Box::new(f)));
 470    }
 471}
 472
 473pub struct TcpTransport {
 474    executor: BackgroundExecutor,
 475    pub port: u16,
 476    pub host: Ipv4Addr,
 477    pub timeout: u64,
 478    process: Arc<Mutex<Option<Child>>>,
 479    _stderr_task: Option<Task<()>>,
 480    _stdout_task: Option<Task<()>>,
 481}
 482
 483impl TcpTransport {
 484    /// Get an open port to use with the tcp client when not supplied by debug config
 485    pub async fn port(host: &TcpArgumentsTemplate) -> Result<u16> {
 486        if let Some(port) = host.port {
 487            Ok(port)
 488        } else {
 489            Self::unused_port(host.host()).await
 490        }
 491    }
 492
 493    pub async fn unused_port(host: Ipv4Addr) -> Result<u16> {
 494        Ok(TcpListener::bind(SocketAddrV4::new(host, 0))
 495            .await?
 496            .local_addr()?
 497            .port())
 498    }
 499
 500    async fn start(
 501        binary: &DebugAdapterBinary,
 502        log_handlers: LogHandlers,
 503        cx: &mut AsyncApp,
 504    ) -> Result<Self> {
 505        let connection_args = binary
 506            .connection
 507            .as_ref()
 508            .context("No connection arguments provided")?;
 509
 510        let host = connection_args.host;
 511        let port = connection_args.port;
 512
 513        let mut process = None;
 514        let mut stdout_task = None;
 515        let mut stderr_task = None;
 516
 517        if let Some(command) = &binary.command {
 518            let mut command = util::command::new_std_command(&command);
 519
 520            if let Some(cwd) = &binary.cwd {
 521                command.current_dir(cwd);
 522            }
 523
 524            command.args(&binary.arguments);
 525            command.envs(&binary.envs);
 526
 527            let mut p = Child::spawn(command, Stdio::null())
 528                .with_context(|| "failed to start debug adapter.")?;
 529
 530            stdout_task = p.stdout.take().map(|stdout| {
 531                cx.background_executor()
 532                    .spawn(TransportDelegate::handle_adapter_log(
 533                        stdout,
 534                        IoKind::StdOut,
 535                        log_handlers.clone(),
 536                    ))
 537            });
 538            stderr_task = p.stderr.take().map(|stderr| {
 539                cx.background_executor()
 540                    .spawn(TransportDelegate::handle_adapter_log(
 541                        stderr,
 542                        IoKind::StdErr,
 543                        log_handlers,
 544                    ))
 545            });
 546            process = Some(p);
 547        };
 548
 549        let timeout = connection_args.timeout.unwrap_or_else(|| {
 550            cx.update(|cx| DebuggerSettings::get_global(cx).timeout)
 551                .unwrap_or(20000u64)
 552        });
 553
 554        log::info!(
 555            "Debug adapter has connected to TCP server {}:{}",
 556            host,
 557            port
 558        );
 559
 560        let this = Self {
 561            executor: cx.background_executor().clone(),
 562            port,
 563            host,
 564            process: Arc::new(Mutex::new(process)),
 565            timeout,
 566            _stdout_task: stdout_task,
 567            _stderr_task: stderr_task,
 568        };
 569
 570        Ok(this)
 571    }
 572}
 573
 574impl Transport for TcpTransport {
 575    fn has_adapter_logs(&self) -> bool {
 576        true
 577    }
 578
 579    fn kill(&mut self) {
 580        if let Some(process) = &mut *self.process.lock() {
 581            process.kill();
 582        }
 583    }
 584
 585    fn tcp_arguments(&self) -> Option<TcpArguments> {
 586        Some(TcpArguments {
 587            host: self.host,
 588            port: self.port,
 589            timeout: Some(self.timeout),
 590        })
 591    }
 592
 593    fn connect(
 594        &mut self,
 595    ) -> Task<
 596        Result<(
 597            Box<dyn AsyncWrite + Unpin + Send + 'static>,
 598            Box<dyn AsyncRead + Unpin + Send + 'static>,
 599        )>,
 600    > {
 601        let executor = self.executor.clone();
 602        let timeout = self.timeout;
 603        let address = SocketAddrV4::new(self.host, self.port);
 604        let process = self.process.clone();
 605        executor.clone().spawn(async move {
 606            select! {
 607                _ = executor.timer(Duration::from_millis(timeout)).fuse() => {
 608                    anyhow::bail!("Connection to TCP DAP timeout {address}");
 609                },
 610                result = executor.clone().spawn(async move {
 611                    loop {
 612                        match TcpStream::connect(address).await {
 613                            Ok(stream) => {
 614                                let (read, write) = stream.split();
 615                                return Ok((Box::new(write) as _, Box::new(read) as _))
 616                            },
 617                            Err(_) => {
 618                                let has_process = process.lock().is_some();
 619                                if has_process {
 620                                    let status = process.lock().as_mut().unwrap().try_status();
 621                                    if let Ok(Some(_)) = status {
 622                                        let process = process.lock().take().unwrap().into_inner();
 623                                        let output = process.output().await?;
 624                                        let output = if output.stderr.is_empty() {
 625                                            String::from_utf8_lossy(&output.stdout).to_string()
 626                                        } else {
 627                                            String::from_utf8_lossy(&output.stderr).to_string()
 628                                        };
 629                                        anyhow::bail!("{output}\nerror: process exited before debugger attached.");
 630                                    }
 631                                }
 632
 633                                executor.timer(Duration::from_millis(100)).await;
 634                            }
 635                        }
 636                    }
 637                }).fuse() => result
 638            }
 639        })
 640    }
 641}
 642
 643impl Drop for TcpTransport {
 644    fn drop(&mut self) {
 645        if let Some(mut p) = self.process.lock().take() {
 646            p.kill()
 647        }
 648    }
 649}
 650
 651pub struct StdioTransport {
 652    process: Mutex<Child>,
 653    _stderr_task: Option<Task<()>>,
 654}
 655
 656impl StdioTransport {
 657    // #[allow(dead_code, reason = "This is used in non test builds of Zed")]
 658    async fn start(
 659        binary: &DebugAdapterBinary,
 660        log_handlers: LogHandlers,
 661        cx: &mut AsyncApp,
 662    ) -> Result<Self> {
 663        let Some(binary_command) = &binary.command else {
 664            bail!(
 665                "When using the `stdio` transport, the path to a debug adapter binary must be set by Zed."
 666            );
 667        };
 668        let mut command = util::command::new_std_command(&binary_command);
 669
 670        if let Some(cwd) = &binary.cwd {
 671            command.current_dir(cwd);
 672        }
 673
 674        command.args(&binary.arguments);
 675        command.envs(&binary.envs);
 676
 677        let mut process = Child::spawn(command, Stdio::piped())?;
 678
 679        let _stderr_task = process.stderr.take().map(|stderr| {
 680            cx.background_spawn(TransportDelegate::handle_adapter_log(
 681                stderr,
 682                IoKind::StdErr,
 683                log_handlers,
 684            ))
 685        });
 686
 687        let process = Mutex::new(process);
 688
 689        Ok(Self {
 690            process,
 691            _stderr_task,
 692        })
 693    }
 694}
 695
 696impl Transport for StdioTransport {
 697    fn has_adapter_logs(&self) -> bool {
 698        true
 699    }
 700
 701    fn kill(&mut self) {
 702        self.process.lock().kill();
 703    }
 704
 705    fn connect(
 706        &mut self,
 707    ) -> Task<
 708        Result<(
 709            Box<dyn AsyncWrite + Unpin + Send + 'static>,
 710            Box<dyn AsyncRead + Unpin + Send + 'static>,
 711        )>,
 712    > {
 713        let result = util::maybe!({
 714            let mut process = self.process.lock();
 715            Ok((
 716                Box::new(process.stdin.take().context("Cannot reconnect")?) as _,
 717                Box::new(process.stdout.take().context("Cannot reconnect")?) as _,
 718            ))
 719        });
 720        Task::ready(result)
 721    }
 722
 723    fn tcp_arguments(&self) -> Option<TcpArguments> {
 724        None
 725    }
 726}
 727
 728impl Drop for StdioTransport {
 729    fn drop(&mut self) {
 730        self.process.lock().kill();
 731    }
 732}
 733
 734#[cfg(any(test, feature = "test-support"))]
 735type RequestHandler = Box<dyn Send + FnMut(u64, serde_json::Value) -> RequestHandling<Response>>;
 736
 737#[cfg(any(test, feature = "test-support"))]
 738type ResponseHandler = Box<dyn Send + Fn(Response)>;
 739
 740#[cfg(any(test, feature = "test-support"))]
 741pub struct FakeTransport {
 742    // for sending fake response back from adapter side
 743    request_handlers: Arc<Mutex<HashMap<&'static str, RequestHandler>>>,
 744    // for reverse request responses
 745    response_handlers: Arc<Mutex<HashMap<&'static str, ResponseHandler>>>,
 746    message_handler: Option<Task<Result<()>>>,
 747    kind: FakeTransportKind,
 748}
 749
 750#[cfg(any(test, feature = "test-support"))]
 751pub enum FakeTransportKind {
 752    Stdio {
 753        stdin_writer: Option<PipeWriter>,
 754        stdout_reader: Option<PipeReader>,
 755    },
 756    Tcp {
 757        connection: TcpArguments,
 758        executor: BackgroundExecutor,
 759    },
 760}
 761
 762#[cfg(any(test, feature = "test-support"))]
 763impl FakeTransport {
 764    pub fn on_request<R: dap_types::requests::Request, F>(&self, mut handler: F)
 765    where
 766        F: 'static
 767            + Send
 768            + FnMut(u64, R::Arguments) -> RequestHandling<Result<R::Response, ErrorResponse>>,
 769    {
 770        self.request_handlers.lock().insert(
 771            R::COMMAND,
 772            Box::new(move |seq, args| {
 773                let result = handler(seq, serde_json::from_value(args).unwrap());
 774                let RequestHandling::Respond(response) = result else {
 775                    return RequestHandling::Exit;
 776                };
 777                let response = match response {
 778                    Ok(response) => Response {
 779                        seq: seq + 1,
 780                        request_seq: seq,
 781                        success: true,
 782                        command: R::COMMAND.into(),
 783                        body: Some(serde_json::to_value(response).unwrap()),
 784                        message: None,
 785                    },
 786                    Err(response) => Response {
 787                        seq: seq + 1,
 788                        request_seq: seq,
 789                        success: false,
 790                        command: R::COMMAND.into(),
 791                        body: Some(serde_json::to_value(response).unwrap()),
 792                        message: None,
 793                    },
 794                };
 795                RequestHandling::Respond(response)
 796            }),
 797        );
 798    }
 799
 800    pub fn on_response<R: dap_types::requests::Request, F>(&self, handler: F)
 801    where
 802        F: 'static + Send + Fn(Response),
 803    {
 804        self.response_handlers
 805            .lock()
 806            .insert(R::COMMAND, Box::new(handler));
 807    }
 808
 809    async fn start_tcp(connection: TcpArguments, cx: &mut AsyncApp) -> Result<Self> {
 810        Ok(Self {
 811            request_handlers: Arc::new(Mutex::new(HashMap::default())),
 812            response_handlers: Arc::new(Mutex::new(HashMap::default())),
 813            message_handler: None,
 814            kind: FakeTransportKind::Tcp {
 815                connection,
 816                executor: cx.background_executor().clone(),
 817            },
 818        })
 819    }
 820
 821    async fn handle_messages(
 822        request_handlers: Arc<Mutex<HashMap<&'static str, RequestHandler>>>,
 823        response_handlers: Arc<Mutex<HashMap<&'static str, ResponseHandler>>>,
 824        stdin_reader: PipeReader,
 825        stdout_writer: PipeWriter,
 826    ) -> Result<()> {
 827        use dap_types::requests::{Request, RunInTerminal, StartDebugging};
 828        use serde_json::json;
 829
 830        let mut reader = BufReader::new(stdin_reader);
 831        let stdout_writer = Arc::new(smol::lock::Mutex::new(stdout_writer));
 832        let mut buffer = String::new();
 833
 834        loop {
 835            match TransportDelegate::receive_server_message(&mut reader, &mut buffer, None).await {
 836                ConnectionResult::Timeout => {
 837                    anyhow::bail!("Timed out when connecting to debugger");
 838                }
 839                ConnectionResult::ConnectionReset => {
 840                    log::info!("Debugger closed the connection");
 841                    break Ok(());
 842                }
 843                ConnectionResult::Result(Err(e)) => break Err(e),
 844                ConnectionResult::Result(Ok(message)) => {
 845                    match message {
 846                        Message::Request(request) => {
 847                            // redirect reverse requests to stdout writer/reader
 848                            if request.command == RunInTerminal::COMMAND
 849                                || request.command == StartDebugging::COMMAND
 850                            {
 851                                let message =
 852                                    serde_json::to_string(&Message::Request(request)).unwrap();
 853
 854                                let mut writer = stdout_writer.lock().await;
 855                                writer
 856                                    .write_all(
 857                                        TransportDelegate::build_rpc_message(message).as_bytes(),
 858                                    )
 859                                    .await
 860                                    .unwrap();
 861                                writer.flush().await.unwrap();
 862                            } else {
 863                                let response = if let Some(handle) =
 864                                    request_handlers.lock().get_mut(request.command.as_str())
 865                                {
 866                                    handle(request.seq, request.arguments.unwrap_or(json!({})))
 867                                } else {
 868                                    panic!("No request handler for {}", request.command);
 869                                };
 870                                let response = match response {
 871                                    RequestHandling::Respond(response) => response,
 872                                    RequestHandling::Exit => {
 873                                        break Err(anyhow!("exit in response to request"));
 874                                    }
 875                                };
 876                                let success = response.success;
 877                                let message =
 878                                    serde_json::to_string(&Message::Response(response)).unwrap();
 879
 880                                let mut writer = stdout_writer.lock().await;
 881                                writer
 882                                    .write_all(
 883                                        TransportDelegate::build_rpc_message(message).as_bytes(),
 884                                    )
 885                                    .await
 886                                    .unwrap();
 887
 888                                if request.command == dap_types::requests::Initialize::COMMAND
 889                                    && success
 890                                {
 891                                    let message = serde_json::to_string(&Message::Event(Box::new(
 892                                        dap_types::messages::Events::Initialized(Some(
 893                                            Default::default(),
 894                                        )),
 895                                    )))
 896                                    .unwrap();
 897                                    writer
 898                                        .write_all(
 899                                            TransportDelegate::build_rpc_message(message)
 900                                                .as_bytes(),
 901                                        )
 902                                        .await
 903                                        .unwrap();
 904                                }
 905
 906                                writer.flush().await.unwrap();
 907                            }
 908                        }
 909                        Message::Event(event) => {
 910                            let message = serde_json::to_string(&Message::Event(event)).unwrap();
 911
 912                            let mut writer = stdout_writer.lock().await;
 913                            writer
 914                                .write_all(TransportDelegate::build_rpc_message(message).as_bytes())
 915                                .await
 916                                .unwrap();
 917                            writer.flush().await.unwrap();
 918                        }
 919                        Message::Response(response) => {
 920                            if let Some(handle) =
 921                                response_handlers.lock().get(response.command.as_str())
 922                            {
 923                                handle(response);
 924                            } else {
 925                                log::error!("No response handler for {}", response.command);
 926                            }
 927                        }
 928                    }
 929                }
 930            }
 931        }
 932    }
 933
 934    async fn start_stdio(cx: &mut AsyncApp) -> Result<Self> {
 935        let (stdin_writer, stdin_reader) = async_pipe::pipe();
 936        let (stdout_writer, stdout_reader) = async_pipe::pipe();
 937        let kind = FakeTransportKind::Stdio {
 938            stdin_writer: Some(stdin_writer),
 939            stdout_reader: Some(stdout_reader),
 940        };
 941
 942        let mut this = Self {
 943            request_handlers: Arc::new(Mutex::new(HashMap::default())),
 944            response_handlers: Arc::new(Mutex::new(HashMap::default())),
 945            message_handler: None,
 946            kind,
 947        };
 948
 949        let request_handlers = this.request_handlers.clone();
 950        let response_handlers = this.response_handlers.clone();
 951
 952        this.message_handler = Some(cx.background_spawn(Self::handle_messages(
 953            request_handlers,
 954            response_handlers,
 955            stdin_reader,
 956            stdout_writer,
 957        )));
 958
 959        Ok(this)
 960    }
 961}
 962
 963#[cfg(any(test, feature = "test-support"))]
 964impl Transport for FakeTransport {
 965    fn tcp_arguments(&self) -> Option<TcpArguments> {
 966        match &self.kind {
 967            FakeTransportKind::Stdio { .. } => None,
 968            FakeTransportKind::Tcp { connection, .. } => Some(connection.clone()),
 969        }
 970    }
 971
 972    fn connect(
 973        &mut self,
 974    ) -> Task<
 975        Result<(
 976            Box<dyn AsyncWrite + Unpin + Send + 'static>,
 977            Box<dyn AsyncRead + Unpin + Send + 'static>,
 978        )>,
 979    > {
 980        let result = match &mut self.kind {
 981            FakeTransportKind::Stdio {
 982                stdin_writer,
 983                stdout_reader,
 984            } => util::maybe!({
 985                Ok((
 986                    Box::new(stdin_writer.take().context("Cannot reconnect")?) as _,
 987                    Box::new(stdout_reader.take().context("Cannot reconnect")?) as _,
 988                ))
 989            }),
 990            FakeTransportKind::Tcp { executor, .. } => {
 991                let (stdin_writer, stdin_reader) = async_pipe::pipe();
 992                let (stdout_writer, stdout_reader) = async_pipe::pipe();
 993
 994                let request_handlers = self.request_handlers.clone();
 995                let response_handlers = self.response_handlers.clone();
 996
 997                self.message_handler = Some(executor.spawn(Self::handle_messages(
 998                    request_handlers,
 999                    response_handlers,
1000                    stdin_reader,
1001                    stdout_writer,
1002                )));
1003
1004                Ok((Box::new(stdin_writer) as _, Box::new(stdout_reader) as _))
1005            }
1006        };
1007        Task::ready(result)
1008    }
1009
1010    fn has_adapter_logs(&self) -> bool {
1011        false
1012    }
1013
1014    fn kill(&mut self) {
1015        self.message_handler.take();
1016    }
1017
1018    #[cfg(any(test, feature = "test-support"))]
1019    fn as_fake(&self) -> &FakeTransport {
1020        self
1021    }
1022}
1023
1024struct Child {
1025    process: smol::process::Child,
1026}
1027
1028impl std::ops::Deref for Child {
1029    type Target = smol::process::Child;
1030
1031    fn deref(&self) -> &Self::Target {
1032        &self.process
1033    }
1034}
1035
1036impl std::ops::DerefMut for Child {
1037    fn deref_mut(&mut self) -> &mut Self::Target {
1038        &mut self.process
1039    }
1040}
1041
1042impl Child {
1043    fn into_inner(self) -> smol::process::Child {
1044        self.process
1045    }
1046
1047    #[cfg(not(windows))]
1048    fn spawn(mut command: std::process::Command, stdin: Stdio) -> Result<Self> {
1049        util::set_pre_exec_to_start_new_session(&mut command);
1050        let mut command = smol::process::Command::from(command);
1051        let process = command
1052            .stdin(stdin)
1053            .stdout(Stdio::piped())
1054            .stderr(Stdio::piped())
1055            .spawn()
1056            .with_context(|| format!("failed to spawn command `{command:?}`",))?;
1057        Ok(Self { process })
1058    }
1059
1060    #[cfg(windows)]
1061    fn spawn(command: std::process::Command, stdin: Stdio) -> Result<Self> {
1062        // TODO(windows): create a job object and add the child process handle to it,
1063        // see https://learn.microsoft.com/en-us/windows/win32/procthread/job-objects
1064        let mut command = smol::process::Command::from(command);
1065        let process = command
1066            .stdin(stdin)
1067            .stdout(Stdio::piped())
1068            .stderr(Stdio::piped())
1069            .spawn()
1070            .with_context(|| format!("failed to spawn command `{command:?}`",))?;
1071        Ok(Self { process })
1072    }
1073
1074    #[cfg(not(windows))]
1075    fn kill(&mut self) {
1076        let pid = self.process.id();
1077        unsafe {
1078            libc::killpg(pid as i32, libc::SIGKILL);
1079        }
1080    }
1081
1082    #[cfg(windows)]
1083    fn kill(&mut self) {
1084        // TODO(windows): terminate the job object in kill
1085        let _ = self.process.kill();
1086    }
1087}