Fix JumpHost on Windows (#40713)

localcc created

Closes #39382 

Release Notes:

- Fixed Windows specific ssh jumphost connection issues

Change summary

crates/remote/src/transport/ssh.rs | 184 +++++++++++++++++++++++++------
1 file changed, 144 insertions(+), 40 deletions(-)

Detailed changes

crates/remote/src/transport/ssh.rs 🔗

@@ -34,7 +34,7 @@ use util::{
 
 pub(crate) struct SshRemoteConnection {
     socket: SshSocket,
-    master_process: Mutex<Option<Child>>,
+    master_process: Mutex<Option<MasterProcess>>,
     remote_binary_path: Option<Arc<RelPath>>,
     ssh_platform: RemotePlatform,
     ssh_path_style: PathStyle,
@@ -80,6 +80,129 @@ struct SshSocket {
     _proxy: askpass::PasswordProxy,
 }
 
+struct MasterProcess {
+    process: Child,
+}
+
+#[cfg(not(target_os = "windows"))]
+impl MasterProcess {
+    pub fn new(
+        askpass_script_path: &std::ffi::OsStr,
+        additional_args: Vec<String>,
+        socket_path: &std::path::Path,
+        url: &str,
+    ) -> Result<Self> {
+        let args = [
+            "-N",
+            "-o",
+            "ControlPersist=no",
+            "-o",
+            "ControlMaster=yes",
+            "-o",
+        ];
+
+        let mut master_process = util::command::new_smol_command("ssh");
+        master_process
+            .kill_on_drop(true)
+            .stdin(Stdio::null())
+            .stdout(Stdio::piped())
+            .stderr(Stdio::piped())
+            .env("SSH_ASKPASS_REQUIRE", "force")
+            .env("SSH_ASKPASS", askpass_script_path)
+            .args(additional_args)
+            .args(args);
+
+        master_process.arg(format!("ControlPath={}", socket_path.display()));
+
+        let process = master_process.arg(&url).spawn()?;
+
+        Ok(MasterProcess { process })
+    }
+
+    pub async fn wait_connected(&mut self) -> Result<()> {
+        let Some(mut stdout) = self.process.stdout.take() else {
+            anyhow::bail!("ssh process stdout capture failed");
+        };
+
+        let mut output = Vec::new();
+        stdout.read_to_end(&mut output).await?;
+        Ok(())
+    }
+}
+
+#[cfg(target_os = "windows")]
+impl MasterProcess {
+    const CONNECTION_ESTABLISHED_MAGIC: &str = "ZED_SSH_CONNECTION_ESTABLISHED";
+
+    pub fn new(
+        askpass_script_path: &std::ffi::OsStr,
+        additional_args: Vec<String>,
+        url: &str,
+    ) -> Result<Self> {
+        // On Windows, `ControlMaster` and `ControlPath` are not supported:
+        // https://github.com/PowerShell/Win32-OpenSSH/issues/405
+        // https://github.com/PowerShell/Win32-OpenSSH/wiki/Project-Scope
+        //
+        // Using an ugly workaround to detect connection establishment
+        // -N doesn't work with JumpHosts as windows openssh never closes stdin in that case
+        let args = [
+            "-t",
+            &format!("echo '{}'; exec $0", Self::CONNECTION_ESTABLISHED_MAGIC),
+        ];
+
+        let mut master_process = util::command::new_smol_command("ssh");
+        master_process
+            .kill_on_drop(true)
+            .stdin(Stdio::null())
+            .stdout(Stdio::piped())
+            .stderr(Stdio::piped())
+            .env("SSH_ASKPASS_REQUIRE", "force")
+            .env("SSH_ASKPASS", askpass_script_path)
+            .args(additional_args)
+            .arg(url)
+            .args(args);
+
+        let process = master_process.spawn()?;
+
+        Ok(MasterProcess { process })
+    }
+
+    pub async fn wait_connected(&mut self) -> Result<()> {
+        use smol::io::AsyncBufReadExt;
+
+        let Some(stdout) = self.process.stdout.take() else {
+            anyhow::bail!("ssh process stdout capture failed");
+        };
+
+        let mut reader = smol::io::BufReader::new(stdout);
+
+        let mut line = String::new();
+
+        loop {
+            let n = reader.read_line(&mut line).await?;
+            if n == 0 {
+                anyhow::bail!("ssh process exited before connection established");
+            }
+
+            if line.contains(Self::CONNECTION_ESTABLISHED_MAGIC) {
+                return Ok(());
+            }
+        }
+    }
+}
+
+impl AsRef<Child> for MasterProcess {
+    fn as_ref(&self) -> &Child {
+        &self.process
+    }
+}
+
+impl AsMut<Child> for MasterProcess {
+    fn as_mut(&mut self) -> &mut Child {
+        &mut self.process
+    }
+}
+
 macro_rules! shell_script {
     ($fmt:expr, $($name:ident = $arg:expr),+ $(,)?) => {{
         format!(
@@ -97,8 +220,8 @@ impl RemoteConnection for SshRemoteConnection {
         let Some(mut process) = self.master_process.lock().take() else {
             return Ok(());
         };
-        process.kill().ok();
-        process.status().await?;
+        process.as_mut().kill().ok();
+        process.as_mut().status().await?;
         Ok(())
     }
 
@@ -302,45 +425,25 @@ impl SshRemoteConnection {
         #[cfg(not(target_os = "windows"))]
         let socket_path = temp_dir.path().join("ssh.sock");
 
-        let mut master_process = {
-            #[cfg(not(target_os = "windows"))]
-            let args = [
-                "-N",
-                "-o",
-                "ControlPersist=no",
-                "-o",
-                "ControlMaster=yes",
-                "-o",
-            ];
-            // On Windows, `ControlMaster` and `ControlPath` are not supported:
-            // https://github.com/PowerShell/Win32-OpenSSH/issues/405
-            // https://github.com/PowerShell/Win32-OpenSSH/wiki/Project-Scope
-            #[cfg(target_os = "windows")]
-            let args = ["-N"];
-            let mut master_process = util::command::new_smol_command("ssh");
-            master_process
-                .kill_on_drop(true)
-                .stdin(Stdio::null())
-                .stdout(Stdio::piped())
-                .stderr(Stdio::piped())
-                .env("SSH_ASKPASS_REQUIRE", "force")
-                .env("SSH_ASKPASS", askpass.script_path())
-                .args(connection_options.additional_args())
-                .args(args);
-            #[cfg(not(target_os = "windows"))]
-            master_process.arg(format!("ControlPath={}", socket_path.display()));
-            master_process.arg(&url).spawn()?
-        };
-        // Wait for this ssh process to close its stdout, indicating that authentication
-        // has completed.
-        let mut stdout = master_process.stdout.take().unwrap();
-        let mut output = Vec::new();
+        #[cfg(target_os = "windows")]
+        let mut master_process = MasterProcess::new(
+            askpass.script_path().as_ref(),
+            connection_options.additional_args(),
+            &url,
+        )?;
+        #[cfg(not(target_os = "windows"))]
+        let mut master_process = MasterProcess::new(
+            askpass.script_path().as_ref(),
+            connection_options.additional_args(),
+            &socket_path,
+            &url,
+        )?;
 
         let result = select_biased! {
             result = askpass.run().fuse() => {
                 match result {
                     AskPassResult::CancelledByUser => {
-                        master_process.kill().ok();
+                        master_process.as_mut().kill().ok();
                         anyhow::bail!("SSH connection canceled")
                     }
                     AskPassResult::Timedout => {
@@ -348,7 +451,7 @@ impl SshRemoteConnection {
                     }
                 }
             }
-            _ = stdout.read_to_end(&mut output).fuse() => {
+            _ = master_process.wait_connected().fuse() => {
                 anyhow::Ok(())
             }
         };
@@ -357,9 +460,10 @@ impl SshRemoteConnection {
             return Err(e.context("Failed to connect to host"));
         }
 
-        if master_process.try_status()?.is_some() {
+        if master_process.as_mut().try_status()?.is_some() {
+            let mut output = Vec::new();
             output.clear();
-            let mut stderr = master_process.stderr.take().unwrap();
+            let mut stderr = master_process.as_mut().stderr.take().unwrap();
             stderr.read_to_end(&mut output).await?;
 
             let error_message = format!(