askpass.rs

  1mod encrypted_password;
  2
  3pub use encrypted_password::{EncryptedPassword, ProcessExt};
  4
  5#[cfg(target_os = "windows")]
  6use std::sync::OnceLock;
  7use std::{ffi::OsStr, time::Duration};
  8
  9use anyhow::{Context as _, Result};
 10use futures::channel::{mpsc, oneshot};
 11use futures::{
 12    AsyncBufReadExt as _, AsyncWriteExt as _, FutureExt as _, SinkExt, StreamExt, io::BufReader,
 13    select_biased,
 14};
 15use gpui::{AsyncApp, BackgroundExecutor, Task};
 16use smol::fs;
 17use util::ResultExt as _;
 18
 19use crate::encrypted_password::decrypt;
 20
 21#[derive(PartialEq, Eq)]
 22pub enum AskPassResult {
 23    CancelledByUser,
 24    Timedout,
 25}
 26
 27pub struct AskPassDelegate {
 28    tx: mpsc::UnboundedSender<(String, oneshot::Sender<EncryptedPassword>)>,
 29    _task: Task<()>,
 30}
 31
 32impl AskPassDelegate {
 33    pub fn new(
 34        cx: &mut AsyncApp,
 35        password_prompt: impl Fn(String, oneshot::Sender<EncryptedPassword>, &mut AsyncApp)
 36        + Send
 37        + Sync
 38        + 'static,
 39    ) -> Self {
 40        let (tx, mut rx) = mpsc::unbounded::<(String, oneshot::Sender<_>)>();
 41        let task = cx.spawn(async move |cx: &mut AsyncApp| {
 42            while let Some((prompt, channel)) = rx.next().await {
 43                password_prompt(prompt, channel, cx);
 44            }
 45        });
 46        Self { tx, _task: task }
 47    }
 48
 49    pub async fn ask_password(&mut self, prompt: String) -> Option<EncryptedPassword> {
 50        let (tx, rx) = oneshot::channel();
 51        self.tx.send((prompt, tx)).await.ok()?;
 52        rx.await.ok()
 53    }
 54}
 55
 56pub struct AskPassSession {
 57    #[cfg(not(target_os = "windows"))]
 58    script_path: std::path::PathBuf,
 59    #[cfg(target_os = "windows")]
 60    askpass_helper: String,
 61    #[cfg(target_os = "windows")]
 62    secret: std::sync::Arc<OnceLock<EncryptedPassword>>,
 63    _askpass_task: Task<()>,
 64    askpass_opened_rx: Option<oneshot::Receiver<()>>,
 65    askpass_kill_master_rx: Option<oneshot::Receiver<()>>,
 66}
 67
 68#[cfg(not(target_os = "windows"))]
 69const ASKPASS_SCRIPT_NAME: &str = "askpass.sh";
 70#[cfg(target_os = "windows")]
 71const ASKPASS_SCRIPT_NAME: &str = "askpass.ps1";
 72
 73impl AskPassSession {
 74    /// This will create a new AskPassSession.
 75    /// You must retain this session until the master process exits.
 76    #[must_use]
 77    pub async fn new(executor: &BackgroundExecutor, mut delegate: AskPassDelegate) -> Result<Self> {
 78        use net::async_net::UnixListener;
 79        use util::fs::make_file_executable;
 80
 81        #[cfg(target_os = "windows")]
 82        let secret = std::sync::Arc::new(OnceLock::new());
 83        let temp_dir = tempfile::Builder::new().prefix("zed-askpass").tempdir()?;
 84        let askpass_socket = temp_dir.path().join("askpass.sock");
 85        let askpass_script_path = temp_dir.path().join(ASKPASS_SCRIPT_NAME);
 86        let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
 87        let listener = UnixListener::bind(&askpass_socket).context("creating askpass socket")?;
 88        let zed_cli_path =
 89            util::get_shell_safe_zed_cli_path().context("getting zed-cli path for askpass")?;
 90
 91        let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<()>();
 92        let mut kill_tx = Some(askpass_kill_master_tx);
 93
 94        #[cfg(target_os = "windows")]
 95        let askpass_secret = secret.clone();
 96        let askpass_task = executor.spawn(async move {
 97            let mut askpass_opened_tx = Some(askpass_opened_tx);
 98
 99            while let Ok((mut stream, _)) = listener.accept().await {
100                if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
101                    askpass_opened_tx.send(()).ok();
102                }
103                let mut buffer = Vec::new();
104                let mut reader = BufReader::new(&mut stream);
105                if reader.read_until(b'\0', &mut buffer).await.is_err() {
106                    buffer.clear();
107                }
108                let prompt = String::from_utf8_lossy(&buffer);
109                if let Some(password) = delegate.ask_password(prompt.into_owned()).await {
110                    #[cfg(target_os = "windows")]
111                    {
112                        askpass_secret.get_or_init(|| password.clone());
113                    }
114                    if let Ok(decrypted) = decrypt(password) {
115                        stream.write_all(decrypted.as_bytes()).await.log_err();
116                    }
117                } else {
118                    if let Some(kill_tx) = kill_tx.take() {
119                        kill_tx.send(()).log_err();
120                    }
121                    // note: we expect the caller to drop this task when it's done.
122                    // We need to keep the stream open until the caller is done to avoid
123                    // spurious errors from ssh.
124                    std::future::pending::<()>().await;
125                    drop(stream);
126                }
127            }
128            drop(temp_dir)
129        });
130
131        // Create an askpass script that communicates back to this process.
132        let askpass_script = generate_askpass_script(&zed_cli_path, &askpass_socket);
133        fs::write(&askpass_script_path, askpass_script)
134            .await
135            .with_context(|| format!("creating askpass script at {askpass_script_path:?}"))?;
136        make_file_executable(&askpass_script_path).await?;
137        #[cfg(target_os = "windows")]
138        let askpass_helper = format!(
139            "powershell.exe -ExecutionPolicy Bypass -File {}",
140            askpass_script_path.display()
141        );
142
143        Ok(Self {
144            #[cfg(not(target_os = "windows"))]
145            script_path: askpass_script_path,
146
147            #[cfg(target_os = "windows")]
148            secret,
149            #[cfg(target_os = "windows")]
150            askpass_helper,
151
152            _askpass_task: askpass_task,
153            askpass_kill_master_rx: Some(askpass_kill_master_rx),
154            askpass_opened_rx: Some(askpass_opened_rx),
155        })
156    }
157
158    #[cfg(not(target_os = "windows"))]
159    pub fn script_path(&self) -> impl AsRef<OsStr> {
160        &self.script_path
161    }
162
163    #[cfg(target_os = "windows")]
164    pub fn script_path(&self) -> impl AsRef<OsStr> {
165        &self.askpass_helper
166    }
167
168    // This will run the askpass task forever, resolving as many authentication requests as needed.
169    // The caller is responsible for examining the result of their own commands and cancelling this
170    // future when this is no longer needed. Note that this can only be called once, but due to the
171    // drop order this takes an &mut, so you can `drop()` it after you're done with the master process.
172    pub async fn run(&mut self) -> AskPassResult {
173        // This is the default timeout setting used by VSCode.
174        let connection_timeout = Duration::from_secs(17);
175        let askpass_opened_rx = self.askpass_opened_rx.take().expect("Only call run once");
176        let askpass_kill_master_rx = self
177            .askpass_kill_master_rx
178            .take()
179            .expect("Only call run once");
180
181        select_biased! {
182            _ = askpass_opened_rx.fuse() => {
183                // Note: this await can only resolve after we are dropped.
184                askpass_kill_master_rx.await.ok();
185                AskPassResult::CancelledByUser
186            }
187
188            _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
189                AskPassResult::Timedout
190            }
191        }
192    }
193
194    /// This will return the password that was last set by the askpass script.
195    #[cfg(target_os = "windows")]
196    pub fn get_password(&self) -> Option<EncryptedPassword> {
197        self.secret.get().cloned()
198    }
199}
200
201/// The main function for when Zed is running in netcat mode for use in askpass.
202/// Called from both the remote server binary and the zed binary in their respective main functions.
203pub fn main(socket: &str) {
204    use net::UnixStream;
205    use std::io::{self, Read, Write};
206    use std::process::exit;
207
208    let mut stream = match UnixStream::connect(socket) {
209        Ok(stream) => stream,
210        Err(err) => {
211            eprintln!("Error connecting to socket {}: {}", socket, err);
212            exit(1);
213        }
214    };
215
216    let mut buffer = Vec::new();
217    if let Err(err) = io::stdin().read_to_end(&mut buffer) {
218        eprintln!("Error reading from stdin: {}", err);
219        exit(1);
220    }
221
222    #[cfg(target_os = "windows")]
223    while buffer.last().is_some_and(|&b| b == b'\n' || b == b'\r') {
224        buffer.pop();
225    }
226    if buffer.last() != Some(&b'\0') {
227        buffer.push(b'\0');
228    }
229
230    if let Err(err) = stream.write_all(&buffer) {
231        eprintln!("Error writing to socket: {}", err);
232        exit(1);
233    }
234
235    let mut response = Vec::new();
236    if let Err(err) = stream.read_to_end(&mut response) {
237        eprintln!("Error reading from socket: {}", err);
238        exit(1);
239    }
240
241    if let Err(err) = io::stdout().write_all(&response) {
242        eprintln!("Error writing to stdout: {}", err);
243        exit(1);
244    }
245}
246
247#[inline]
248#[cfg(not(target_os = "windows"))]
249fn generate_askpass_script(zed_cli_path: &str, askpass_socket: &std::path::Path) -> String {
250    format!(
251        "{shebang}\n{print_args} | {zed_cli} --askpass={askpass_socket} 2> /dev/null \n",
252        zed_cli = zed_cli_path,
253        askpass_socket = askpass_socket.display(),
254        print_args = "printf '%s\\0' \"$@\"",
255        shebang = "#!/bin/sh",
256    )
257}
258
259#[inline]
260#[cfg(target_os = "windows")]
261fn generate_askpass_script(zed_cli_path: &str, askpass_socket: &std::path::Path) -> String {
262    format!(
263        r#"
264        $ErrorActionPreference = 'Stop';
265        ($args -join [char]0) | & "{zed_cli}" --askpass={askpass_socket} 2> $null
266        "#,
267        zed_cli = zed_cli_path,
268        askpass_socket = askpass_socket.display(),
269    )
270}