ssh: Do not cancel connection process if user is typing password (#18812)

Bennet Bo Fenner and Thorsten created

Previously, the connection process would be cancelled after 10 seconds,
even if the connection was established successfully but the user was
still typing in a password.
We know recognize when the user is prompted for a password, and cancel
the timeout task.

Co-Authored-by: Thorsten <thorsten@zed.dev>

Release Notes:

- N/A

---------

Co-authored-by: Thorsten <thorsten@zed.dev>

Change summary

crates/remote/src/ssh_session.rs | 67 +++++++++++++++------------------
1 file changed, 30 insertions(+), 37 deletions(-)

Detailed changes

crates/remote/src/ssh_session.rs 🔗

@@ -171,28 +171,6 @@ async fn run_cmd(command: &mut process::Command) -> Result<String> {
         ))
     }
 }
-#[cfg(unix)]
-async fn read_with_timeout(
-    stdout: &mut process::ChildStdout,
-    timeout: Duration,
-    output: &mut Vec<u8>,
-) -> Result<(), std::io::Error> {
-    smol::future::or(
-        async {
-            stdout.read_to_end(output).await?;
-            Ok::<_, std::io::Error>(())
-        },
-        async {
-            smol::Timer::after(timeout).await;
-
-            Err(std::io::Error::new(
-                std::io::ErrorKind::TimedOut,
-                "Read operation timed out",
-            ))
-        },
-    )
-    .await
-}
 
 struct ChannelForwarder {
     quit_tx: UnboundedSender<()>,
@@ -725,13 +703,19 @@ impl SshRemoteConnection {
 
         // Create a domain socket listener to handle requests from the askpass program.
         let askpass_socket = temp_dir.path().join("askpass.sock");
+        let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
         let listener =
             UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
 
         let askpass_task = cx.spawn({
             let delegate = delegate.clone();
             |mut cx| async move {
+                let mut askpass_opened_tx = Some(askpass_opened_tx);
+
                 while let Ok((mut stream, _)) = listener.accept().await {
+                    if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
+                        askpass_opened_tx.send(()).ok();
+                    }
                     let mut buffer = Vec::new();
                     let mut reader = BufReader::new(&mut stream);
                     if reader.read_until(b'\0', &mut buffer).await.is_err() {
@@ -782,19 +766,28 @@ impl SshRemoteConnection {
         let stdout = master_process.stdout.as_mut().unwrap();
         let mut output = Vec::new();
         let connection_timeout = Duration::from_secs(10);
-        let result = read_with_timeout(stdout, connection_timeout, &mut output).await;
-        if let Err(e) = result {
-            let error_message = if e.kind() == std::io::ErrorKind::TimedOut {
-                format!(
-                    "Failed to connect to host. Timed out after {:?}.",
-                    connection_timeout
-                )
-            } else {
-                format!("Failed to connect to host: {}.", e)
-            };
 
+        let result = select_biased! {
+            _ = askpass_opened_rx.fuse() => {
+                // If the askpass script has opened, that means the user is typing
+                // their password, in which case we don't want to timeout anymore,
+                // since we know a connection has been established.
+                stdout.read_to_end(&mut output).await?;
+                Ok(())
+            }
+            result = stdout.read_to_end(&mut output).fuse() => {
+                result?;
+                Ok(())
+            }
+            _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
+                Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout))
+            }
+        };
+
+        if let Err(e) = result {
+            let error_message = format!("Failed to connect to host: {}.", e);
             delegate.set_error(error_message, cx);
-            return Err(e.into());
+            return Err(e);
         }
 
         drop(askpass_task);
@@ -803,10 +796,10 @@ impl SshRemoteConnection {
             output.clear();
             let mut stderr = master_process.stderr.take().unwrap();
             stderr.read_to_end(&mut output).await?;
-            Err(anyhow!(
-                "failed to connect: {}",
-                String::from_utf8_lossy(&output)
-            ))?;
+
+            let error_message = format!("failed to connect: {}", String::from_utf8_lossy(&output));
+            delegate.set_error(error_message.clone(), cx);
+            Err(anyhow!(error_message))?;
         }
 
         Ok(Self {