@@ -171,28 +171,6 @@ async fn run_cmd(command: &mut process::Command) -> Result<String> {
))
}
}
-#[cfg(unix)]
-async fn read_with_timeout(
- stdout: &mut process::ChildStdout,
- timeout: Duration,
- output: &mut Vec<u8>,
-) -> Result<(), std::io::Error> {
- smol::future::or(
- async {
- stdout.read_to_end(output).await?;
- Ok::<_, std::io::Error>(())
- },
- async {
- smol::Timer::after(timeout).await;
-
- Err(std::io::Error::new(
- std::io::ErrorKind::TimedOut,
- "Read operation timed out",
- ))
- },
- )
- .await
-}
struct ChannelForwarder {
quit_tx: UnboundedSender<()>,
@@ -725,13 +703,19 @@ impl SshRemoteConnection {
// Create a domain socket listener to handle requests from the askpass program.
let askpass_socket = temp_dir.path().join("askpass.sock");
+ let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
let listener =
UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
let askpass_task = cx.spawn({
let delegate = delegate.clone();
|mut cx| async move {
+ let mut askpass_opened_tx = Some(askpass_opened_tx);
+
while let Ok((mut stream, _)) = listener.accept().await {
+ if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
+ askpass_opened_tx.send(()).ok();
+ }
let mut buffer = Vec::new();
let mut reader = BufReader::new(&mut stream);
if reader.read_until(b'\0', &mut buffer).await.is_err() {
@@ -782,19 +766,28 @@ impl SshRemoteConnection {
let stdout = master_process.stdout.as_mut().unwrap();
let mut output = Vec::new();
let connection_timeout = Duration::from_secs(10);
- let result = read_with_timeout(stdout, connection_timeout, &mut output).await;
- if let Err(e) = result {
- let error_message = if e.kind() == std::io::ErrorKind::TimedOut {
- format!(
- "Failed to connect to host. Timed out after {:?}.",
- connection_timeout
- )
- } else {
- format!("Failed to connect to host: {}.", e)
- };
+ let result = select_biased! {
+ _ = askpass_opened_rx.fuse() => {
+ // If the askpass script has opened, that means the user is typing
+ // their password, in which case we don't want to timeout anymore,
+ // since we know a connection has been established.
+ stdout.read_to_end(&mut output).await?;
+ Ok(())
+ }
+ result = stdout.read_to_end(&mut output).fuse() => {
+ result?;
+ Ok(())
+ }
+ _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
+ Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout))
+ }
+ };
+
+ if let Err(e) = result {
+ let error_message = format!("Failed to connect to host: {}.", e);
delegate.set_error(error_message, cx);
- return Err(e.into());
+ return Err(e);
}
drop(askpass_task);
@@ -803,10 +796,10 @@ impl SshRemoteConnection {
output.clear();
let mut stderr = master_process.stderr.take().unwrap();
stderr.read_to_end(&mut output).await?;
- Err(anyhow!(
- "failed to connect: {}",
- String::from_utf8_lossy(&output)
- ))?;
+
+ let error_message = format!("failed to connect: {}", String::from_utf8_lossy(&output));
+ delegate.set_error(error_message.clone(), cx);
+ Err(anyhow!(error_message))?;
}
Ok(Self {