Better handle interrupted connections for shared SSH (#19925)

Conrad Irwin and Mikayla created

Co-Authored-By: Mikayla <mikayla@zed.dev>

Change summary

crates/remote/src/ssh_session.rs | 34 ++++++++++++++++++++++++++--------
1 file changed, 26 insertions(+), 8 deletions(-)

Detailed changes

crates/remote/src/ssh_session.rs 🔗

@@ -1288,6 +1288,7 @@ impl SshRemoteConnection {
     ) -> Result<Self> {
         use futures::AsyncWriteExt as _;
         use futures::{io::BufReader, AsyncBufReadExt as _};
+        use smol::net::unix::UnixStream;
         use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
         use util::ResultExt as _;
 
@@ -1304,6 +1305,9 @@ impl SshRemoteConnection {
         let listener =
             UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
 
+        let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<UnixStream>();
+        let mut kill_tx = Some(askpass_kill_master_tx);
+
         let askpass_task = cx.spawn({
             let delegate = delegate.clone();
             |mut cx| async move {
@@ -1327,6 +1331,11 @@ impl SshRemoteConnection {
                         .log_err()
                     {
                         stream.write_all(password.as_bytes()).await.log_err();
+                    } else {
+                        if let Some(kill_tx) = kill_tx.take() {
+                            kill_tx.send(stream).log_err();
+                            break;
+                        }
                     }
                 }
             }
@@ -1347,6 +1356,7 @@ impl SshRemoteConnection {
         // the connection and keep it open, allowing other ssh commands to reuse it
         // via a control socket.
         let socket_path = temp_dir.path().join("ssh.sock");
+
         let mut master_process = process::Command::new("ssh")
             .stdin(Stdio::null())
             .stdout(Stdio::piped())
@@ -1369,20 +1379,28 @@ impl SshRemoteConnection {
 
         // Wait for this ssh process to close its stdout, indicating that authentication
         // has completed.
-        let stdout = master_process.stdout.as_mut().unwrap();
+        let mut stdout = master_process.stdout.take().unwrap();
         let mut output = Vec::new();
         let connection_timeout = Duration::from_secs(10);
 
         let result = select_biased! {
             _ = askpass_opened_rx.fuse() => {
-                // If the askpass script has opened, that means the user is typing
-                // their password, in which case we don't want to timeout anymore,
-                // since we know a connection has been established.
-                stdout.read_to_end(&mut output).await?;
-                Ok(())
+                select_biased! {
+                    stream = askpass_kill_master_rx.fuse() => {
+                        master_process.kill().ok();
+                        drop(stream);
+                        Err(anyhow!("SSH connection canceled"))
+                    }
+                    // If the askpass script has opened, that means the user is typing
+                    // their password, in which case we don't want to timeout anymore,
+                    // since we know a connection has been established.
+                    result = stdout.read_to_end(&mut output).fuse() => {
+                        result?;
+                        Ok(())
+                    }
+                }
             }
-            result = stdout.read_to_end(&mut output).fuse() => {
-                result?;
+            _ = stdout.read_to_end(&mut output).fuse() => {
                 Ok(())
             }
             _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {