askpass.rs

  1mod encrypted_password;
  2
  3pub use encrypted_password::{EncryptedPassword, IKnowWhatIAmDoingAndIHaveReadTheDocs};
  4
  5use net::async_net::UnixListener;
  6use smol::lock::Mutex;
  7use util::fs::make_file_executable;
  8
  9use std::ffi::OsStr;
 10use std::ops::ControlFlow;
 11use std::sync::Arc;
 12use std::sync::OnceLock;
 13use std::time::Duration;
 14
 15use anyhow::{Context as _, Result};
 16use futures::channel::{mpsc, oneshot};
 17use futures::{
 18    AsyncBufReadExt as _, AsyncWriteExt as _, FutureExt as _, SinkExt, StreamExt, io::BufReader,
 19    select_biased,
 20};
 21use gpui::{AsyncApp, BackgroundExecutor, Task};
 22use smol::fs;
 23use util::{ResultExt as _, debug_panic, maybe, paths::PathExt, shell::ShellKind};
 24
 25/// Path to the program used for askpass
 26///
 27/// On Unix and remote servers, this defaults to the current executable
 28/// On Windows, this is set to the CLI variant of zed
 29static ASKPASS_PROGRAM: OnceLock<std::path::PathBuf> = OnceLock::new();
 30
 31#[derive(PartialEq, Eq)]
 32pub enum AskPassResult {
 33    CancelledByUser,
 34    Timedout,
 35}
 36
 37pub struct AskPassDelegate {
 38    tx: mpsc::UnboundedSender<(String, oneshot::Sender<EncryptedPassword>)>,
 39    executor: BackgroundExecutor,
 40    _task: Task<()>,
 41}
 42
 43impl AskPassDelegate {
 44    pub fn new(
 45        cx: &mut AsyncApp,
 46        password_prompt: impl Fn(String, oneshot::Sender<EncryptedPassword>, &mut AsyncApp)
 47        + Send
 48        + Sync
 49        + 'static,
 50    ) -> Self {
 51        let (tx, mut rx) = mpsc::unbounded::<(String, oneshot::Sender<_>)>();
 52        let task = cx.spawn(async move |cx: &mut AsyncApp| {
 53            while let Some((prompt, channel)) = rx.next().await {
 54                password_prompt(prompt, channel, cx);
 55            }
 56        });
 57        Self {
 58            tx,
 59            _task: task,
 60            executor: cx.background_executor().clone(),
 61        }
 62    }
 63
 64    pub fn ask_password(&mut self, prompt: String) -> Task<Option<EncryptedPassword>> {
 65        let mut this_tx = self.tx.clone();
 66        self.executor.spawn(async move {
 67            let (tx, rx) = oneshot::channel();
 68            this_tx.send((prompt, tx)).await.ok()?;
 69            rx.await.ok()
 70        })
 71    }
 72}
 73
 74pub struct AskPassSession {
 75    #[cfg(target_os = "windows")]
 76    secret: std::sync::Arc<std::sync::Mutex<Option<EncryptedPassword>>>,
 77    askpass_task: PasswordProxy,
 78    askpass_opened_rx: Option<oneshot::Receiver<()>>,
 79    askpass_kill_master_rx: Option<oneshot::Receiver<()>>,
 80    executor: BackgroundExecutor,
 81}
 82
 83const ASKPASS_SCRIPT_NAME: &str = if cfg!(target_os = "windows") {
 84    "askpass.ps1"
 85} else {
 86    "askpass.sh"
 87};
 88
 89impl AskPassSession {
 90    /// This will create a new AskPassSession.
 91    /// You must retain this session until the master process exits.
 92    #[must_use]
 93    pub async fn new(executor: BackgroundExecutor, mut delegate: AskPassDelegate) -> Result<Self> {
 94        #[cfg(target_os = "windows")]
 95        let secret = std::sync::Arc::new(std::sync::Mutex::new(None));
 96
 97        let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
 98
 99        let askpass_opened_tx = Arc::new(Mutex::new(Some(askpass_opened_tx)));
100
101        let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<()>();
102        let kill_tx = Arc::new(Mutex::new(Some(askpass_kill_master_tx)));
103
104        let get_password = {
105            let executor = executor.clone();
106
107            #[cfg(target_os = "windows")]
108            let askpass_secret = secret.clone();
109            move |prompt| {
110                let prompt = delegate.ask_password(prompt);
111                let kill_tx = kill_tx.clone();
112                let askpass_opened_tx = askpass_opened_tx.clone();
113                #[cfg(target_os = "windows")]
114                let askpass_secret = askpass_secret.clone();
115                executor.spawn(async move {
116                    if let Some(askpass_opened_tx) = askpass_opened_tx.lock().await.take() {
117                        askpass_opened_tx.send(()).ok();
118                    }
119                    if let Some(password) = prompt.await {
120                        #[cfg(target_os = "windows")]
121                        {
122                            askpass_secret.lock().unwrap().replace(password.clone());
123                        }
124                        ControlFlow::Continue(Ok(password))
125                    } else {
126                        if let Some(kill_tx) = kill_tx.lock().await.take() {
127                            kill_tx.send(()).log_err();
128                        }
129                        ControlFlow::Break(())
130                    }
131                })
132            }
133        };
134        let askpass_task = PasswordProxy::new(Box::new(get_password), executor.clone()).await?;
135
136        Ok(Self {
137            #[cfg(target_os = "windows")]
138            secret,
139
140            askpass_task,
141            askpass_kill_master_rx: Some(askpass_kill_master_rx),
142            askpass_opened_rx: Some(askpass_opened_rx),
143            executor,
144        })
145    }
146
147    // This will run the askpass task forever, resolving as many authentication requests as needed.
148    // The caller is responsible for examining the result of their own commands and cancelling this
149    // future when this is no longer needed. Note that this can only be called once, but due to the
150    // drop order this takes an &mut, so you can `drop()` it after you're done with the master process.
151    pub async fn run(&mut self) -> AskPassResult {
152        // This is the default timeout setting used by VSCode.
153        let connection_timeout = Duration::from_secs(17);
154        let askpass_opened_rx = self.askpass_opened_rx.take().expect("Only call run once");
155        let askpass_kill_master_rx = self
156            .askpass_kill_master_rx
157            .take()
158            .expect("Only call run once");
159        let executor = self.executor.clone();
160
161        select_biased! {
162            _ = askpass_opened_rx.fuse() => {
163                // Note: this await can only resolve after we are dropped.
164                askpass_kill_master_rx.await.ok();
165                AskPassResult::CancelledByUser
166            }
167
168            _ = futures::FutureExt::fuse(executor.timer(connection_timeout)) => {
169                AskPassResult::Timedout
170            }
171        }
172    }
173
174    /// This will return the password that was last set by the askpass script.
175    #[cfg(target_os = "windows")]
176    pub fn get_password(&self) -> Option<EncryptedPassword> {
177        self.secret.lock().ok()?.clone()
178    }
179
180    pub fn script_path(&self) -> impl AsRef<OsStr> {
181        self.askpass_task.script_path()
182    }
183}
184
185pub struct PasswordProxy {
186    _task: Task<()>,
187    #[cfg(not(target_os = "windows"))]
188    askpass_script_path: std::path::PathBuf,
189    #[cfg(target_os = "windows")]
190    askpass_helper: String,
191}
192
193impl PasswordProxy {
194    pub async fn new(
195        mut get_password: Box<
196            dyn FnMut(String) -> Task<ControlFlow<(), Result<EncryptedPassword>>>
197                + 'static
198                + Send
199                + Sync,
200        >,
201        executor: BackgroundExecutor,
202    ) -> Result<Self> {
203        let temp_dir = tempfile::Builder::new().prefix("zed-askpass").tempdir()?;
204        let askpass_socket = temp_dir.path().join("askpass.sock");
205        let askpass_script_path = temp_dir.path().join(ASKPASS_SCRIPT_NAME);
206        let current_exec =
207            std::env::current_exe().context("Failed to determine current zed executable path.")?;
208
209        // TODO: inferred from the use of powershell.exe in askpass_helper_script
210        let shell_kind = if cfg!(windows) {
211            ShellKind::PowerShell
212        } else {
213            ShellKind::Posix
214        };
215        let askpass_program = ASKPASS_PROGRAM.get_or_init(|| current_exec);
216        // Create an askpass script that communicates back to this process.
217        let askpass_script = generate_askpass_script(shell_kind, askpass_program, &askpass_socket)?;
218        let _task = executor.spawn(async move {
219            maybe!(async move {
220                let listener =
221                    UnixListener::bind(&askpass_socket).context("creating askpass socket")?;
222
223                while let Ok((mut stream, _)) = listener.accept().await {
224                    let mut buffer = Vec::new();
225                    let mut reader = BufReader::new(&mut stream);
226                    if reader.read_until(b'\0', &mut buffer).await.is_err() {
227                        buffer.clear();
228                    }
229                    let prompt = String::from_utf8_lossy(&buffer).into_owned();
230                    let password = get_password(prompt).await;
231                    match password {
232                        ControlFlow::Continue(password) => {
233                            if let Ok(password) = password
234                                && let Ok(decrypted) =
235                                    password.decrypt(IKnowWhatIAmDoingAndIHaveReadTheDocs)
236                            {
237                                stream.write_all(decrypted.as_bytes()).await.log_err();
238                            }
239                        }
240                        ControlFlow::Break(()) => {
241                            // note: we expect the caller to drop this task when it's done.
242                            // We need to keep the stream open until the caller is done to avoid
243                            // spurious errors from ssh.
244                            std::future::pending::<()>().await;
245                            drop(stream);
246                        }
247                    }
248                }
249                drop(temp_dir);
250                Result::<_, anyhow::Error>::Ok(())
251            })
252            .await
253            .log_err();
254        });
255
256        fs::write(&askpass_script_path, askpass_script)
257            .await
258            .with_context(|| format!("creating askpass script at {askpass_script_path:?}"))?;
259        make_file_executable(&askpass_script_path)
260            .await
261            .with_context(|| {
262                format!("marking askpass script executable at {askpass_script_path:?}")
263            })?;
264        // todo(shell): There might be no powershell on the system
265        #[cfg(target_os = "windows")]
266        let askpass_helper = format!(
267            "powershell.exe -ExecutionPolicy Bypass -File \"{}\"",
268            askpass_script_path.display()
269        );
270
271        Ok(Self {
272            _task,
273            #[cfg(not(target_os = "windows"))]
274            askpass_script_path,
275            #[cfg(target_os = "windows")]
276            askpass_helper,
277        })
278    }
279
280    pub fn script_path(&self) -> impl AsRef<OsStr> {
281        #[cfg(not(target_os = "windows"))]
282        {
283            &self.askpass_script_path
284        }
285        #[cfg(target_os = "windows")]
286        {
287            &self.askpass_helper
288        }
289    }
290}
291/// The main function for when Zed is running in netcat mode for use in askpass.
292/// Called from both the remote server binary and the zed binary in their respective main functions.
293pub fn main(socket: &str) {
294    use net::UnixStream;
295    use std::io::{self, Read, Write};
296    use std::process::exit;
297
298    let mut stream = match UnixStream::connect(socket) {
299        Ok(stream) => stream,
300        Err(err) => {
301            eprintln!("Error connecting to socket {}: {}", socket, err);
302            exit(1);
303        }
304    };
305
306    let mut buffer = Vec::new();
307    if let Err(err) = io::stdin().read_to_end(&mut buffer) {
308        eprintln!("Error reading from stdin: {}", err);
309        exit(1);
310    }
311
312    #[cfg(target_os = "windows")]
313    while buffer.last().is_some_and(|&b| b == b'\n' || b == b'\r') {
314        buffer.pop();
315    }
316    if buffer.last() != Some(&b'\0') {
317        buffer.push(b'\0');
318    }
319
320    if let Err(err) = stream.write_all(&buffer) {
321        eprintln!("Error writing to socket: {}", err);
322        exit(1);
323    }
324
325    let mut response = Vec::new();
326    if let Err(err) = stream.read_to_end(&mut response) {
327        eprintln!("Error reading from socket: {}", err);
328        exit(1);
329    }
330
331    if let Err(err) = io::stdout().write_all(&response) {
332        eprintln!("Error writing to stdout: {}", err);
333        exit(1);
334    }
335}
336
337pub fn set_askpass_program(path: std::path::PathBuf) {
338    if ASKPASS_PROGRAM.set(path).is_err() {
339        debug_panic!("askpass program has already been set");
340    }
341}
342
343#[inline]
344#[cfg(not(target_os = "windows"))]
345fn generate_askpass_script(
346    shell_kind: ShellKind,
347    askpass_program: &std::path::Path,
348    askpass_socket: &std::path::Path,
349) -> Result<String> {
350    let askpass_program = shell_kind.prepend_command_prefix(
351        askpass_program
352            .to_str()
353            .context("Askpass program is on a non-utf8 path")?,
354    );
355    let askpass_program = shell_kind
356        .try_quote_prefix_aware(&askpass_program)
357        .context("Failed to shell-escape Askpass program path")?;
358    let askpass_socket = askpass_socket
359        .try_shell_safe(shell_kind)
360        .context("Failed to shell-escape Askpass socket path")?;
361    let print_args = "printf '%s\\0' \"$@\"";
362    let shebang = "#!/bin/sh";
363    Ok(format!(
364        "{shebang}\n{print_args} | {askpass_program} --askpass={askpass_socket} 2> /dev/null \n",
365    ))
366}
367
368#[inline]
369#[cfg(target_os = "windows")]
370fn generate_askpass_script(
371    shell_kind: ShellKind,
372    askpass_program: &std::path::Path,
373    askpass_socket: &std::path::Path,
374) -> Result<String> {
375    let askpass_program = shell_kind.prepend_command_prefix(
376        askpass_program
377            .to_str()
378            .context("Askpass program is on a non-utf8 path")?,
379    );
380    let askpass_program = shell_kind
381        .try_quote_prefix_aware(&askpass_program)
382        .context("Failed to shell-escape Askpass program path")?;
383    let askpass_socket = askpass_socket
384        .try_shell_safe(shell_kind)
385        .context("Failed to shell-escape Askpass socket path")?;
386    Ok(format!(
387        r#"
388        $ErrorActionPreference = 'Stop';
389        ($args -join [char]0) | {askpass_program} --askpass={askpass_socket} 2> $null
390        "#,
391    ))
392}