askpass.rs

  1mod encrypted_password;
  2
  3pub use encrypted_password::{EncryptedPassword, ProcessExt};
  4
  5#[cfg(target_os = "windows")]
  6use std::sync::OnceLock;
  7use std::{ffi::OsStr, time::Duration};
  8
  9use anyhow::{Context as _, Result};
 10use futures::channel::{mpsc, oneshot};
 11use futures::{
 12    AsyncBufReadExt as _, AsyncWriteExt as _, FutureExt as _, SinkExt, StreamExt, io::BufReader,
 13    select_biased,
 14};
 15use gpui::{AsyncApp, BackgroundExecutor, Task};
 16use smol::fs;
 17use util::ResultExt as _;
 18
 19use crate::encrypted_password::decrypt;
 20
 21#[derive(PartialEq, Eq)]
 22pub enum AskPassResult {
 23    CancelledByUser,
 24    Timedout,
 25}
 26
 27pub struct AskPassDelegate {
 28    tx: mpsc::UnboundedSender<(String, oneshot::Sender<EncryptedPassword>)>,
 29    _task: Task<()>,
 30}
 31
 32impl AskPassDelegate {
 33    pub fn new(
 34        cx: &mut AsyncApp,
 35        password_prompt: impl Fn(String, oneshot::Sender<EncryptedPassword>, &mut AsyncApp)
 36        + Send
 37        + Sync
 38        + 'static,
 39    ) -> Self {
 40        let (tx, mut rx) = mpsc::unbounded::<(String, oneshot::Sender<_>)>();
 41        let task = cx.spawn(async move |cx: &mut AsyncApp| {
 42            while let Some((prompt, channel)) = rx.next().await {
 43                password_prompt(prompt, channel, cx);
 44            }
 45        });
 46        Self { tx, _task: task }
 47    }
 48
 49    pub async fn ask_password(&mut self, prompt: String) -> Result<EncryptedPassword> {
 50        let (tx, rx) = oneshot::channel();
 51        self.tx.send((prompt, tx)).await?;
 52        Ok(rx.await?)
 53    }
 54}
 55
 56pub struct AskPassSession {
 57    #[cfg(not(target_os = "windows"))]
 58    script_path: std::path::PathBuf,
 59    #[cfg(target_os = "windows")]
 60    askpass_helper: String,
 61    #[cfg(target_os = "windows")]
 62    secret: std::sync::Arc<OnceLock<EncryptedPassword>>,
 63    _askpass_task: Task<()>,
 64    askpass_opened_rx: Option<oneshot::Receiver<()>>,
 65    askpass_kill_master_rx: Option<oneshot::Receiver<()>>,
 66}
 67
 68#[cfg(not(target_os = "windows"))]
 69const ASKPASS_SCRIPT_NAME: &str = "askpass.sh";
 70#[cfg(target_os = "windows")]
 71const ASKPASS_SCRIPT_NAME: &str = "askpass.ps1";
 72
 73impl AskPassSession {
 74    /// This will create a new AskPassSession.
 75    /// You must retain this session until the master process exits.
 76    #[must_use]
 77    pub async fn new(executor: &BackgroundExecutor, mut delegate: AskPassDelegate) -> Result<Self> {
 78        use net::async_net::UnixListener;
 79        use util::fs::make_file_executable;
 80
 81        #[cfg(target_os = "windows")]
 82        let secret = std::sync::Arc::new(OnceLock::new());
 83        let temp_dir = tempfile::Builder::new().prefix("zed-askpass").tempdir()?;
 84        let askpass_socket = temp_dir.path().join("askpass.sock");
 85        let askpass_script_path = temp_dir.path().join(ASKPASS_SCRIPT_NAME);
 86        let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
 87        let listener = UnixListener::bind(&askpass_socket).context("creating askpass socket")?;
 88        let zed_cli_path =
 89            util::get_shell_safe_zed_cli_path().context("getting zed-cli path for askpass")?;
 90
 91        let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<()>();
 92        let mut kill_tx = Some(askpass_kill_master_tx);
 93
 94        #[cfg(target_os = "windows")]
 95        let askpass_secret = secret.clone();
 96        let askpass_task = executor.spawn(async move {
 97            let mut askpass_opened_tx = Some(askpass_opened_tx);
 98
 99            while let Ok((mut stream, _)) = listener.accept().await {
100                if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
101                    askpass_opened_tx.send(()).ok();
102                }
103                let mut buffer = Vec::new();
104                let mut reader = BufReader::new(&mut stream);
105                if reader.read_until(b'\0', &mut buffer).await.is_err() {
106                    buffer.clear();
107                }
108                let prompt = String::from_utf8_lossy(&buffer);
109                if let Some(password) = delegate
110                    .ask_password(prompt.to_string())
111                    .await
112                    .context("getting askpass password")
113                    .log_err()
114                {
115                    #[cfg(target_os = "windows")]
116                    {
117                        askpass_secret.get_or_init(|| password.clone());
118                    }
119                    if let Ok(decrypted) = decrypt(password) {
120                        stream.write_all(decrypted.as_bytes()).await.log_err();
121                    }
122                } else {
123                    if let Some(kill_tx) = kill_tx.take() {
124                        kill_tx.send(()).log_err();
125                    }
126                    // note: we expect the caller to drop this task when it's done.
127                    // We need to keep the stream open until the caller is done to avoid
128                    // spurious errors from ssh.
129                    std::future::pending::<()>().await;
130                    drop(stream);
131                }
132            }
133            drop(temp_dir)
134        });
135
136        // Create an askpass script that communicates back to this process.
137        let askpass_script = generate_askpass_script(&zed_cli_path, &askpass_socket);
138        fs::write(&askpass_script_path, askpass_script)
139            .await
140            .with_context(|| format!("creating askpass script at {askpass_script_path:?}"))?;
141        make_file_executable(&askpass_script_path).await?;
142        #[cfg(target_os = "windows")]
143        let askpass_helper = format!(
144            "powershell.exe -ExecutionPolicy Bypass -File {}",
145            askpass_script_path.display()
146        );
147
148        Ok(Self {
149            #[cfg(not(target_os = "windows"))]
150            script_path: askpass_script_path,
151
152            #[cfg(target_os = "windows")]
153            secret,
154            #[cfg(target_os = "windows")]
155            askpass_helper,
156
157            _askpass_task: askpass_task,
158            askpass_kill_master_rx: Some(askpass_kill_master_rx),
159            askpass_opened_rx: Some(askpass_opened_rx),
160        })
161    }
162
163    #[cfg(not(target_os = "windows"))]
164    pub fn script_path(&self) -> impl AsRef<OsStr> {
165        &self.script_path
166    }
167
168    #[cfg(target_os = "windows")]
169    pub fn script_path(&self) -> impl AsRef<OsStr> {
170        &self.askpass_helper
171    }
172
173    // This will run the askpass task forever, resolving as many authentication requests as needed.
174    // The caller is responsible for examining the result of their own commands and cancelling this
175    // future when this is no longer needed. Note that this can only be called once, but due to the
176    // drop order this takes an &mut, so you can `drop()` it after you're done with the master process.
177    pub async fn run(&mut self) -> AskPassResult {
178        // This is the default timeout setting used by VSCode.
179        let connection_timeout = Duration::from_secs(17);
180        let askpass_opened_rx = self.askpass_opened_rx.take().expect("Only call run once");
181        let askpass_kill_master_rx = self
182            .askpass_kill_master_rx
183            .take()
184            .expect("Only call run once");
185
186        select_biased! {
187            _ = askpass_opened_rx.fuse() => {
188                // Note: this await can only resolve after we are dropped.
189                askpass_kill_master_rx.await.ok();
190                AskPassResult::CancelledByUser
191            }
192
193            _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
194                AskPassResult::Timedout
195            }
196        }
197    }
198
199    /// This will return the password that was last set by the askpass script.
200    #[cfg(target_os = "windows")]
201    pub fn get_password(&self) -> Option<EncryptedPassword> {
202        self.secret.get().cloned()
203    }
204}
205
206/// The main function for when Zed is running in netcat mode for use in askpass.
207/// Called from both the remote server binary and the zed binary in their respective main functions.
208pub fn main(socket: &str) {
209    use net::UnixStream;
210    use std::io::{self, Read, Write};
211    use std::process::exit;
212
213    let mut stream = match UnixStream::connect(socket) {
214        Ok(stream) => stream,
215        Err(err) => {
216            eprintln!("Error connecting to socket {}: {}", socket, err);
217            exit(1);
218        }
219    };
220
221    let mut buffer = Vec::new();
222    if let Err(err) = io::stdin().read_to_end(&mut buffer) {
223        eprintln!("Error reading from stdin: {}", err);
224        exit(1);
225    }
226
227    #[cfg(target_os = "windows")]
228    while buffer.last().is_some_and(|&b| b == b'\n' || b == b'\r') {
229        buffer.pop();
230    }
231    if buffer.last() != Some(&b'\0') {
232        buffer.push(b'\0');
233    }
234
235    if let Err(err) = stream.write_all(&buffer) {
236        eprintln!("Error writing to socket: {}", err);
237        exit(1);
238    }
239
240    let mut response = Vec::new();
241    if let Err(err) = stream.read_to_end(&mut response) {
242        eprintln!("Error reading from socket: {}", err);
243        exit(1);
244    }
245
246    if let Err(err) = io::stdout().write_all(&response) {
247        eprintln!("Error writing to stdout: {}", err);
248        exit(1);
249    }
250}
251
252#[inline]
253#[cfg(not(target_os = "windows"))]
254fn generate_askpass_script(zed_cli_path: &str, askpass_socket: &std::path::Path) -> String {
255    format!(
256        "{shebang}\n{print_args} | {zed_cli} --askpass={askpass_socket} 2> /dev/null \n",
257        zed_cli = zed_cli_path,
258        askpass_socket = askpass_socket.display(),
259        print_args = "printf '%s\\0' \"$@\"",
260        shebang = "#!/bin/sh",
261    )
262}
263
264#[inline]
265#[cfg(target_os = "windows")]
266fn generate_askpass_script(zed_cli_path: &str, askpass_socket: &std::path::Path) -> String {
267    format!(
268        r#"
269        $ErrorActionPreference = 'Stop';
270        ($args -join [char]0) | & "{zed_cli}" --askpass={askpass_socket} 2> $null
271        "#,
272        zed_cli = zed_cli_path,
273        askpass_socket = askpass_socket.display(),
274    )
275}