askpass.rs

  1mod encrypted_password;
  2
  3pub use encrypted_password::{EncryptedPassword, IKnowWhatIAmDoingAndIHaveReadTheDocs};
  4
  5use net::async_net::UnixListener;
  6use smol::lock::Mutex;
  7use util::fs::make_file_executable;
  8
  9use std::ffi::OsStr;
 10use std::ops::ControlFlow;
 11use std::sync::Arc;
 12use std::sync::OnceLock;
 13use std::time::Duration;
 14
 15use anyhow::{Context as _, Result};
 16use futures::channel::{mpsc, oneshot};
 17use futures::{
 18    AsyncBufReadExt as _, AsyncWriteExt as _, FutureExt as _, SinkExt, StreamExt, io::BufReader,
 19    select_biased,
 20};
 21use gpui::{AsyncApp, BackgroundExecutor, Task};
 22use smol::fs;
 23use util::{
 24    ResultExt as _, debug_panic, maybe,
 25    paths::PathExt,
 26    shell::{PosixShell, ShellKind},
 27};
 28
 29/// Path to the program used for askpass
 30///
 31/// On Unix and remote servers, this defaults to the current executable
 32/// On Windows, this is set to the CLI variant of zed
 33static ASKPASS_PROGRAM: OnceLock<std::path::PathBuf> = OnceLock::new();
 34
 35#[derive(PartialEq, Eq)]
 36pub enum AskPassResult {
 37    CancelledByUser,
 38    Timedout,
 39}
 40
 41pub struct AskPassDelegate {
 42    tx: mpsc::UnboundedSender<(String, oneshot::Sender<EncryptedPassword>)>,
 43    executor: BackgroundExecutor,
 44    _task: Task<()>,
 45}
 46
 47impl AskPassDelegate {
 48    pub fn new(
 49        cx: &mut AsyncApp,
 50        password_prompt: impl Fn(String, oneshot::Sender<EncryptedPassword>, &mut AsyncApp)
 51        + Send
 52        + Sync
 53        + 'static,
 54    ) -> Self {
 55        let (tx, mut rx) = mpsc::unbounded::<(String, oneshot::Sender<_>)>();
 56        let task = cx.spawn(async move |cx: &mut AsyncApp| {
 57            while let Some((prompt, channel)) = rx.next().await {
 58                password_prompt(prompt, channel, cx);
 59            }
 60        });
 61        Self {
 62            tx,
 63            _task: task,
 64            executor: cx.background_executor().clone(),
 65        }
 66    }
 67
 68    pub fn ask_password(&mut self, prompt: String) -> Task<Option<EncryptedPassword>> {
 69        let mut this_tx = self.tx.clone();
 70        self.executor.spawn(async move {
 71            let (tx, rx) = oneshot::channel();
 72            this_tx.send((prompt, tx)).await.ok()?;
 73            rx.await.ok()
 74        })
 75    }
 76}
 77
 78pub struct AskPassSession {
 79    #[cfg(target_os = "windows")]
 80    secret: std::sync::Arc<std::sync::Mutex<Option<EncryptedPassword>>>,
 81    askpass_task: PasswordProxy,
 82    askpass_opened_rx: Option<oneshot::Receiver<()>>,
 83    askpass_kill_master_rx: Option<oneshot::Receiver<()>>,
 84    executor: BackgroundExecutor,
 85}
 86
 87const ASKPASS_SCRIPT_NAME: &str = if cfg!(target_os = "windows") {
 88    "askpass.ps1"
 89} else {
 90    "askpass.sh"
 91};
 92
 93impl AskPassSession {
 94    /// This will create a new AskPassSession.
 95    /// You must retain this session until the master process exits.
 96    #[must_use]
 97    pub async fn new(executor: BackgroundExecutor, mut delegate: AskPassDelegate) -> Result<Self> {
 98        #[cfg(target_os = "windows")]
 99        let secret = std::sync::Arc::new(std::sync::Mutex::new(None));
100
101        let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
102
103        let askpass_opened_tx = Arc::new(Mutex::new(Some(askpass_opened_tx)));
104
105        let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<()>();
106        let kill_tx = Arc::new(Mutex::new(Some(askpass_kill_master_tx)));
107
108        let get_password = {
109            let executor = executor.clone();
110
111            #[cfg(target_os = "windows")]
112            let askpass_secret = secret.clone();
113            move |prompt| {
114                let prompt = delegate.ask_password(prompt);
115                let kill_tx = kill_tx.clone();
116                let askpass_opened_tx = askpass_opened_tx.clone();
117                #[cfg(target_os = "windows")]
118                let askpass_secret = askpass_secret.clone();
119                executor.spawn(async move {
120                    if let Some(askpass_opened_tx) = askpass_opened_tx.lock().await.take() {
121                        askpass_opened_tx.send(()).ok();
122                    }
123                    if let Some(password) = prompt.await {
124                        #[cfg(target_os = "windows")]
125                        {
126                            askpass_secret.lock().unwrap().replace(password.clone());
127                        }
128                        ControlFlow::Continue(Ok(password))
129                    } else {
130                        if let Some(kill_tx) = kill_tx.lock().await.take() {
131                            kill_tx.send(()).log_err();
132                        }
133                        ControlFlow::Break(())
134                    }
135                })
136            }
137        };
138        let askpass_task = PasswordProxy::new(get_password, executor.clone()).await?;
139
140        Ok(Self {
141            #[cfg(target_os = "windows")]
142            secret,
143
144            askpass_task,
145            askpass_kill_master_rx: Some(askpass_kill_master_rx),
146            askpass_opened_rx: Some(askpass_opened_rx),
147            executor,
148        })
149    }
150
151    // This will run the askpass task forever, resolving as many authentication requests as needed.
152    // The caller is responsible for examining the result of their own commands and cancelling this
153    // future when this is no longer needed. Note that this can only be called once, but due to the
154    // drop order this takes an &mut, so you can `drop()` it after you're done with the master process.
155    pub async fn run(&mut self) -> AskPassResult {
156        // This is the default timeout setting used by VSCode.
157        let connection_timeout = Duration::from_secs(17);
158        let askpass_opened_rx = self.askpass_opened_rx.take().expect("Only call run once");
159        let askpass_kill_master_rx = self
160            .askpass_kill_master_rx
161            .take()
162            .expect("Only call run once");
163        let executor = self.executor.clone();
164
165        select_biased! {
166            _ = askpass_opened_rx.fuse() => {
167                // Note: this await can only resolve after we are dropped.
168                askpass_kill_master_rx.await.ok();
169                AskPassResult::CancelledByUser
170            }
171
172            _ = futures::FutureExt::fuse(executor.timer(connection_timeout)) => {
173                AskPassResult::Timedout
174            }
175        }
176    }
177
178    /// This will return the password that was last set by the askpass script.
179    #[cfg(target_os = "windows")]
180    pub fn get_password(&self) -> Option<EncryptedPassword> {
181        self.secret.lock().ok()?.clone()
182    }
183
184    pub fn script_path(&self) -> impl AsRef<OsStr> {
185        self.askpass_task.script_path()
186    }
187}
188
189pub struct PasswordProxy {
190    _task: Task<()>,
191    #[cfg(not(target_os = "windows"))]
192    askpass_script_path: std::path::PathBuf,
193    #[cfg(target_os = "windows")]
194    askpass_helper: String,
195}
196
197impl PasswordProxy {
198    pub async fn new(
199        mut get_password: impl FnMut(String) -> Task<ControlFlow<(), Result<EncryptedPassword>>>
200        + 'static
201        + Send
202        + Sync,
203        executor: BackgroundExecutor,
204    ) -> Result<Self> {
205        let temp_dir = tempfile::Builder::new().prefix("zed-askpass").tempdir()?;
206        let askpass_socket = temp_dir.path().join("askpass.sock");
207        let askpass_script_path = temp_dir.path().join(ASKPASS_SCRIPT_NAME);
208        let current_exec =
209            std::env::current_exe().context("Failed to determine current zed executable path.")?;
210
211        // TODO: inferred from the use of powershell.exe in askpass_helper_script
212        let shell_kind = if cfg!(windows) {
213            ShellKind::PowerShell
214        } else {
215            // TODO: Consider using the user's actual shell instead of hardcoding "sh"
216            ShellKind::Posix(PosixShell::Sh)
217        };
218        let askpass_program = ASKPASS_PROGRAM.get_or_init(|| current_exec);
219        // Create an askpass script that communicates back to this process.
220        let askpass_script = generate_askpass_script(shell_kind, askpass_program, &askpass_socket)?;
221        let _task = executor.spawn(async move {
222            maybe!(async move {
223                let listener =
224                    UnixListener::bind(&askpass_socket).context("creating askpass socket")?;
225
226                while let Ok((mut stream, _)) = listener.accept().await {
227                    let mut buffer = Vec::new();
228                    let mut reader = BufReader::new(&mut stream);
229                    if reader.read_until(b'\0', &mut buffer).await.is_err() {
230                        buffer.clear();
231                    }
232                    let prompt = String::from_utf8_lossy(&buffer).into_owned();
233                    let password = get_password(prompt).await;
234                    match password {
235                        ControlFlow::Continue(password) => {
236                            if let Ok(password) = password
237                                && let Ok(decrypted) =
238                                    password.decrypt(IKnowWhatIAmDoingAndIHaveReadTheDocs)
239                            {
240                                stream.write_all(decrypted.as_bytes()).await.log_err();
241                            }
242                        }
243                        ControlFlow::Break(()) => {
244                            // note: we expect the caller to drop this task when it's done.
245                            // We need to keep the stream open until the caller is done to avoid
246                            // spurious errors from ssh.
247                            std::future::pending::<()>().await;
248                            drop(stream);
249                        }
250                    }
251                }
252                drop(temp_dir);
253                Result::<_, anyhow::Error>::Ok(())
254            })
255            .await
256            .log_err();
257        });
258
259        fs::write(&askpass_script_path, askpass_script)
260            .await
261            .with_context(|| format!("creating askpass script at {askpass_script_path:?}"))?;
262        make_file_executable(&askpass_script_path)
263            .await
264            .with_context(|| {
265                format!("marking askpass script executable at {askpass_script_path:?}")
266            })?;
267        // todo(shell): There might be no powershell on the system
268        #[cfg(target_os = "windows")]
269        let askpass_helper = format!(
270            "powershell.exe -ExecutionPolicy Bypass -File \"{}\"",
271            askpass_script_path.display()
272        );
273
274        Ok(Self {
275            _task,
276            #[cfg(not(target_os = "windows"))]
277            askpass_script_path,
278            #[cfg(target_os = "windows")]
279            askpass_helper,
280        })
281    }
282
283    pub fn script_path(&self) -> impl AsRef<OsStr> {
284        #[cfg(not(target_os = "windows"))]
285        {
286            &self.askpass_script_path
287        }
288        #[cfg(target_os = "windows")]
289        {
290            &self.askpass_helper
291        }
292    }
293}
294/// The main function for when Zed is running in netcat mode for use in askpass.
295/// Called from both the remote server binary and the zed binary in their respective main functions.
296pub fn main(socket: &str) {
297    use net::UnixStream;
298    use std::io::{self, Read, Write};
299    use std::process::exit;
300
301    let mut stream = match UnixStream::connect(socket) {
302        Ok(stream) => stream,
303        Err(err) => {
304            eprintln!("Error connecting to socket {}: {}", socket, err);
305            exit(1);
306        }
307    };
308
309    let mut buffer = Vec::new();
310    if let Err(err) = io::stdin().read_to_end(&mut buffer) {
311        eprintln!("Error reading from stdin: {}", err);
312        exit(1);
313    }
314
315    #[cfg(target_os = "windows")]
316    while buffer.last().is_some_and(|&b| b == b'\n' || b == b'\r') {
317        buffer.pop();
318    }
319    if buffer.last() != Some(&b'\0') {
320        buffer.push(b'\0');
321    }
322
323    if let Err(err) = stream.write_all(&buffer) {
324        eprintln!("Error writing to socket: {}", err);
325        exit(1);
326    }
327
328    let mut response = Vec::new();
329    if let Err(err) = stream.read_to_end(&mut response) {
330        eprintln!("Error reading from socket: {}", err);
331        exit(1);
332    }
333
334    if let Err(err) = io::stdout().write_all(&response) {
335        eprintln!("Error writing to stdout: {}", err);
336        exit(1);
337    }
338}
339
340pub fn set_askpass_program(path: std::path::PathBuf) {
341    if ASKPASS_PROGRAM.set(path).is_err() {
342        debug_panic!("askpass program has already been set");
343    }
344}
345
346#[inline]
347#[cfg(not(target_os = "windows"))]
348fn generate_askpass_script(
349    shell_kind: ShellKind,
350    askpass_program: &std::path::Path,
351    askpass_socket: &std::path::Path,
352) -> Result<String> {
353    let askpass_program = shell_kind.prepend_command_prefix(
354        askpass_program
355            .to_str()
356            .context("Askpass program is on a non-utf8 path")?,
357    );
358    let askpass_program = shell_kind
359        .try_quote_prefix_aware(&askpass_program)
360        .context("Failed to shell-escape Askpass program path")?;
361    let askpass_socket = askpass_socket
362        .try_shell_safe(Some(&shell_kind))
363        .context("Failed to shell-escape Askpass socket path")?;
364    let print_args = "printf '%s\\0' \"$@\"";
365    let shebang = "#!/bin/sh";
366    Ok(format!(
367        "{shebang}\n{print_args} | {askpass_program} --askpass={askpass_socket} 2> /dev/null \n",
368    ))
369}
370
371#[inline]
372#[cfg(target_os = "windows")]
373fn generate_askpass_script(
374    shell_kind: ShellKind,
375    askpass_program: &std::path::Path,
376    askpass_socket: &std::path::Path,
377) -> Result<String> {
378    let askpass_program = shell_kind.prepend_command_prefix(
379        askpass_program
380            .to_str()
381            .context("Askpass program is on a non-utf8 path")?,
382    );
383    let askpass_program = shell_kind
384        .try_quote_prefix_aware(&askpass_program)
385        .context("Failed to shell-escape Askpass program path")?;
386    let askpass_socket = askpass_socket
387        .try_shell_safe(Some(&shell_kind))
388        .context("Failed to shell-escape Askpass socket path")?;
389    Ok(format!(
390        r#"
391        $ErrorActionPreference = 'Stop';
392        ($args -join [char]0) | {askpass_program} --askpass={askpass_socket} 2> $null
393        "#,
394    ))
395}