askpass.rs

  1mod encrypted_password;
  2
  3pub use encrypted_password::{EncryptedPassword, ProcessExt};
  4use util::paths::PathExt;
  5
  6use std::sync::OnceLock;
  7use std::{ffi::OsStr, time::Duration};
  8
  9use anyhow::{Context as _, Result};
 10use futures::channel::{mpsc, oneshot};
 11use futures::{
 12    AsyncBufReadExt as _, AsyncWriteExt as _, FutureExt as _, SinkExt, StreamExt, io::BufReader,
 13    select_biased,
 14};
 15use gpui::{AsyncApp, BackgroundExecutor, Task};
 16use smol::fs;
 17use util::{ResultExt as _, debug_panic};
 18
 19use crate::encrypted_password::decrypt;
 20
 21/// Path to the program used for askpass
 22///
 23/// On Unix and remote servers, this defaults to the current executable
 24/// On Windows, this is set to the CLI variant of zed
 25static ASKPASS_PROGRAM: OnceLock<std::path::PathBuf> = OnceLock::new();
 26
 27#[derive(PartialEq, Eq)]
 28pub enum AskPassResult {
 29    CancelledByUser,
 30    Timedout,
 31}
 32
 33pub struct AskPassDelegate {
 34    tx: mpsc::UnboundedSender<(String, oneshot::Sender<EncryptedPassword>)>,
 35    _task: Task<()>,
 36}
 37
 38impl AskPassDelegate {
 39    pub fn new(
 40        cx: &mut AsyncApp,
 41        password_prompt: impl Fn(String, oneshot::Sender<EncryptedPassword>, &mut AsyncApp)
 42        + Send
 43        + Sync
 44        + 'static,
 45    ) -> Self {
 46        let (tx, mut rx) = mpsc::unbounded::<(String, oneshot::Sender<_>)>();
 47        let task = cx.spawn(async move |cx: &mut AsyncApp| {
 48            while let Some((prompt, channel)) = rx.next().await {
 49                password_prompt(prompt, channel, cx);
 50            }
 51        });
 52        Self { tx, _task: task }
 53    }
 54
 55    pub async fn ask_password(&mut self, prompt: String) -> Result<EncryptedPassword> {
 56        let (tx, rx) = oneshot::channel();
 57        self.tx.send((prompt, tx)).await?;
 58        Ok(rx.await?)
 59    }
 60}
 61
 62pub struct AskPassSession {
 63    #[cfg(not(target_os = "windows"))]
 64    script_path: std::path::PathBuf,
 65    #[cfg(target_os = "windows")]
 66    askpass_helper: String,
 67    #[cfg(target_os = "windows")]
 68    secret: std::sync::Arc<OnceLock<EncryptedPassword>>,
 69    _askpass_task: Task<()>,
 70    askpass_opened_rx: Option<oneshot::Receiver<()>>,
 71    askpass_kill_master_rx: Option<oneshot::Receiver<()>>,
 72}
 73
 74#[cfg(not(target_os = "windows"))]
 75const ASKPASS_SCRIPT_NAME: &str = "askpass.sh";
 76#[cfg(target_os = "windows")]
 77const ASKPASS_SCRIPT_NAME: &str = "askpass.ps1";
 78
 79impl AskPassSession {
 80    /// This will create a new AskPassSession.
 81    /// You must retain this session until the master process exits.
 82    #[must_use]
 83    pub async fn new(executor: &BackgroundExecutor, mut delegate: AskPassDelegate) -> Result<Self> {
 84        use net::async_net::UnixListener;
 85        use util::fs::make_file_executable;
 86
 87        #[cfg(target_os = "windows")]
 88        let secret = std::sync::Arc::new(OnceLock::new());
 89        let temp_dir = tempfile::Builder::new().prefix("zed-askpass").tempdir()?;
 90        let askpass_socket = temp_dir.path().join("askpass.sock");
 91        let askpass_script_path = temp_dir.path().join(ASKPASS_SCRIPT_NAME);
 92        let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
 93        let listener = UnixListener::bind(&askpass_socket).context("creating askpass socket")?;
 94
 95        let current_exec =
 96            std::env::current_exe().context("Failed to determine current zed executable path.")?;
 97
 98        let askpass_program = ASKPASS_PROGRAM
 99            .get_or_init(|| current_exec)
100            .try_shell_safe()
101            .context("Failed to shell-escape Askpass program path.")?
102            .to_string();
103
104        let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<()>();
105        let mut kill_tx = Some(askpass_kill_master_tx);
106
107        #[cfg(target_os = "windows")]
108        let askpass_secret = secret.clone();
109        let askpass_task = executor.spawn(async move {
110            let mut askpass_opened_tx = Some(askpass_opened_tx);
111
112            while let Ok((mut stream, _)) = listener.accept().await {
113                if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
114                    askpass_opened_tx.send(()).ok();
115                }
116                let mut buffer = Vec::new();
117                let mut reader = BufReader::new(&mut stream);
118                if reader.read_until(b'\0', &mut buffer).await.is_err() {
119                    buffer.clear();
120                }
121                let prompt = String::from_utf8_lossy(&buffer);
122                if let Some(password) = delegate
123                    .ask_password(prompt.to_string())
124                    .await
125                    .context("getting askpass password")
126                    .log_err()
127                {
128                    #[cfg(target_os = "windows")]
129                    {
130                        askpass_secret.get_or_init(|| password.clone());
131                    }
132                    if let Ok(decrypted) = decrypt(password) {
133                        stream.write_all(decrypted.as_bytes()).await.log_err();
134                    }
135                } else {
136                    if let Some(kill_tx) = kill_tx.take() {
137                        kill_tx.send(()).log_err();
138                    }
139                    // note: we expect the caller to drop this task when it's done.
140                    // We need to keep the stream open until the caller is done to avoid
141                    // spurious errors from ssh.
142                    std::future::pending::<()>().await;
143                    drop(stream);
144                }
145            }
146            drop(temp_dir)
147        });
148
149        // Create an askpass script that communicates back to this process.
150        let askpass_script = generate_askpass_script(&askpass_program, &askpass_socket);
151        fs::write(&askpass_script_path, askpass_script)
152            .await
153            .with_context(|| format!("creating askpass script at {askpass_script_path:?}"))?;
154        make_file_executable(&askpass_script_path).await?;
155        #[cfg(target_os = "windows")]
156        let askpass_helper = format!(
157            "powershell.exe -ExecutionPolicy Bypass -File {}",
158            askpass_script_path.display()
159        );
160
161        Ok(Self {
162            #[cfg(not(target_os = "windows"))]
163            script_path: askpass_script_path,
164
165            #[cfg(target_os = "windows")]
166            secret,
167            #[cfg(target_os = "windows")]
168            askpass_helper,
169
170            _askpass_task: askpass_task,
171            askpass_kill_master_rx: Some(askpass_kill_master_rx),
172            askpass_opened_rx: Some(askpass_opened_rx),
173        })
174    }
175
176    #[cfg(not(target_os = "windows"))]
177    pub fn script_path(&self) -> impl AsRef<OsStr> {
178        &self.script_path
179    }
180
181    #[cfg(target_os = "windows")]
182    pub fn script_path(&self) -> impl AsRef<OsStr> {
183        &self.askpass_helper
184    }
185
186    // This will run the askpass task forever, resolving as many authentication requests as needed.
187    // The caller is responsible for examining the result of their own commands and cancelling this
188    // future when this is no longer needed. Note that this can only be called once, but due to the
189    // drop order this takes an &mut, so you can `drop()` it after you're done with the master process.
190    pub async fn run(&mut self) -> AskPassResult {
191        // This is the default timeout setting used by VSCode.
192        let connection_timeout = Duration::from_secs(17);
193        let askpass_opened_rx = self.askpass_opened_rx.take().expect("Only call run once");
194        let askpass_kill_master_rx = self
195            .askpass_kill_master_rx
196            .take()
197            .expect("Only call run once");
198
199        select_biased! {
200            _ = askpass_opened_rx.fuse() => {
201                // Note: this await can only resolve after we are dropped.
202                askpass_kill_master_rx.await.ok();
203                AskPassResult::CancelledByUser
204            }
205
206            _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
207                AskPassResult::Timedout
208            }
209        }
210    }
211
212    /// This will return the password that was last set by the askpass script.
213    #[cfg(target_os = "windows")]
214    pub fn get_password(&self) -> Option<EncryptedPassword> {
215        self.secret.get().cloned()
216    }
217}
218
219/// The main function for when Zed is running in netcat mode for use in askpass.
220/// Called from both the remote server binary and the zed binary in their respective main functions.
221pub fn main(socket: &str) {
222    use net::UnixStream;
223    use std::io::{self, Read, Write};
224    use std::process::exit;
225
226    let mut stream = match UnixStream::connect(socket) {
227        Ok(stream) => stream,
228        Err(err) => {
229            eprintln!("Error connecting to socket {}: {}", socket, err);
230            exit(1);
231        }
232    };
233
234    let mut buffer = Vec::new();
235    if let Err(err) = io::stdin().read_to_end(&mut buffer) {
236        eprintln!("Error reading from stdin: {}", err);
237        exit(1);
238    }
239
240    #[cfg(target_os = "windows")]
241    while buffer.last().is_some_and(|&b| b == b'\n' || b == b'\r') {
242        buffer.pop();
243    }
244    if buffer.last() != Some(&b'\0') {
245        buffer.push(b'\0');
246    }
247
248    if let Err(err) = stream.write_all(&buffer) {
249        eprintln!("Error writing to socket: {}", err);
250        exit(1);
251    }
252
253    let mut response = Vec::new();
254    if let Err(err) = stream.read_to_end(&mut response) {
255        eprintln!("Error reading from socket: {}", err);
256        exit(1);
257    }
258
259    if let Err(err) = io::stdout().write_all(&response) {
260        eprintln!("Error writing to stdout: {}", err);
261        exit(1);
262    }
263}
264
265pub fn set_askpass_program(path: std::path::PathBuf) {
266    if ASKPASS_PROGRAM.set(path).is_err() {
267        debug_panic!("askpass program has already been set");
268    }
269}
270
271#[inline]
272#[cfg(not(target_os = "windows"))]
273fn generate_askpass_script(askpass_program: &str, askpass_socket: &std::path::Path) -> String {
274    format!(
275        "{shebang}\n{print_args} | {askpass_program} --askpass={askpass_socket} 2> /dev/null \n",
276        askpass_socket = askpass_socket.display(),
277        print_args = "printf '%s\\0' \"$@\"",
278        shebang = "#!/bin/sh",
279    )
280}
281
282#[inline]
283#[cfg(target_os = "windows")]
284fn generate_askpass_script(askpass_program: &str, askpass_socket: &std::path::Path) -> String {
285    format!(
286        r#"
287        $ErrorActionPreference = 'Stop';
288        ($args -join [char]0) | & "{askpass_program}" --askpass={askpass_socket} 2> $null
289        "#,
290        askpass_socket = askpass_socket.display(),
291    )
292}