askpass.rs

  1mod encrypted_password;
  2
  3pub use encrypted_password::{EncryptedPassword, ProcessExt};
  4
  5#[cfg(target_os = "windows")]
  6use std::sync::OnceLock;
  7use std::{ffi::OsStr, time::Duration};
  8
  9use anyhow::{Context as _, Result};
 10use futures::channel::{mpsc, oneshot};
 11use futures::{
 12    AsyncBufReadExt as _, AsyncWriteExt as _, FutureExt as _, SinkExt, StreamExt, io::BufReader,
 13    select_biased,
 14};
 15use gpui::{AsyncApp, BackgroundExecutor, Task};
 16use smol::fs;
 17use util::ResultExt as _;
 18
 19use crate::encrypted_password::decrypt;
 20
 21#[derive(PartialEq, Eq)]
 22pub enum AskPassResult {
 23    CancelledByUser,
 24    Timedout,
 25}
 26
 27pub struct AskPassDelegate {
 28    tx: mpsc::UnboundedSender<(String, oneshot::Sender<EncryptedPassword>)>,
 29    _task: Task<()>,
 30}
 31
 32impl AskPassDelegate {
 33    pub fn new(
 34        cx: &mut AsyncApp,
 35        password_prompt: impl Fn(String, oneshot::Sender<EncryptedPassword>, &mut AsyncApp)
 36        + Send
 37        + Sync
 38        + 'static,
 39    ) -> Self {
 40        let (tx, mut rx) = mpsc::unbounded::<(String, oneshot::Sender<_>)>();
 41        let task = cx.spawn(async move |cx: &mut AsyncApp| {
 42            while let Some((prompt, channel)) = rx.next().await {
 43                password_prompt(prompt, channel, cx);
 44            }
 45        });
 46        Self { tx, _task: task }
 47    }
 48
 49    pub async fn ask_password(&mut self, prompt: String) -> Result<EncryptedPassword> {
 50        let (tx, rx) = oneshot::channel();
 51        self.tx.send((prompt, tx)).await?;
 52        Ok(rx.await?)
 53    }
 54}
 55
 56pub struct AskPassSession {
 57    #[cfg(not(target_os = "windows"))]
 58    script_path: std::path::PathBuf,
 59    #[cfg(target_os = "windows")]
 60    askpass_helper: String,
 61    #[cfg(target_os = "windows")]
 62    secret: std::sync::Arc<OnceLock<EncryptedPassword>>,
 63    _askpass_task: Task<()>,
 64    askpass_opened_rx: Option<oneshot::Receiver<()>>,
 65    askpass_kill_master_rx: Option<oneshot::Receiver<()>>,
 66}
 67
 68#[cfg(not(target_os = "windows"))]
 69const ASKPASS_SCRIPT_NAME: &str = "askpass.sh";
 70#[cfg(target_os = "windows")]
 71const ASKPASS_SCRIPT_NAME: &str = "askpass.ps1";
 72
 73impl AskPassSession {
 74    /// This will create a new AskPassSession.
 75    /// You must retain this session until the master process exits.
 76    #[must_use]
 77    pub async fn new(executor: &BackgroundExecutor, mut delegate: AskPassDelegate) -> Result<Self> {
 78        use net::async_net::UnixListener;
 79        use util::fs::make_file_executable;
 80
 81        #[cfg(target_os = "windows")]
 82        let secret = std::sync::Arc::new(OnceLock::new());
 83        let temp_dir = tempfile::Builder::new().prefix("zed-askpass").tempdir()?;
 84        let askpass_socket = temp_dir.path().join("askpass.sock");
 85        let askpass_script_path = temp_dir.path().join(ASKPASS_SCRIPT_NAME);
 86        let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
 87        let listener = UnixListener::bind(&askpass_socket).context("creating askpass socket")?;
 88        #[cfg(not(target_os = "windows"))]
 89        let zed_path = util::get_shell_safe_zed_path()?;
 90        #[cfg(target_os = "windows")]
 91        let zed_path = std::env::current_exe()
 92            .context("finding current executable path for use in askpass")?;
 93
 94        let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<()>();
 95        let mut kill_tx = Some(askpass_kill_master_tx);
 96
 97        #[cfg(target_os = "windows")]
 98        let askpass_secret = secret.clone();
 99        let askpass_task = executor.spawn(async move {
100            let mut askpass_opened_tx = Some(askpass_opened_tx);
101
102            while let Ok((mut stream, _)) = listener.accept().await {
103                if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
104                    askpass_opened_tx.send(()).ok();
105                }
106                let mut buffer = Vec::new();
107                let mut reader = BufReader::new(&mut stream);
108                if reader.read_until(b'\0', &mut buffer).await.is_err() {
109                    buffer.clear();
110                }
111                let prompt = String::from_utf8_lossy(&buffer);
112                if let Some(password) = delegate
113                    .ask_password(prompt.to_string())
114                    .await
115                    .context("getting askpass password")
116                    .log_err()
117                {
118                    #[cfg(target_os = "windows")]
119                    {
120                        askpass_secret.get_or_init(|| password.clone());
121                    }
122                    if let Ok(decrypted) = decrypt(password) {
123                        stream.write_all(decrypted.as_bytes()).await.log_err();
124                    }
125                } else {
126                    if let Some(kill_tx) = kill_tx.take() {
127                        kill_tx.send(()).log_err();
128                    }
129                    // note: we expect the caller to drop this task when it's done.
130                    // We need to keep the stream open until the caller is done to avoid
131                    // spurious errors from ssh.
132                    std::future::pending::<()>().await;
133                    drop(stream);
134                }
135            }
136            drop(temp_dir)
137        });
138
139        // Create an askpass script that communicates back to this process.
140        let askpass_script = generate_askpass_script(&zed_path, &askpass_socket);
141        fs::write(&askpass_script_path, askpass_script)
142            .await
143            .with_context(|| format!("creating askpass script at {askpass_script_path:?}"))?;
144        make_file_executable(&askpass_script_path).await?;
145        #[cfg(target_os = "windows")]
146        let askpass_helper = format!(
147            "powershell.exe -ExecutionPolicy Bypass -File {}",
148            askpass_script_path.display()
149        );
150
151        Ok(Self {
152            #[cfg(not(target_os = "windows"))]
153            script_path: askpass_script_path,
154
155            #[cfg(target_os = "windows")]
156            secret,
157            #[cfg(target_os = "windows")]
158            askpass_helper,
159
160            _askpass_task: askpass_task,
161            askpass_kill_master_rx: Some(askpass_kill_master_rx),
162            askpass_opened_rx: Some(askpass_opened_rx),
163        })
164    }
165
166    #[cfg(not(target_os = "windows"))]
167    pub fn script_path(&self) -> impl AsRef<OsStr> {
168        &self.script_path
169    }
170
171    #[cfg(target_os = "windows")]
172    pub fn script_path(&self) -> impl AsRef<OsStr> {
173        &self.askpass_helper
174    }
175
176    // This will run the askpass task forever, resolving as many authentication requests as needed.
177    // The caller is responsible for examining the result of their own commands and cancelling this
178    // future when this is no longer needed. Note that this can only be called once, but due to the
179    // drop order this takes an &mut, so you can `drop()` it after you're done with the master process.
180    pub async fn run(&mut self) -> AskPassResult {
181        // This is the default timeout setting used by VSCode.
182        let connection_timeout = Duration::from_secs(17);
183        let askpass_opened_rx = self.askpass_opened_rx.take().expect("Only call run once");
184        let askpass_kill_master_rx = self
185            .askpass_kill_master_rx
186            .take()
187            .expect("Only call run once");
188
189        select_biased! {
190            _ = askpass_opened_rx.fuse() => {
191                // Note: this await can only resolve after we are dropped.
192                askpass_kill_master_rx.await.ok();
193                AskPassResult::CancelledByUser
194            }
195
196            _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
197                AskPassResult::Timedout
198            }
199        }
200    }
201
202    /// This will return the password that was last set by the askpass script.
203    #[cfg(target_os = "windows")]
204    pub fn get_password(&self) -> Option<EncryptedPassword> {
205        self.secret.get().cloned()
206    }
207}
208
209/// The main function for when Zed is running in netcat mode for use in askpass.
210/// Called from both the remote server binary and the zed binary in their respective main functions.
211pub fn main(socket: &str) {
212    use net::UnixStream;
213    use std::io::{self, Read, Write};
214    use std::process::exit;
215
216    let mut stream = match UnixStream::connect(socket) {
217        Ok(stream) => stream,
218        Err(err) => {
219            eprintln!("Error connecting to socket {}: {}", socket, err);
220            exit(1);
221        }
222    };
223
224    let mut buffer = Vec::new();
225    if let Err(err) = io::stdin().read_to_end(&mut buffer) {
226        eprintln!("Error reading from stdin: {}", err);
227        exit(1);
228    }
229
230    #[cfg(target_os = "windows")]
231    while buffer.last().is_some_and(|&b| b == b'\n' || b == b'\r') {
232        buffer.pop();
233    }
234    if buffer.last() != Some(&b'\0') {
235        buffer.push(b'\0');
236    }
237
238    if let Err(err) = stream.write_all(&buffer) {
239        eprintln!("Error writing to socket: {}", err);
240        exit(1);
241    }
242
243    let mut response = Vec::new();
244    if let Err(err) = stream.read_to_end(&mut response) {
245        eprintln!("Error reading from socket: {}", err);
246        exit(1);
247    }
248
249    if let Err(err) = io::stdout().write_all(&response) {
250        eprintln!("Error writing to stdout: {}", err);
251        exit(1);
252    }
253}
254
255#[inline]
256#[cfg(not(target_os = "windows"))]
257fn generate_askpass_script(zed_path: &str, askpass_socket: &std::path::Path) -> String {
258    format!(
259        "{shebang}\n{print_args} | {zed_exe} --askpass={askpass_socket} 2> /dev/null \n",
260        zed_exe = zed_path,
261        askpass_socket = askpass_socket.display(),
262        print_args = "printf '%s\\0' \"$@\"",
263        shebang = "#!/bin/sh",
264    )
265}
266
267#[inline]
268#[cfg(target_os = "windows")]
269fn generate_askpass_script(zed_path: &std::path::Path, askpass_socket: &std::path::Path) -> String {
270    format!(
271        r#"
272        $ErrorActionPreference = 'Stop';
273        ($args -join [char]0) | & "{zed_exe}" --askpass={askpass_socket} 2> $null
274        "#,
275        zed_exe = zed_path.display(),
276        askpass_socket = askpass_socket.display(),
277    )
278}