windows: Use nc-esque ssh askpass auth for remoting (#39646)

Piotr Osiewicz created

This lets us avoid storing user PW in ZED_ASKPASS_PASSWORD env var.
Release Notes:

- N/A

Change summary

Cargo.lock                               |   1 
crates/askpass/Cargo.toml                |   1 
crates/askpass/src/askpass.rs            | 245 ++++++++++++++++---------
crates/askpass/src/encrypted_password.rs |  86 +++-----
crates/project/src/git_store.rs          |   7 
crates/remote/src/transport/ssh.rs       |  33 +-
6 files changed, 213 insertions(+), 160 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -808,7 +808,6 @@ dependencies = [
  "gpui",
  "log",
  "net",
- "proto",
  "smol",
  "tempfile",
  "windows 0.61.1",

crates/askpass/Cargo.toml 🔗

@@ -16,7 +16,6 @@ anyhow.workspace = true
 futures.workspace = true
 gpui.workspace = true
 net.workspace = true
-proto.workspace = true
 smol.workspace = true
 log.workspace = true
 tempfile.workspace = true

crates/askpass/src/askpass.rs 🔗

@@ -1,10 +1,16 @@
 mod encrypted_password;
 
-pub use encrypted_password::{EncryptedPassword, ProcessExt};
-use util::paths::PathExt;
+pub use encrypted_password::{EncryptedPassword, IKnowWhatIAmDoingAndIHaveReadTheDocs};
 
+use net::async_net::UnixListener;
+use smol::lock::Mutex;
+use util::fs::make_file_executable;
+
+use std::ffi::OsStr;
+use std::ops::ControlFlow;
+use std::sync::Arc;
 use std::sync::OnceLock;
-use std::{ffi::OsStr, time::Duration};
+use std::time::Duration;
 
 use anyhow::{Context as _, Result};
 use futures::channel::{mpsc, oneshot};
@@ -14,9 +20,7 @@ use futures::{
 };
 use gpui::{AsyncApp, BackgroundExecutor, Task};
 use smol::fs;
-use util::{ResultExt as _, debug_panic};
-
-use crate::encrypted_password::decrypt;
+use util::{ResultExt as _, debug_panic, maybe, paths::PathExt};
 
 /// Path to the program used for askpass
 ///
@@ -32,6 +36,7 @@ pub enum AskPassResult {
 
 pub struct AskPassDelegate {
     tx: mpsc::UnboundedSender<(String, oneshot::Sender<EncryptedPassword>)>,
+    executor: BackgroundExecutor,
     _task: Task<()>,
 }
 
@@ -49,24 +54,27 @@ impl AskPassDelegate {
                 password_prompt(prompt, channel, cx);
             }
         });
-        Self { tx, _task: task }
+        Self {
+            tx,
+            _task: task,
+            executor: cx.background_executor().clone(),
+        }
     }
 
-    pub async fn ask_password(&mut self, prompt: String) -> Option<EncryptedPassword> {
-        let (tx, rx) = oneshot::channel();
-        self.tx.send((prompt, tx)).await.ok()?;
-        rx.await.ok()
+    pub fn ask_password(&mut self, prompt: String) -> Task<Option<EncryptedPassword>> {
+        let mut this_tx = self.tx.clone();
+        self.executor.spawn(async move {
+            let (tx, rx) = oneshot::channel();
+            this_tx.send((prompt, tx)).await.ok()?;
+            rx.await.ok()
+        })
     }
 }
 
 pub struct AskPassSession {
-    #[cfg(not(target_os = "windows"))]
-    script_path: std::path::PathBuf,
-    #[cfg(target_os = "windows")]
-    askpass_helper: String,
     #[cfg(target_os = "windows")]
     secret: std::sync::Arc<OnceLock<EncryptedPassword>>,
-    _askpass_task: Task<()>,
+    askpass_task: PasswordProxy,
     askpass_opened_rx: Option<oneshot::Receiver<()>>,
     askpass_kill_master_rx: Option<oneshot::Receiver<()>>,
 }
@@ -81,103 +89,57 @@ impl AskPassSession {
     /// You must retain this session until the master process exits.
     #[must_use]
     pub async fn new(executor: &BackgroundExecutor, mut delegate: AskPassDelegate) -> Result<Self> {
-        use net::async_net::UnixListener;
-        use util::fs::make_file_executable;
-
         #[cfg(target_os = "windows")]
         let secret = std::sync::Arc::new(OnceLock::new());
-        let temp_dir = tempfile::Builder::new().prefix("zed-askpass").tempdir()?;
-        let askpass_socket = temp_dir.path().join("askpass.sock");
-        let askpass_script_path = temp_dir.path().join(ASKPASS_SCRIPT_NAME);
         let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
-        let listener = UnixListener::bind(&askpass_socket).context("creating askpass socket")?;
-
-        let current_exec =
-            std::env::current_exe().context("Failed to determine current zed executable path.")?;
 
-        let askpass_program = ASKPASS_PROGRAM
-            .get_or_init(|| current_exec)
-            .try_shell_safe()
-            .context("Failed to shell-escape Askpass program path.")?
-            .to_string();
+        let askpass_opened_tx = Arc::new(Mutex::new(Some(askpass_opened_tx)));
 
         let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<()>();
-        let mut kill_tx = Some(askpass_kill_master_tx);
+        let kill_tx = Arc::new(Mutex::new(Some(askpass_kill_master_tx)));
 
         #[cfg(target_os = "windows")]
         let askpass_secret = secret.clone();
-        let askpass_task = executor.spawn(async move {
-            let mut askpass_opened_tx = Some(askpass_opened_tx);
-
-            while let Ok((mut stream, _)) = listener.accept().await {
-                if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
-                    askpass_opened_tx.send(()).ok();
-                }
-                let mut buffer = Vec::new();
-                let mut reader = BufReader::new(&mut stream);
-                if reader.read_until(b'\0', &mut buffer).await.is_err() {
-                    buffer.clear();
-                }
-                let prompt = String::from_utf8_lossy(&buffer);
-                if let Some(password) = delegate.ask_password(prompt.into_owned()).await {
-                    #[cfg(target_os = "windows")]
-                    {
-                        askpass_secret.get_or_init(|| password.clone());
-                    }
-                    if let Ok(decrypted) = decrypt(password) {
-                        stream.write_all(decrypted.as_bytes()).await.log_err();
+        let get_password = {
+            let executor = executor.clone();
+
+            move |prompt| {
+                let prompt = delegate.ask_password(prompt);
+                let kill_tx = kill_tx.clone();
+                let askpass_opened_tx = askpass_opened_tx.clone();
+                #[cfg(target_os = "windows")]
+                let askpass_secret = askpass_secret.clone();
+                executor.spawn(async move {
+                    if let Some(askpass_opened_tx) = askpass_opened_tx.lock().await.take() {
+                        askpass_opened_tx.send(()).ok();
                     }
-                } else {
-                    if let Some(kill_tx) = kill_tx.take() {
-                        kill_tx.send(()).log_err();
+                    if let Some(password) = prompt.await {
+                        #[cfg(target_os = "windows")]
+                        {
+                            _ = askpass_secret.set(password.clone());
+                        }
+                        ControlFlow::Continue(Ok(password))
+                    } else {
+                        if let Some(kill_tx) = kill_tx.lock().await.take() {
+                            kill_tx.send(()).log_err();
+                        }
+                        ControlFlow::Break(())
                     }
-                    // note: we expect the caller to drop this task when it's done.
-                    // We need to keep the stream open until the caller is done to avoid
-                    // spurious errors from ssh.
-                    std::future::pending::<()>().await;
-                    drop(stream);
-                }
+                })
             }
-            drop(temp_dir)
-        });
-
-        // Create an askpass script that communicates back to this process.
-        let askpass_script = generate_askpass_script(&askpass_program, &askpass_socket);
-        fs::write(&askpass_script_path, askpass_script)
-            .await
-            .with_context(|| format!("creating askpass script at {askpass_script_path:?}"))?;
-        make_file_executable(&askpass_script_path).await?;
-        #[cfg(target_os = "windows")]
-        let askpass_helper = format!(
-            "powershell.exe -ExecutionPolicy Bypass -File {}",
-            askpass_script_path.display()
-        );
+        };
+        let askpass_task = PasswordProxy::new(get_password, executor.clone()).await?;
 
         Ok(Self {
-            #[cfg(not(target_os = "windows"))]
-            script_path: askpass_script_path,
-
             #[cfg(target_os = "windows")]
             secret,
-            #[cfg(target_os = "windows")]
-            askpass_helper,
 
-            _askpass_task: askpass_task,
+            askpass_task,
             askpass_kill_master_rx: Some(askpass_kill_master_rx),
             askpass_opened_rx: Some(askpass_opened_rx),
         })
     }
 
-    #[cfg(not(target_os = "windows"))]
-    pub fn script_path(&self) -> impl AsRef<OsStr> {
-        &self.script_path
-    }
-
-    #[cfg(target_os = "windows")]
-    pub fn script_path(&self) -> impl AsRef<OsStr> {
-        &self.askpass_helper
-    }
-
     // This will run the askpass task forever, resolving as many authentication requests as needed.
     // The caller is responsible for examining the result of their own commands and cancelling this
     // future when this is no longer needed. Note that this can only be called once, but due to the
@@ -209,8 +171,109 @@ impl AskPassSession {
     pub fn get_password(&self) -> Option<EncryptedPassword> {
         self.secret.get().cloned()
     }
+
+    pub fn script_path(&self) -> impl AsRef<OsStr> {
+        self.askpass_task.script_path()
+    }
+}
+
+pub struct PasswordProxy {
+    _task: Task<()>,
+    #[cfg(not(target_os = "windows"))]
+    askpass_script_path: std::path::PathBuf,
+    #[cfg(target_os = "windows")]
+    askpass_helper: String,
 }
 
+impl PasswordProxy {
+    pub async fn new(
+        mut get_password: impl FnMut(String) -> Task<ControlFlow<(), Result<EncryptedPassword>>>
+        + 'static
+        + Send
+        + Sync,
+        executor: BackgroundExecutor,
+    ) -> Result<Self> {
+        let temp_dir = tempfile::Builder::new().prefix("zed-askpass").tempdir()?;
+        let askpass_socket = temp_dir.path().join("askpass.sock");
+        let askpass_script_path = temp_dir.path().join(ASKPASS_SCRIPT_NAME);
+        let current_exec =
+            std::env::current_exe().context("Failed to determine current zed executable path.")?;
+
+        let askpass_program = ASKPASS_PROGRAM
+            .get_or_init(|| current_exec)
+            .try_shell_safe()
+            .context("Failed to shell-escape Askpass program path.")?
+            .to_string();
+        // Create an askpass script that communicates back to this process.
+        let askpass_script = generate_askpass_script(&askpass_program, &askpass_socket);
+        let _task = executor.spawn(async move {
+            maybe!(async move {
+                let listener =
+                    UnixListener::bind(&askpass_socket).context("creating askpass socket")?;
+
+                while let Ok((mut stream, _)) = listener.accept().await {
+                    let mut buffer = Vec::new();
+                    let mut reader = BufReader::new(&mut stream);
+                    if reader.read_until(b'\0', &mut buffer).await.is_err() {
+                        buffer.clear();
+                    }
+                    let prompt = String::from_utf8_lossy(&buffer).into_owned();
+                    let password = get_password(prompt).await;
+                    match password {
+                        ControlFlow::Continue(password) => {
+                            if let Ok(password) = password
+                                && let Ok(decrypted) =
+                                    password.decrypt(IKnowWhatIAmDoingAndIHaveReadTheDocs)
+                            {
+                                stream.write_all(decrypted.as_bytes()).await.log_err();
+                            }
+                        }
+                        ControlFlow::Break(()) => {
+                            // note: we expect the caller to drop this task when it's done.
+                            // We need to keep the stream open until the caller is done to avoid
+                            // spurious errors from ssh.
+                            std::future::pending::<()>().await;
+                            drop(stream);
+                        }
+                    }
+                }
+                drop(temp_dir);
+                Result::<_, anyhow::Error>::Ok(())
+            })
+            .await
+            .log_err();
+        });
+
+        fs::write(&askpass_script_path, askpass_script)
+            .await
+            .with_context(|| format!("creating askpass script at {askpass_script_path:?}"))?;
+        make_file_executable(&askpass_script_path).await?;
+        #[cfg(target_os = "windows")]
+        let askpass_helper = format!(
+            "powershell.exe -ExecutionPolicy Bypass -File {}",
+            askpass_script_path.display()
+        );
+
+        Ok(Self {
+            _task,
+            #[cfg(not(target_os = "windows"))]
+            askpass_script_path,
+            #[cfg(target_os = "windows")]
+            askpass_helper,
+        })
+    }
+
+    pub fn script_path(&self) -> impl AsRef<OsStr> {
+        #[cfg(not(target_os = "windows"))]
+        {
+            &self.askpass_script_path
+        }
+        #[cfg(target_os = "windows")]
+        {
+            &self.askpass_helper
+        }
+    }
+}
 /// The main function for when Zed is running in netcat mode for use in askpass.
 /// Called from both the remote server binary and the zed binary in their respective main functions.
 pub fn main(socket: &str) {

crates/askpass/src/encrypted_password.rs 🔗

@@ -21,27 +21,6 @@ type LengthWithoutPadding = u32;
 #[derive(Clone)]
 pub struct EncryptedPassword(Vec<u8>, LengthWithoutPadding);
 
-pub trait ProcessExt {
-    fn encrypted_env(&mut self, name: &str, value: EncryptedPassword) -> &mut Self;
-}
-
-impl ProcessExt for smol::process::Command {
-    fn encrypted_env(&mut self, name: &str, value: EncryptedPassword) -> &mut Self {
-        if let Ok(password) = decrypt(value) {
-            self.env(name, password);
-        }
-        self
-    }
-}
-
-impl TryFrom<EncryptedPassword> for proto::AskPassResponse {
-    type Error = anyhow::Error;
-    fn try_from(pw: EncryptedPassword) -> Result<Self, Self::Error> {
-        let pw = decrypt(pw)?;
-        Ok(Self { response: pw })
-    }
-}
-
 impl Drop for EncryptedPassword {
     fn drop(&mut self) {
         self.0.zeroize();
@@ -79,38 +58,45 @@ impl TryFrom<&str> for EncryptedPassword {
     }
 }
 
-pub(crate) fn decrypt(mut password: EncryptedPassword) -> Result<String> {
-    #[cfg(windows)]
-    {
-        use anyhow::Context;
-        use windows::Win32::Security::Cryptography::{
-            CRYPTPROTECTMEMORY_BLOCK_SIZE, CRYPTPROTECTMEMORY_SAME_PROCESS, CryptUnprotectMemory,
-        };
-        assert_eq!(
-            password.0.len() % CRYPTPROTECTMEMORY_BLOCK_SIZE as usize,
-            0,
-            "Violated pre-condition (buffer size <{}> must be a multiple of CRYPTPROTECTMEMORY_BLOCK_SIZE <{}>) for CryptUnprotectMemory.",
-            password.0.len(),
-            CRYPTPROTECTMEMORY_BLOCK_SIZE
-        );
-        if password.1 != 0 {
-            unsafe {
-                CryptUnprotectMemory(
-                    password.0.as_mut_ptr() as _,
-                    password.0.len().try_into()?,
-                    CRYPTPROTECTMEMORY_SAME_PROCESS,
-                )
-                .context("while decrypting a SSH password")?
+/// Read the docs for [EncryptedPassword]; please take care of not storing the plaintext string in memory for extended
+/// periods of time.
+pub struct IKnowWhatIAmDoingAndIHaveReadTheDocs;
+
+impl EncryptedPassword {
+    pub fn decrypt(mut self, _: IKnowWhatIAmDoingAndIHaveReadTheDocs) -> Result<String> {
+        #[cfg(windows)]
+        {
+            use anyhow::Context;
+            use windows::Win32::Security::Cryptography::{
+                CRYPTPROTECTMEMORY_BLOCK_SIZE, CRYPTPROTECTMEMORY_SAME_PROCESS,
+                CryptUnprotectMemory,
             };
+            assert_eq!(
+                self.0.len() % CRYPTPROTECTMEMORY_BLOCK_SIZE as usize,
+                0,
+                "Violated pre-condition (buffer size <{}> must be a multiple of CRYPTPROTECTMEMORY_BLOCK_SIZE <{}>) for CryptUnprotectMemory.",
+                self.0.len(),
+                CRYPTPROTECTMEMORY_BLOCK_SIZE
+            );
+            if self.1 != 0 {
+                unsafe {
+                    CryptUnprotectMemory(
+                        self.0.as_mut_ptr() as _,
+                        self.0.len().try_into()?,
+                        CRYPTPROTECTMEMORY_SAME_PROCESS,
+                    )
+                    .context("while decrypting a SSH password")?
+                };
 
-            {
-                // Remove padding
-                _ = password.0.drain(password.1 as usize..);
+                {
+                    // Remove padding
+                    _ = self.0.drain(self.1 as usize..);
+                }
             }
-        }
 
-        Ok(String::from_utf8(std::mem::take(&mut password.0))?)
+            Ok(String::from_utf8(std::mem::take(&mut self.0))?)
+        }
+        #[cfg(not(windows))]
+        Ok(String::from_utf8(std::mem::take(&mut self.0))?)
     }
-    #[cfg(not(windows))]
-    Ok(String::from_utf8(std::mem::take(&mut password.0))?)
 }

crates/project/src/git_store.rs 🔗

@@ -7,7 +7,7 @@ use crate::{
     worktree_store::{WorktreeStore, WorktreeStoreEvent},
 };
 use anyhow::{Context as _, Result, anyhow, bail};
-use askpass::{AskPassDelegate, EncryptedPassword};
+use askpass::{AskPassDelegate, EncryptedPassword, IKnowWhatIAmDoingAndIHaveReadTheDocs};
 use buffer_diff::{BufferDiff, BufferDiffEvent};
 use client::ProjectId;
 use collections::HashMap;
@@ -2120,7 +2120,10 @@ impl GitStore {
             .lock()
             .insert(envelope.payload.askpass_id, askpass);
 
-        response.try_into()
+        // In fact, we don't quite know what we're doing here, as we're sending askpass password unencrypted, but..
+        Ok(proto::AskPassResponse {
+            response: response.decrypt(IKnowWhatIAmDoingAndIHaveReadTheDocs)?,
+        })
     }
 
     async fn handle_check_for_pushed_commits(

crates/remote/src/transport/ssh.rs 🔗

@@ -70,14 +70,13 @@ impl From<settings::SshConnection> for SshConnectionOptions {
     }
 }
 
-#[derive(Clone)]
 struct SshSocket {
     connection_options: SshConnectionOptions,
     #[cfg(not(target_os = "windows"))]
-    socket_path: PathBuf,
+    socket_path: std::path::PathBuf,
     envs: HashMap<String, String>,
     #[cfg(target_os = "windows")]
-    password: askpass::EncryptedPassword,
+    _proxy: askpass::PasswordProxy,
 }
 
 macro_rules! shell_script {
@@ -343,16 +342,17 @@ impl SshRemoteConnection {
         }
 
         #[cfg(not(target_os = "windows"))]
-        let socket = SshSocket::new(connection_options, socket_path)?;
+        let socket = SshSocket::new(connection_options, socket_path).await?;
         #[cfg(target_os = "windows")]
         let socket = SshSocket::new(
             connection_options,
-            &temp_dir,
             askpass
                 .get_password()
                 .or_else(|| askpass::EncryptedPassword::try_from("").ok())
                 .context("Failed to fetch askpass password")?,
-        )?;
+            cx.background_executor().clone(),
+        )
+        .await?;
         drop(askpass);
 
         let ssh_platform = socket.platform().await?;
@@ -659,7 +659,7 @@ impl SshRemoteConnection {
 
 impl SshSocket {
     #[cfg(not(target_os = "windows"))]
-    fn new(options: SshConnectionOptions, socket_path: PathBuf) -> Result<Self> {
+    async fn new(options: SshConnectionOptions, socket_path: PathBuf) -> Result<Self> {
         Ok(Self {
             connection_options: options,
             envs: HashMap::default(),
@@ -668,21 +668,26 @@ impl SshSocket {
     }
 
     #[cfg(target_os = "windows")]
-    fn new(
+    async fn new(
         options: SshConnectionOptions,
-        temp_dir: &TempDir,
         password: askpass::EncryptedPassword,
+        executor: gpui::BackgroundExecutor,
     ) -> Result<Self> {
-        let askpass_script = temp_dir.path().join("askpass.bat");
-        std::fs::write(&askpass_script, "@ECHO OFF\necho %ZED_SSH_ASKPASS%")?;
         let mut envs = HashMap::default();
+        let get_password =
+            move |_| Task::ready(std::ops::ControlFlow::Continue(Ok(password.clone())));
+
+        let _proxy = askpass::PasswordProxy::new(get_password, executor).await?;
         envs.insert("SSH_ASKPASS_REQUIRE".into(), "force".into());
-        envs.insert("SSH_ASKPASS".into(), askpass_script.display().to_string());
+        envs.insert(
+            "SSH_ASKPASS".into(),
+            _proxy.script_path().as_ref().display().to_string(),
+        );
 
         Ok(Self {
             connection_options: options,
             envs,
-            password,
+            _proxy,
         })
     }
 
@@ -736,14 +741,12 @@ impl SshSocket {
 
     #[cfg(target_os = "windows")]
     fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
-        use askpass::ProcessExt;
         command
             .stdin(Stdio::piped())
             .stdout(Stdio::piped())
             .stderr(Stdio::piped())
             .args(self.connection_options.additional_args())
             .envs(self.envs.clone())
-            .encrypted_env("ZED_SSH_ASKPASS", self.password.clone())
     }
 
     // On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.