askpass.rs

  1mod encrypted_password;
  2
  3pub use encrypted_password::{EncryptedPassword, IKnowWhatIAmDoingAndIHaveReadTheDocs};
  4
  5use net::async_net::UnixListener;
  6use smol::lock::Mutex;
  7use util::fs::make_file_executable;
  8
  9use std::ffi::OsStr;
 10use std::ops::ControlFlow;
 11use std::sync::Arc;
 12use std::sync::OnceLock;
 13use std::time::Duration;
 14
 15use anyhow::{Context as _, Result};
 16use futures::channel::{mpsc, oneshot};
 17use futures::{
 18    AsyncBufReadExt as _, AsyncWriteExt as _, FutureExt as _, SinkExt, StreamExt, io::BufReader,
 19    select_biased,
 20};
 21use gpui::{AsyncApp, BackgroundExecutor, Task};
 22use smol::fs;
 23use util::{ResultExt as _, debug_panic, maybe, paths::PathExt};
 24
 25/// Path to the program used for askpass
 26///
 27/// On Unix and remote servers, this defaults to the current executable
 28/// On Windows, this is set to the CLI variant of zed
 29static ASKPASS_PROGRAM: OnceLock<std::path::PathBuf> = OnceLock::new();
 30
 31#[derive(PartialEq, Eq)]
 32pub enum AskPassResult {
 33    CancelledByUser,
 34    Timedout,
 35}
 36
 37pub struct AskPassDelegate {
 38    tx: mpsc::UnboundedSender<(String, oneshot::Sender<EncryptedPassword>)>,
 39    executor: BackgroundExecutor,
 40    _task: Task<()>,
 41}
 42
 43impl AskPassDelegate {
 44    pub fn new(
 45        cx: &mut AsyncApp,
 46        password_prompt: impl Fn(String, oneshot::Sender<EncryptedPassword>, &mut AsyncApp)
 47        + Send
 48        + Sync
 49        + 'static,
 50    ) -> Self {
 51        let (tx, mut rx) = mpsc::unbounded::<(String, oneshot::Sender<_>)>();
 52        let task = cx.spawn(async move |cx: &mut AsyncApp| {
 53            while let Some((prompt, channel)) = rx.next().await {
 54                password_prompt(prompt, channel, cx);
 55            }
 56        });
 57        Self {
 58            tx,
 59            _task: task,
 60            executor: cx.background_executor().clone(),
 61        }
 62    }
 63
 64    pub fn ask_password(&mut self, prompt: String) -> Task<Option<EncryptedPassword>> {
 65        let mut this_tx = self.tx.clone();
 66        self.executor.spawn(async move {
 67            let (tx, rx) = oneshot::channel();
 68            this_tx.send((prompt, tx)).await.ok()?;
 69            rx.await.ok()
 70        })
 71    }
 72}
 73
 74pub struct AskPassSession {
 75    #[cfg(target_os = "windows")]
 76    secret: std::sync::Arc<OnceLock<EncryptedPassword>>,
 77    askpass_task: PasswordProxy,
 78    askpass_opened_rx: Option<oneshot::Receiver<()>>,
 79    askpass_kill_master_rx: Option<oneshot::Receiver<()>>,
 80}
 81
 82#[cfg(not(target_os = "windows"))]
 83const ASKPASS_SCRIPT_NAME: &str = "askpass.sh";
 84#[cfg(target_os = "windows")]
 85const ASKPASS_SCRIPT_NAME: &str = "askpass.ps1";
 86
 87impl AskPassSession {
 88    /// This will create a new AskPassSession.
 89    /// You must retain this session until the master process exits.
 90    #[must_use]
 91    pub async fn new(executor: &BackgroundExecutor, mut delegate: AskPassDelegate) -> Result<Self> {
 92        #[cfg(target_os = "windows")]
 93        let secret = std::sync::Arc::new(OnceLock::new());
 94        let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
 95
 96        let askpass_opened_tx = Arc::new(Mutex::new(Some(askpass_opened_tx)));
 97
 98        let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<()>();
 99        let kill_tx = Arc::new(Mutex::new(Some(askpass_kill_master_tx)));
100
101        #[cfg(target_os = "windows")]
102        let askpass_secret = secret.clone();
103        let get_password = {
104            let executor = executor.clone();
105
106            move |prompt| {
107                let prompt = delegate.ask_password(prompt);
108                let kill_tx = kill_tx.clone();
109                let askpass_opened_tx = askpass_opened_tx.clone();
110                #[cfg(target_os = "windows")]
111                let askpass_secret = askpass_secret.clone();
112                executor.spawn(async move {
113                    if let Some(askpass_opened_tx) = askpass_opened_tx.lock().await.take() {
114                        askpass_opened_tx.send(()).ok();
115                    }
116                    if let Some(password) = prompt.await {
117                        #[cfg(target_os = "windows")]
118                        {
119                            _ = askpass_secret.set(password.clone());
120                        }
121                        ControlFlow::Continue(Ok(password))
122                    } else {
123                        if let Some(kill_tx) = kill_tx.lock().await.take() {
124                            kill_tx.send(()).log_err();
125                        }
126                        ControlFlow::Break(())
127                    }
128                })
129            }
130        };
131        let askpass_task = PasswordProxy::new(get_password, executor.clone()).await?;
132
133        Ok(Self {
134            #[cfg(target_os = "windows")]
135            secret,
136
137            askpass_task,
138            askpass_kill_master_rx: Some(askpass_kill_master_rx),
139            askpass_opened_rx: Some(askpass_opened_rx),
140        })
141    }
142
143    // This will run the askpass task forever, resolving as many authentication requests as needed.
144    // The caller is responsible for examining the result of their own commands and cancelling this
145    // future when this is no longer needed. Note that this can only be called once, but due to the
146    // drop order this takes an &mut, so you can `drop()` it after you're done with the master process.
147    pub async fn run(&mut self) -> AskPassResult {
148        // This is the default timeout setting used by VSCode.
149        let connection_timeout = Duration::from_secs(17);
150        let askpass_opened_rx = self.askpass_opened_rx.take().expect("Only call run once");
151        let askpass_kill_master_rx = self
152            .askpass_kill_master_rx
153            .take()
154            .expect("Only call run once");
155
156        select_biased! {
157            _ = askpass_opened_rx.fuse() => {
158                // Note: this await can only resolve after we are dropped.
159                askpass_kill_master_rx.await.ok();
160                AskPassResult::CancelledByUser
161            }
162
163            _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
164                AskPassResult::Timedout
165            }
166        }
167    }
168
169    /// This will return the password that was last set by the askpass script.
170    #[cfg(target_os = "windows")]
171    pub fn get_password(&self) -> Option<EncryptedPassword> {
172        self.secret.get().cloned()
173    }
174
175    pub fn script_path(&self) -> impl AsRef<OsStr> {
176        self.askpass_task.script_path()
177    }
178}
179
180pub struct PasswordProxy {
181    _task: Task<()>,
182    #[cfg(not(target_os = "windows"))]
183    askpass_script_path: std::path::PathBuf,
184    #[cfg(target_os = "windows")]
185    askpass_helper: String,
186}
187
188impl PasswordProxy {
189    pub async fn new(
190        mut get_password: impl FnMut(String) -> Task<ControlFlow<(), Result<EncryptedPassword>>>
191        + 'static
192        + Send
193        + Sync,
194        executor: BackgroundExecutor,
195    ) -> Result<Self> {
196        let temp_dir = tempfile::Builder::new().prefix("zed-askpass").tempdir()?;
197        let askpass_socket = temp_dir.path().join("askpass.sock");
198        let askpass_script_path = temp_dir.path().join(ASKPASS_SCRIPT_NAME);
199        let current_exec =
200            std::env::current_exe().context("Failed to determine current zed executable path.")?;
201
202        let askpass_program = ASKPASS_PROGRAM
203            .get_or_init(|| current_exec)
204            .try_shell_safe()
205            .context("Failed to shell-escape Askpass program path.")?
206            .to_string();
207        // Create an askpass script that communicates back to this process.
208        let askpass_script = generate_askpass_script(&askpass_program, &askpass_socket);
209        let _task = executor.spawn(async move {
210            maybe!(async move {
211                let listener =
212                    UnixListener::bind(&askpass_socket).context("creating askpass socket")?;
213
214                while let Ok((mut stream, _)) = listener.accept().await {
215                    let mut buffer = Vec::new();
216                    let mut reader = BufReader::new(&mut stream);
217                    if reader.read_until(b'\0', &mut buffer).await.is_err() {
218                        buffer.clear();
219                    }
220                    let prompt = String::from_utf8_lossy(&buffer).into_owned();
221                    let password = get_password(prompt).await;
222                    match password {
223                        ControlFlow::Continue(password) => {
224                            if let Ok(password) = password
225                                && let Ok(decrypted) =
226                                    password.decrypt(IKnowWhatIAmDoingAndIHaveReadTheDocs)
227                            {
228                                stream.write_all(decrypted.as_bytes()).await.log_err();
229                            }
230                        }
231                        ControlFlow::Break(()) => {
232                            // note: we expect the caller to drop this task when it's done.
233                            // We need to keep the stream open until the caller is done to avoid
234                            // spurious errors from ssh.
235                            std::future::pending::<()>().await;
236                            drop(stream);
237                        }
238                    }
239                }
240                drop(temp_dir);
241                Result::<_, anyhow::Error>::Ok(())
242            })
243            .await
244            .log_err();
245        });
246
247        fs::write(&askpass_script_path, askpass_script)
248            .await
249            .with_context(|| format!("creating askpass script at {askpass_script_path:?}"))?;
250        make_file_executable(&askpass_script_path).await?;
251        #[cfg(target_os = "windows")]
252        let askpass_helper = format!(
253            "powershell.exe -ExecutionPolicy Bypass -File {}",
254            askpass_script_path.display()
255        );
256
257        Ok(Self {
258            _task,
259            #[cfg(not(target_os = "windows"))]
260            askpass_script_path,
261            #[cfg(target_os = "windows")]
262            askpass_helper,
263        })
264    }
265
266    pub fn script_path(&self) -> impl AsRef<OsStr> {
267        #[cfg(not(target_os = "windows"))]
268        {
269            &self.askpass_script_path
270        }
271        #[cfg(target_os = "windows")]
272        {
273            &self.askpass_helper
274        }
275    }
276}
277/// The main function for when Zed is running in netcat mode for use in askpass.
278/// Called from both the remote server binary and the zed binary in their respective main functions.
279pub fn main(socket: &str) {
280    use net::UnixStream;
281    use std::io::{self, Read, Write};
282    use std::process::exit;
283
284    let mut stream = match UnixStream::connect(socket) {
285        Ok(stream) => stream,
286        Err(err) => {
287            eprintln!("Error connecting to socket {}: {}", socket, err);
288            exit(1);
289        }
290    };
291
292    let mut buffer = Vec::new();
293    if let Err(err) = io::stdin().read_to_end(&mut buffer) {
294        eprintln!("Error reading from stdin: {}", err);
295        exit(1);
296    }
297
298    #[cfg(target_os = "windows")]
299    while buffer.last().is_some_and(|&b| b == b'\n' || b == b'\r') {
300        buffer.pop();
301    }
302    if buffer.last() != Some(&b'\0') {
303        buffer.push(b'\0');
304    }
305
306    if let Err(err) = stream.write_all(&buffer) {
307        eprintln!("Error writing to socket: {}", err);
308        exit(1);
309    }
310
311    let mut response = Vec::new();
312    if let Err(err) = stream.read_to_end(&mut response) {
313        eprintln!("Error reading from socket: {}", err);
314        exit(1);
315    }
316
317    if let Err(err) = io::stdout().write_all(&response) {
318        eprintln!("Error writing to stdout: {}", err);
319        exit(1);
320    }
321}
322
323pub fn set_askpass_program(path: std::path::PathBuf) {
324    if ASKPASS_PROGRAM.set(path).is_err() {
325        debug_panic!("askpass program has already been set");
326    }
327}
328
329#[inline]
330#[cfg(not(target_os = "windows"))]
331fn generate_askpass_script(askpass_program: &str, askpass_socket: &std::path::Path) -> String {
332    format!(
333        "{shebang}\n{print_args} | {askpass_program} --askpass={askpass_socket} 2> /dev/null \n",
334        askpass_socket = askpass_socket.display(),
335        print_args = "printf '%s\\0' \"$@\"",
336        shebang = "#!/bin/sh",
337    )
338}
339
340#[inline]
341#[cfg(target_os = "windows")]
342fn generate_askpass_script(askpass_program: &str, askpass_socket: &std::path::Path) -> String {
343    format!(
344        r#"
345        $ErrorActionPreference = 'Stop';
346        ($args -join [char]0) | & "{askpass_program}" --askpass={askpass_socket} 2> $null
347        "#,
348        askpass_socket = askpass_socket.display(),
349    )
350}