transport.rs

   1use anyhow::{Context as _, Result, anyhow, bail};
   2#[cfg(any(test, feature = "test-support"))]
   3use async_pipe::{PipeReader, PipeWriter};
   4use dap_types::{
   5    ErrorResponse,
   6    messages::{Message, Response},
   7};
   8use futures::{AsyncRead, AsyncReadExt as _, AsyncWrite, FutureExt as _, channel::oneshot, select};
   9use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Task};
  10use parking_lot::Mutex;
  11use proto::ErrorExt;
  12use settings::Settings as _;
  13use smallvec::SmallVec;
  14use smol::{
  15    channel::{Receiver, Sender, unbounded},
  16    io::{AsyncBufReadExt as _, AsyncWriteExt, BufReader},
  17    net::{TcpListener, TcpStream},
  18};
  19use std::{
  20    collections::HashMap,
  21    net::{Ipv4Addr, SocketAddrV4},
  22    process::Stdio,
  23    sync::Arc,
  24    time::Duration,
  25};
  26use task::TcpArgumentsTemplate;
  27use util::ConnectionResult;
  28
  29use crate::{
  30    adapters::{DebugAdapterBinary, TcpArguments},
  31    client::DapMessageHandler,
  32    debugger_settings::DebuggerSettings,
  33};
  34
  35pub(crate) type IoMessage = str;
  36pub(crate) type Command = str;
  37pub type IoHandler = Box<dyn Send + FnMut(IoKind, Option<&Command>, &IoMessage)>;
  38
  39#[derive(PartialEq, Eq, Clone, Copy)]
  40pub enum LogKind {
  41    Adapter,
  42    Rpc,
  43}
  44
  45#[derive(Clone, Copy)]
  46pub enum IoKind {
  47    StdIn,
  48    StdOut,
  49    StdErr,
  50}
  51
  52#[cfg(any(test, feature = "test-support"))]
  53pub enum RequestHandling<T> {
  54    Respond(T),
  55    Exit,
  56}
  57
  58type LogHandlers = Arc<Mutex<SmallVec<[(LogKind, IoHandler); 2]>>>;
  59
  60pub trait Transport: Send + Sync {
  61    fn has_adapter_logs(&self) -> bool;
  62    fn tcp_arguments(&self) -> Option<TcpArguments>;
  63    fn connect(
  64        &mut self,
  65    ) -> Task<
  66        Result<(
  67            Box<dyn AsyncWrite + Unpin + Send + 'static>,
  68            Box<dyn AsyncRead + Unpin + Send + 'static>,
  69        )>,
  70    >;
  71    fn kill(&mut self);
  72    #[cfg(any(test, feature = "test-support"))]
  73    fn as_fake(&self) -> &FakeTransport {
  74        unreachable!()
  75    }
  76}
  77
  78async fn start(
  79    binary: &DebugAdapterBinary,
  80    log_handlers: LogHandlers,
  81    cx: &mut AsyncApp,
  82) -> Result<Box<dyn Transport>> {
  83    #[cfg(any(test, feature = "test-support"))]
  84    if cfg!(any(test, feature = "test-support")) {
  85        if let Some(connection) = binary.connection.clone() {
  86            return Ok(Box::new(FakeTransport::start_tcp(connection, cx).await?));
  87        } else {
  88            return Ok(Box::new(FakeTransport::start_stdio(cx).await?));
  89        }
  90    }
  91
  92    if binary.connection.is_some() {
  93        Ok(Box::new(
  94            TcpTransport::start(binary, log_handlers, cx).await?,
  95        ))
  96    } else {
  97        Ok(Box::new(
  98            StdioTransport::start(binary, log_handlers, cx).await?,
  99        ))
 100    }
 101}
 102
 103pub(crate) struct PendingRequests {
 104    inner: Option<HashMap<u64, oneshot::Sender<Result<Response>>>>,
 105}
 106
 107impl PendingRequests {
 108    fn new() -> Self {
 109        Self {
 110            inner: Some(HashMap::default()),
 111        }
 112    }
 113
 114    fn flush(&mut self, e: anyhow::Error) {
 115        let Some(inner) = self.inner.as_mut() else {
 116            return;
 117        };
 118        for (_, sender) in inner.drain() {
 119            sender.send(Err(e.cloned())).ok();
 120        }
 121    }
 122
 123    pub(crate) fn insert(
 124        &mut self,
 125        sequence_id: u64,
 126        callback_tx: oneshot::Sender<Result<Response>>,
 127    ) -> anyhow::Result<()> {
 128        let Some(inner) = self.inner.as_mut() else {
 129            bail!("client is closed")
 130        };
 131        inner.insert(sequence_id, callback_tx);
 132        Ok(())
 133    }
 134
 135    pub(crate) fn remove(
 136        &mut self,
 137        sequence_id: u64,
 138    ) -> anyhow::Result<Option<oneshot::Sender<Result<Response>>>> {
 139        let Some(inner) = self.inner.as_mut() else {
 140            bail!("client is closed");
 141        };
 142        Ok(inner.remove(&sequence_id))
 143    }
 144
 145    pub(crate) fn shutdown(&mut self) {
 146        self.flush(anyhow!("transport shutdown"));
 147        self.inner = None;
 148    }
 149}
 150
 151pub(crate) struct TransportDelegate {
 152    log_handlers: LogHandlers,
 153    pub(crate) pending_requests: Arc<Mutex<PendingRequests>>,
 154    pub(crate) transport: Mutex<Box<dyn Transport>>,
 155    pub(crate) server_tx: smol::lock::Mutex<Option<Sender<Message>>>,
 156    tasks: Mutex<Vec<Task<()>>>,
 157}
 158
 159impl TransportDelegate {
 160    pub(crate) async fn start(binary: &DebugAdapterBinary, cx: &mut AsyncApp) -> Result<Self> {
 161        let log_handlers: LogHandlers = Default::default();
 162        let transport = start(binary, log_handlers.clone(), cx).await?;
 163        Ok(Self {
 164            transport: Mutex::new(transport),
 165            log_handlers,
 166            server_tx: Default::default(),
 167            pending_requests: Arc::new(Mutex::new(PendingRequests::new())),
 168            tasks: Default::default(),
 169        })
 170    }
 171
 172    pub async fn connect(
 173        &self,
 174        message_handler: DapMessageHandler,
 175        cx: &mut AsyncApp,
 176    ) -> Result<()> {
 177        let (server_tx, client_rx) = unbounded::<Message>();
 178        self.tasks.lock().clear();
 179
 180        let log_dap_communications =
 181            cx.update(|cx| DebuggerSettings::get_global(cx).log_dap_communications)
 182                .with_context(|| "Failed to get Debugger Setting log dap communications error in transport::start_handlers. Defaulting to false")
 183                .unwrap_or(false);
 184
 185        let connect = self.transport.lock().connect();
 186        let (input, output) = connect.await?;
 187
 188        let log_handler = if log_dap_communications {
 189            Some(self.log_handlers.clone())
 190        } else {
 191            None
 192        };
 193
 194        let pending_requests = self.pending_requests.clone();
 195        let output_log_handler = log_handler.clone();
 196        {
 197            let mut tasks = self.tasks.lock();
 198            tasks.push(cx.background_spawn(async move {
 199                match Self::recv_from_server(
 200                    output,
 201                    message_handler,
 202                    pending_requests.clone(),
 203                    output_log_handler,
 204                )
 205                .await
 206                {
 207                    Ok(()) => {
 208                        pending_requests
 209                            .lock()
 210                            .flush(anyhow!("debugger shutdown unexpectedly"));
 211                    }
 212                    Err(e) => {
 213                        pending_requests.lock().flush(e);
 214                    }
 215                }
 216            }));
 217
 218            tasks.push(cx.background_spawn(async move {
 219                match Self::send_to_server(input, client_rx, log_handler).await {
 220                    Ok(()) => {}
 221                    Err(e) => log::error!("Error handling debugger input: {e}"),
 222                }
 223            }));
 224        }
 225
 226        {
 227            let mut lock = self.server_tx.lock().await;
 228            *lock = Some(server_tx.clone());
 229        }
 230
 231        Ok(())
 232    }
 233
 234    pub(crate) fn tcp_arguments(&self) -> Option<TcpArguments> {
 235        self.transport.lock().tcp_arguments()
 236    }
 237
 238    pub(crate) async fn send_message(&self, message: Message) -> Result<()> {
 239        if let Some(server_tx) = self.server_tx.lock().await.as_ref() {
 240            server_tx.send(message).await.context("sending message")
 241        } else {
 242            anyhow::bail!("Server tx already dropped")
 243        }
 244    }
 245
 246    async fn handle_adapter_log(
 247        stdout: impl AsyncRead + Unpin + Send + 'static,
 248        iokind: IoKind,
 249        log_handlers: LogHandlers,
 250    ) {
 251        let mut reader = BufReader::new(stdout);
 252        let mut line = String::new();
 253
 254        loop {
 255            line.truncate(0);
 256
 257            match reader.read_line(&mut line).await {
 258                Ok(0) => break,
 259                Ok(_) => {}
 260                Err(e) => {
 261                    log::debug!("handle_adapter_log: {}", e);
 262                    break;
 263                }
 264            }
 265            log::debug!("stderr: {line}");
 266
 267            for (kind, handler) in log_handlers.lock().iter_mut() {
 268                if matches!(kind, LogKind::Adapter) {
 269                    handler(iokind, None, line.as_str());
 270                }
 271            }
 272        }
 273    }
 274
 275    fn build_rpc_message(message: String) -> String {
 276        format!("Content-Length: {}\r\n\r\n{}", message.len(), message)
 277    }
 278
 279    async fn send_to_server<Stdin>(
 280        mut server_stdin: Stdin,
 281        client_rx: Receiver<Message>,
 282        log_handlers: Option<LogHandlers>,
 283    ) -> Result<()>
 284    where
 285        Stdin: AsyncWrite + Unpin + Send + 'static,
 286    {
 287        let result = loop {
 288            match client_rx.recv().await {
 289                Ok(message) => {
 290                    let command = match &message {
 291                        Message::Request(request) => Some(request.command.as_str()),
 292                        Message::Response(response) => Some(response.command.as_str()),
 293                        _ => None,
 294                    };
 295
 296                    let message = match serde_json::to_string(&message) {
 297                        Ok(message) => message,
 298                        Err(e) => break Err(e.into()),
 299                    };
 300
 301                    if let Some(log_handlers) = log_handlers.as_ref() {
 302                        for (kind, log_handler) in log_handlers.lock().iter_mut() {
 303                            if matches!(kind, LogKind::Rpc) {
 304                                log_handler(IoKind::StdIn, command, &message);
 305                            }
 306                        }
 307                    }
 308
 309                    if let Err(e) = server_stdin
 310                        .write_all(Self::build_rpc_message(message).as_bytes())
 311                        .await
 312                    {
 313                        break Err(e.into());
 314                    }
 315
 316                    if let Err(e) = server_stdin.flush().await {
 317                        break Err(e.into());
 318                    }
 319                }
 320                Err(error) => break Err(error.into()),
 321            }
 322        };
 323
 324        log::debug!("Handle adapter input dropped");
 325
 326        result
 327    }
 328
 329    async fn recv_from_server<Stdout>(
 330        server_stdout: Stdout,
 331        mut message_handler: DapMessageHandler,
 332        pending_requests: Arc<Mutex<PendingRequests>>,
 333        log_handlers: Option<LogHandlers>,
 334    ) -> Result<()>
 335    where
 336        Stdout: AsyncRead + Unpin + Send + 'static,
 337    {
 338        let mut recv_buffer = String::new();
 339        let mut reader = BufReader::new(server_stdout);
 340
 341        let result = loop {
 342            let result =
 343                Self::receive_server_message(&mut reader, &mut recv_buffer, log_handlers.as_ref())
 344                    .await;
 345            match result {
 346                ConnectionResult::Timeout => anyhow::bail!("Timed out when connecting to debugger"),
 347                ConnectionResult::ConnectionReset => {
 348                    log::info!("Debugger closed the connection");
 349                    return Ok(());
 350                }
 351                ConnectionResult::Result(Ok(Message::Response(res))) => {
 352                    let tx = pending_requests.lock().remove(res.request_seq)?;
 353                    if let Some(tx) = tx {
 354                        if let Err(e) = tx.send(Self::process_response(res)) {
 355                            log::trace!("Did not send response `{:?}` for a cancelled", e);
 356                        }
 357                    } else {
 358                        message_handler(Message::Response(res))
 359                    }
 360                }
 361                ConnectionResult::Result(Ok(message)) => message_handler(message),
 362                ConnectionResult::Result(Err(e)) => break Err(e),
 363            }
 364        };
 365
 366        log::debug!("Handle adapter output dropped");
 367
 368        result
 369    }
 370
 371    fn process_response(response: Response) -> Result<Response> {
 372        if response.success {
 373            Ok(response)
 374        } else {
 375            if let Some(error_message) = response
 376                .body
 377                .clone()
 378                .and_then(|body| serde_json::from_value::<ErrorResponse>(body).ok())
 379                .and_then(|response| response.error.map(|msg| msg.format))
 380                .or_else(|| response.message.clone())
 381            {
 382                anyhow::bail!(error_message);
 383            };
 384
 385            anyhow::bail!(
 386                "Received error response from adapter. Response: {:?}",
 387                response
 388            );
 389        }
 390    }
 391
 392    async fn receive_server_message<Stdout>(
 393        reader: &mut BufReader<Stdout>,
 394        buffer: &mut String,
 395        log_handlers: Option<&LogHandlers>,
 396    ) -> ConnectionResult<Message>
 397    where
 398        Stdout: AsyncRead + Unpin + Send + 'static,
 399    {
 400        let mut content_length = None;
 401        loop {
 402            buffer.truncate(0);
 403            match reader.read_line(buffer).await {
 404                Ok(0) => return ConnectionResult::ConnectionReset,
 405                Ok(_) => {}
 406                Err(e) => return ConnectionResult::Result(Err(e.into())),
 407            };
 408
 409            if buffer == "\r\n" {
 410                break;
 411            }
 412
 413            if let Some(("Content-Length", value)) = buffer.trim().split_once(": ") {
 414                match value.parse().context("invalid content length") {
 415                    Ok(length) => content_length = Some(length),
 416                    Err(e) => return ConnectionResult::Result(Err(e)),
 417                }
 418            }
 419        }
 420
 421        let content_length = match content_length.context("missing content length") {
 422            Ok(length) => length,
 423            Err(e) => return ConnectionResult::Result(Err(e)),
 424        };
 425
 426        let mut content = vec![0; content_length];
 427        if let Err(e) = reader
 428            .read_exact(&mut content)
 429            .await
 430            .with_context(|| "reading after a loop")
 431        {
 432            return ConnectionResult::Result(Err(e));
 433        }
 434
 435        let message_str = match std::str::from_utf8(&content).context("invalid utf8 from server") {
 436            Ok(str) => str,
 437            Err(e) => return ConnectionResult::Result(Err(e)),
 438        };
 439
 440        let message =
 441            serde_json::from_str::<Message>(message_str).context("deserializing server message");
 442
 443        if let Some(log_handlers) = log_handlers {
 444            let command = match &message {
 445                Ok(Message::Request(request)) => Some(request.command.as_str()),
 446                Ok(Message::Response(response)) => Some(response.command.as_str()),
 447                _ => None,
 448            };
 449
 450            for (kind, log_handler) in log_handlers.lock().iter_mut() {
 451                if matches!(kind, LogKind::Rpc) {
 452                    log_handler(IoKind::StdOut, command, message_str);
 453                }
 454            }
 455        }
 456
 457        ConnectionResult::Result(message)
 458    }
 459
 460    pub fn has_adapter_logs(&self) -> bool {
 461        self.transport.lock().has_adapter_logs()
 462    }
 463
 464    pub fn add_log_handler<F>(&self, f: F, kind: LogKind)
 465    where
 466        F: 'static + Send + FnMut(IoKind, Option<&Command>, &IoMessage),
 467    {
 468        let mut log_handlers = self.log_handlers.lock();
 469        log_handlers.push((kind, Box::new(f)));
 470    }
 471}
 472
 473pub struct TcpTransport {
 474    executor: BackgroundExecutor,
 475    pub port: u16,
 476    pub host: Ipv4Addr,
 477    pub timeout: u64,
 478    process: Arc<Mutex<Option<Child>>>,
 479    _stderr_task: Option<Task<()>>,
 480    _stdout_task: Option<Task<()>>,
 481}
 482
 483impl TcpTransport {
 484    /// Get an open port to use with the tcp client when not supplied by debug config
 485    pub async fn port(host: &TcpArgumentsTemplate) -> Result<u16> {
 486        if let Some(port) = host.port {
 487            Ok(port)
 488        } else {
 489            Self::unused_port(host.host()).await
 490        }
 491    }
 492
 493    pub async fn unused_port(host: Ipv4Addr) -> Result<u16> {
 494        Ok(TcpListener::bind(SocketAddrV4::new(host, 0))
 495            .await?
 496            .local_addr()?
 497            .port())
 498    }
 499
 500    async fn start(
 501        binary: &DebugAdapterBinary,
 502        log_handlers: LogHandlers,
 503        cx: &mut AsyncApp,
 504    ) -> Result<Self> {
 505        let connection_args = binary
 506            .connection
 507            .as_ref()
 508            .context("No connection arguments provided")?;
 509
 510        let host = connection_args.host;
 511        let port = connection_args.port;
 512
 513        let mut process = None;
 514        let mut stdout_task = None;
 515        let mut stderr_task = None;
 516
 517        if let Some(command) = &binary.command {
 518            let mut command = util::command::new_std_command(&command);
 519
 520            if let Some(cwd) = &binary.cwd {
 521                command.current_dir(cwd);
 522            }
 523
 524            command.args(&binary.arguments);
 525            command.envs(&binary.envs);
 526
 527            let mut p = Child::spawn(command, Stdio::null())
 528                .with_context(|| "failed to start debug adapter.")?;
 529
 530            stdout_task = p.stdout.take().map(|stdout| {
 531                cx.background_executor()
 532                    .spawn(TransportDelegate::handle_adapter_log(
 533                        stdout,
 534                        IoKind::StdOut,
 535                        log_handlers.clone(),
 536                    ))
 537            });
 538            stderr_task = p.stderr.take().map(|stderr| {
 539                cx.background_executor()
 540                    .spawn(TransportDelegate::handle_adapter_log(
 541                        stderr,
 542                        IoKind::StdErr,
 543                        log_handlers,
 544                    ))
 545            });
 546            process = Some(p);
 547        };
 548
 549        let timeout = connection_args.timeout.unwrap_or_else(|| {
 550            cx.update(|cx| DebuggerSettings::get_global(cx).timeout)
 551                .unwrap_or(20000u64)
 552        });
 553
 554        log::info!(
 555            "Debug adapter has connected to TCP server {}:{}",
 556            host,
 557            port
 558        );
 559
 560        let this = Self {
 561            executor: cx.background_executor().clone(),
 562            port,
 563            host,
 564            process: Arc::new(Mutex::new(process)),
 565            timeout,
 566            _stdout_task: stdout_task,
 567            _stderr_task: stderr_task,
 568        };
 569
 570        Ok(this)
 571    }
 572}
 573
 574impl Transport for TcpTransport {
 575    fn has_adapter_logs(&self) -> bool {
 576        true
 577    }
 578
 579    fn kill(&mut self) {
 580        if let Some(process) = &mut *self.process.lock() {
 581            process.kill();
 582        }
 583    }
 584
 585    fn tcp_arguments(&self) -> Option<TcpArguments> {
 586        Some(TcpArguments {
 587            host: self.host,
 588            port: self.port,
 589            timeout: Some(self.timeout),
 590        })
 591    }
 592
 593    fn connect(
 594        &mut self,
 595    ) -> Task<
 596        Result<(
 597            Box<dyn AsyncWrite + Unpin + Send + 'static>,
 598            Box<dyn AsyncRead + Unpin + Send + 'static>,
 599        )>,
 600    > {
 601        let executor = self.executor.clone();
 602        let timeout = self.timeout;
 603        let address = SocketAddrV4::new(self.host, self.port);
 604        let process = self.process.clone();
 605        executor.clone().spawn(async move {
 606            select! {
 607                _ = executor.timer(Duration::from_millis(timeout)).fuse() => {
 608                    anyhow::bail!("Connection to TCP DAP timeout {address}");
 609                },
 610                result = executor.clone().spawn(async move {
 611                    loop {
 612                        match TcpStream::connect(address).await {
 613                            Ok(stream) => {
 614                                let (read, write) = stream.split();
 615                                return Ok((Box::new(write) as _, Box::new(read) as _))
 616                            },
 617                            Err(_) => {
 618                                let has_process = process.lock().is_some();
 619                                if has_process {
 620                                    let status = process.lock().as_mut().unwrap().try_status();
 621                                    if let Ok(Some(_)) = status {
 622                                        let process = process.lock().take().unwrap().into_inner();
 623                                        let output = process.output().await?;
 624                                        let output = if output.stderr.is_empty() {
 625                                            String::from_utf8_lossy(&output.stdout).to_string()
 626                                        } else {
 627                                            String::from_utf8_lossy(&output.stderr).to_string()
 628                                        };
 629                                        anyhow::bail!("{output}\nerror: process exited before debugger attached.");
 630                                    }
 631                                }
 632
 633                                executor.timer(Duration::from_millis(100)).await;
 634                            }
 635                        }
 636                    }
 637                }).fuse() => result
 638            }
 639        })
 640    }
 641}
 642
 643impl Drop for TcpTransport {
 644    fn drop(&mut self) {
 645        if let Some(mut p) = self.process.lock().take() {
 646            p.kill()
 647        }
 648    }
 649}
 650
 651pub struct StdioTransport {
 652    process: Mutex<Option<Child>>,
 653    _stderr_task: Option<Task<()>>,
 654}
 655
 656impl StdioTransport {
 657    // #[allow(dead_code, reason = "This is used in non test builds of Zed")]
 658    async fn start(
 659        binary: &DebugAdapterBinary,
 660        log_handlers: LogHandlers,
 661        cx: &mut AsyncApp,
 662    ) -> Result<Self> {
 663        let Some(binary_command) = &binary.command else {
 664            bail!(
 665                "When using the `stdio` transport, the path to a debug adapter binary must be set by Zed."
 666            );
 667        };
 668        let mut command = util::command::new_std_command(&binary_command);
 669
 670        if let Some(cwd) = &binary.cwd {
 671            command.current_dir(cwd);
 672        }
 673
 674        command.args(&binary.arguments);
 675        command.envs(&binary.envs);
 676
 677        let mut process = Child::spawn(command, Stdio::piped()).with_context(|| {
 678            format!(
 679                "failed to spawn command `{} {}`.",
 680                binary_command,
 681                binary.arguments.join(" ")
 682            )
 683        })?;
 684
 685        let err_task = process.stderr.take().map(|stderr| {
 686            cx.background_spawn(TransportDelegate::handle_adapter_log(
 687                stderr,
 688                IoKind::StdErr,
 689                log_handlers,
 690            ))
 691        });
 692
 693        let process = Mutex::new(Some(process));
 694
 695        Ok(Self {
 696            process,
 697            _stderr_task: err_task,
 698        })
 699    }
 700}
 701
 702impl Transport for StdioTransport {
 703    fn has_adapter_logs(&self) -> bool {
 704        false
 705    }
 706
 707    fn kill(&mut self) {
 708        if let Some(process) = &mut *self.process.lock() {
 709            process.kill();
 710        }
 711    }
 712
 713    fn connect(
 714        &mut self,
 715    ) -> Task<
 716        Result<(
 717            Box<dyn AsyncWrite + Unpin + Send + 'static>,
 718            Box<dyn AsyncRead + Unpin + Send + 'static>,
 719        )>,
 720    > {
 721        let result = util::maybe!({
 722            let mut guard = self.process.lock();
 723            let process = guard.as_mut().context("oops")?;
 724            Ok((
 725                Box::new(process.stdin.take().context("Cannot reconnect")?) as _,
 726                Box::new(process.stdout.take().context("Cannot reconnect")?) as _,
 727            ))
 728        });
 729        Task::ready(result)
 730    }
 731
 732    fn tcp_arguments(&self) -> Option<TcpArguments> {
 733        None
 734    }
 735}
 736
 737impl Drop for StdioTransport {
 738    fn drop(&mut self) {
 739        if let Some(process) = &mut *self.process.lock() {
 740            process.kill();
 741        }
 742    }
 743}
 744
 745#[cfg(any(test, feature = "test-support"))]
 746type RequestHandler = Box<dyn Send + FnMut(u64, serde_json::Value) -> RequestHandling<Response>>;
 747
 748#[cfg(any(test, feature = "test-support"))]
 749type ResponseHandler = Box<dyn Send + Fn(Response)>;
 750
 751#[cfg(any(test, feature = "test-support"))]
 752pub struct FakeTransport {
 753    // for sending fake response back from adapter side
 754    request_handlers: Arc<Mutex<HashMap<&'static str, RequestHandler>>>,
 755    // for reverse request responses
 756    response_handlers: Arc<Mutex<HashMap<&'static str, ResponseHandler>>>,
 757    message_handler: Option<Task<Result<()>>>,
 758    kind: FakeTransportKind,
 759}
 760
 761#[cfg(any(test, feature = "test-support"))]
 762pub enum FakeTransportKind {
 763    Stdio {
 764        stdin_writer: Option<PipeWriter>,
 765        stdout_reader: Option<PipeReader>,
 766    },
 767    Tcp {
 768        connection: TcpArguments,
 769        executor: BackgroundExecutor,
 770    },
 771}
 772
 773#[cfg(any(test, feature = "test-support"))]
 774impl FakeTransport {
 775    pub fn on_request<R: dap_types::requests::Request, F>(&self, mut handler: F)
 776    where
 777        F: 'static
 778            + Send
 779            + FnMut(u64, R::Arguments) -> RequestHandling<Result<R::Response, ErrorResponse>>,
 780    {
 781        self.request_handlers.lock().insert(
 782            R::COMMAND,
 783            Box::new(move |seq, args| {
 784                let result = handler(seq, serde_json::from_value(args).unwrap());
 785                let RequestHandling::Respond(response) = result else {
 786                    return RequestHandling::Exit;
 787                };
 788                let response = match response {
 789                    Ok(response) => Response {
 790                        seq: seq + 1,
 791                        request_seq: seq,
 792                        success: true,
 793                        command: R::COMMAND.into(),
 794                        body: Some(serde_json::to_value(response).unwrap()),
 795                        message: None,
 796                    },
 797                    Err(response) => Response {
 798                        seq: seq + 1,
 799                        request_seq: seq,
 800                        success: false,
 801                        command: R::COMMAND.into(),
 802                        body: Some(serde_json::to_value(response).unwrap()),
 803                        message: None,
 804                    },
 805                };
 806                RequestHandling::Respond(response)
 807            }),
 808        );
 809    }
 810
 811    pub fn on_response<R: dap_types::requests::Request, F>(&self, handler: F)
 812    where
 813        F: 'static + Send + Fn(Response),
 814    {
 815        self.response_handlers
 816            .lock()
 817            .insert(R::COMMAND, Box::new(handler));
 818    }
 819
 820    async fn start_tcp(connection: TcpArguments, cx: &mut AsyncApp) -> Result<Self> {
 821        Ok(Self {
 822            request_handlers: Arc::new(Mutex::new(HashMap::default())),
 823            response_handlers: Arc::new(Mutex::new(HashMap::default())),
 824            message_handler: None,
 825            kind: FakeTransportKind::Tcp {
 826                connection,
 827                executor: cx.background_executor().clone(),
 828            },
 829        })
 830    }
 831
 832    async fn handle_messages(
 833        request_handlers: Arc<Mutex<HashMap<&'static str, RequestHandler>>>,
 834        response_handlers: Arc<Mutex<HashMap<&'static str, ResponseHandler>>>,
 835        stdin_reader: PipeReader,
 836        stdout_writer: PipeWriter,
 837    ) -> Result<()> {
 838        use dap_types::requests::{Request, RunInTerminal, StartDebugging};
 839        use serde_json::json;
 840
 841        let mut reader = BufReader::new(stdin_reader);
 842        let stdout_writer = Arc::new(smol::lock::Mutex::new(stdout_writer));
 843        let mut buffer = String::new();
 844
 845        loop {
 846            match TransportDelegate::receive_server_message(&mut reader, &mut buffer, None).await {
 847                ConnectionResult::Timeout => {
 848                    anyhow::bail!("Timed out when connecting to debugger");
 849                }
 850                ConnectionResult::ConnectionReset => {
 851                    log::info!("Debugger closed the connection");
 852                    break Ok(());
 853                }
 854                ConnectionResult::Result(Err(e)) => break Err(e),
 855                ConnectionResult::Result(Ok(message)) => {
 856                    match message {
 857                        Message::Request(request) => {
 858                            // redirect reverse requests to stdout writer/reader
 859                            if request.command == RunInTerminal::COMMAND
 860                                || request.command == StartDebugging::COMMAND
 861                            {
 862                                let message =
 863                                    serde_json::to_string(&Message::Request(request)).unwrap();
 864
 865                                let mut writer = stdout_writer.lock().await;
 866                                writer
 867                                    .write_all(
 868                                        TransportDelegate::build_rpc_message(message).as_bytes(),
 869                                    )
 870                                    .await
 871                                    .unwrap();
 872                                writer.flush().await.unwrap();
 873                            } else {
 874                                let response = if let Some(handle) =
 875                                    request_handlers.lock().get_mut(request.command.as_str())
 876                                {
 877                                    handle(request.seq, request.arguments.unwrap_or(json!({})))
 878                                } else {
 879                                    panic!("No request handler for {}", request.command);
 880                                };
 881                                let response = match response {
 882                                    RequestHandling::Respond(response) => response,
 883                                    RequestHandling::Exit => {
 884                                        break Err(anyhow!("exit in response to request"));
 885                                    }
 886                                };
 887                                let success = response.success;
 888                                let message =
 889                                    serde_json::to_string(&Message::Response(response)).unwrap();
 890
 891                                let mut writer = stdout_writer.lock().await;
 892                                writer
 893                                    .write_all(
 894                                        TransportDelegate::build_rpc_message(message).as_bytes(),
 895                                    )
 896                                    .await
 897                                    .unwrap();
 898
 899                                if request.command == dap_types::requests::Initialize::COMMAND
 900                                    && success
 901                                {
 902                                    let message = serde_json::to_string(&Message::Event(Box::new(
 903                                        dap_types::messages::Events::Initialized(Some(
 904                                            Default::default(),
 905                                        )),
 906                                    )))
 907                                    .unwrap();
 908                                    writer
 909                                        .write_all(
 910                                            TransportDelegate::build_rpc_message(message)
 911                                                .as_bytes(),
 912                                        )
 913                                        .await
 914                                        .unwrap();
 915                                }
 916
 917                                writer.flush().await.unwrap();
 918                            }
 919                        }
 920                        Message::Event(event) => {
 921                            let message = serde_json::to_string(&Message::Event(event)).unwrap();
 922
 923                            let mut writer = stdout_writer.lock().await;
 924                            writer
 925                                .write_all(TransportDelegate::build_rpc_message(message).as_bytes())
 926                                .await
 927                                .unwrap();
 928                            writer.flush().await.unwrap();
 929                        }
 930                        Message::Response(response) => {
 931                            if let Some(handle) =
 932                                response_handlers.lock().get(response.command.as_str())
 933                            {
 934                                handle(response);
 935                            } else {
 936                                log::error!("No response handler for {}", response.command);
 937                            }
 938                        }
 939                    }
 940                }
 941            }
 942        }
 943    }
 944
 945    async fn start_stdio(cx: &mut AsyncApp) -> Result<Self> {
 946        let (stdin_writer, stdin_reader) = async_pipe::pipe();
 947        let (stdout_writer, stdout_reader) = async_pipe::pipe();
 948        let kind = FakeTransportKind::Stdio {
 949            stdin_writer: Some(stdin_writer),
 950            stdout_reader: Some(stdout_reader),
 951        };
 952
 953        let mut this = Self {
 954            request_handlers: Arc::new(Mutex::new(HashMap::default())),
 955            response_handlers: Arc::new(Mutex::new(HashMap::default())),
 956            message_handler: None,
 957            kind,
 958        };
 959
 960        let request_handlers = this.request_handlers.clone();
 961        let response_handlers = this.response_handlers.clone();
 962
 963        this.message_handler = Some(cx.background_spawn(Self::handle_messages(
 964            request_handlers,
 965            response_handlers,
 966            stdin_reader,
 967            stdout_writer,
 968        )));
 969
 970        Ok(this)
 971    }
 972}
 973
 974#[cfg(any(test, feature = "test-support"))]
 975impl Transport for FakeTransport {
 976    fn tcp_arguments(&self) -> Option<TcpArguments> {
 977        match &self.kind {
 978            FakeTransportKind::Stdio { .. } => None,
 979            FakeTransportKind::Tcp { connection, .. } => Some(connection.clone()),
 980        }
 981    }
 982
 983    fn connect(
 984        &mut self,
 985    ) -> Task<
 986        Result<(
 987            Box<dyn AsyncWrite + Unpin + Send + 'static>,
 988            Box<dyn AsyncRead + Unpin + Send + 'static>,
 989        )>,
 990    > {
 991        let result = match &mut self.kind {
 992            FakeTransportKind::Stdio {
 993                stdin_writer,
 994                stdout_reader,
 995            } => util::maybe!({
 996                Ok((
 997                    Box::new(stdin_writer.take().context("Cannot reconnect")?) as _,
 998                    Box::new(stdout_reader.take().context("Cannot reconnect")?) as _,
 999                ))
1000            }),
1001            FakeTransportKind::Tcp { executor, .. } => {
1002                let (stdin_writer, stdin_reader) = async_pipe::pipe();
1003                let (stdout_writer, stdout_reader) = async_pipe::pipe();
1004
1005                let request_handlers = self.request_handlers.clone();
1006                let response_handlers = self.response_handlers.clone();
1007
1008                self.message_handler = Some(executor.spawn(Self::handle_messages(
1009                    request_handlers,
1010                    response_handlers,
1011                    stdin_reader,
1012                    stdout_writer,
1013                )));
1014
1015                Ok((Box::new(stdin_writer) as _, Box::new(stdout_reader) as _))
1016            }
1017        };
1018        Task::ready(result)
1019    }
1020
1021    fn has_adapter_logs(&self) -> bool {
1022        false
1023    }
1024
1025    fn kill(&mut self) {
1026        self.message_handler.take();
1027    }
1028
1029    #[cfg(any(test, feature = "test-support"))]
1030    fn as_fake(&self) -> &FakeTransport {
1031        self
1032    }
1033}
1034
1035struct Child {
1036    process: smol::process::Child,
1037}
1038
1039impl std::ops::Deref for Child {
1040    type Target = smol::process::Child;
1041
1042    fn deref(&self) -> &Self::Target {
1043        &self.process
1044    }
1045}
1046
1047impl std::ops::DerefMut for Child {
1048    fn deref_mut(&mut self) -> &mut Self::Target {
1049        &mut self.process
1050    }
1051}
1052
1053impl Child {
1054    fn into_inner(self) -> smol::process::Child {
1055        self.process
1056    }
1057
1058    #[cfg(not(windows))]
1059    fn spawn(mut command: std::process::Command, stdin: Stdio) -> Result<Self> {
1060        util::set_pre_exec_to_start_new_session(&mut command);
1061        let process = smol::process::Command::from(command)
1062            .stdin(stdin)
1063            .stdout(Stdio::piped())
1064            .stderr(Stdio::piped())
1065            .spawn()?;
1066        Ok(Self { process })
1067    }
1068
1069    #[cfg(windows)]
1070    fn spawn(command: std::process::Command, stdin: Stdio) -> Result<Self> {
1071        // TODO(windows): create a job object and add the child process handle to it,
1072        // see https://learn.microsoft.com/en-us/windows/win32/procthread/job-objects
1073        let process = smol::process::Command::from(command)
1074            .stdin(stdin)
1075            .stdout(Stdio::piped())
1076            .stderr(Stdio::piped())
1077            .spawn()?;
1078        Ok(Self { process })
1079    }
1080
1081    #[cfg(not(windows))]
1082    fn kill(&mut self) {
1083        let pid = self.process.id();
1084        unsafe {
1085            libc::killpg(pid as i32, libc::SIGKILL);
1086        }
1087    }
1088
1089    #[cfg(windows)]
1090    fn kill(&mut self) {
1091        // TODO(windows): terminate the job object in kill
1092        let _ = self.process.kill();
1093    }
1094}