askpass.rs

  1use std::{ffi::OsStr, time::Duration};
  2
  3use anyhow::{Context as _, Result};
  4use futures::channel::{mpsc, oneshot};
  5use futures::{
  6    AsyncBufReadExt as _, AsyncWriteExt as _, FutureExt as _, SinkExt, StreamExt, io::BufReader,
  7    select_biased,
  8};
  9use gpui::{AsyncApp, BackgroundExecutor, Task};
 10use smol::fs;
 11use util::ResultExt as _;
 12
 13#[derive(PartialEq, Eq)]
 14pub enum AskPassResult {
 15    CancelledByUser,
 16    Timedout,
 17}
 18
 19pub struct AskPassDelegate {
 20    tx: mpsc::UnboundedSender<(String, oneshot::Sender<String>)>,
 21    _task: Task<()>,
 22}
 23
 24impl AskPassDelegate {
 25    pub fn new(
 26        cx: &mut AsyncApp,
 27        password_prompt: impl Fn(String, oneshot::Sender<String>, &mut AsyncApp) + Send + Sync + 'static,
 28    ) -> Self {
 29        let (tx, mut rx) = mpsc::unbounded::<(String, oneshot::Sender<String>)>();
 30        let task = cx.spawn(async move |cx: &mut AsyncApp| {
 31            while let Some((prompt, channel)) = rx.next().await {
 32                password_prompt(prompt, channel, cx);
 33            }
 34        });
 35        Self { tx, _task: task }
 36    }
 37
 38    pub async fn ask_password(&mut self, prompt: String) -> Result<String> {
 39        let (tx, rx) = oneshot::channel();
 40        self.tx.send((prompt, tx)).await?;
 41        Ok(rx.await?)
 42    }
 43}
 44
 45pub struct AskPassSession {
 46    #[cfg(not(target_os = "windows"))]
 47    script_path: std::path::PathBuf,
 48    #[cfg(target_os = "windows")]
 49    askpass_helper: String,
 50    #[cfg(target_os = "windows")]
 51    secret: std::sync::Arc<parking_lot::Mutex<String>>,
 52    _askpass_task: Task<()>,
 53    askpass_opened_rx: Option<oneshot::Receiver<()>>,
 54    askpass_kill_master_rx: Option<oneshot::Receiver<()>>,
 55}
 56
 57#[cfg(not(target_os = "windows"))]
 58const ASKPASS_SCRIPT_NAME: &str = "askpass.sh";
 59#[cfg(target_os = "windows")]
 60const ASKPASS_SCRIPT_NAME: &str = "askpass.ps1";
 61
 62impl AskPassSession {
 63    /// This will create a new AskPassSession.
 64    /// You must retain this session until the master process exits.
 65    #[must_use]
 66    pub async fn new(executor: &BackgroundExecutor, mut delegate: AskPassDelegate) -> Result<Self> {
 67        use net::async_net::UnixListener;
 68        use util::fs::make_file_executable;
 69
 70        #[cfg(target_os = "windows")]
 71        let secret = std::sync::Arc::new(parking_lot::Mutex::new(String::new()));
 72        let temp_dir = tempfile::Builder::new().prefix("zed-askpass").tempdir()?;
 73        let askpass_socket = temp_dir.path().join("askpass.sock");
 74        let askpass_script_path = temp_dir.path().join(ASKPASS_SCRIPT_NAME);
 75        let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
 76        let listener = UnixListener::bind(&askpass_socket).context("creating askpass socket")?;
 77        #[cfg(not(target_os = "windows"))]
 78        let zed_path = util::get_shell_safe_zed_path()?;
 79        #[cfg(target_os = "windows")]
 80        let zed_path = std::env::current_exe()
 81            .context("finding current executable path for use in askpass")?;
 82
 83        let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<()>();
 84        let mut kill_tx = Some(askpass_kill_master_tx);
 85
 86        #[cfg(target_os = "windows")]
 87        let askpass_secret = secret.clone();
 88        let askpass_task = executor.spawn(async move {
 89            let mut askpass_opened_tx = Some(askpass_opened_tx);
 90
 91            while let Ok((mut stream, _)) = listener.accept().await {
 92                if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
 93                    askpass_opened_tx.send(()).ok();
 94                }
 95                let mut buffer = Vec::new();
 96                let mut reader = BufReader::new(&mut stream);
 97                if reader.read_until(b'\0', &mut buffer).await.is_err() {
 98                    buffer.clear();
 99                }
100                let prompt = String::from_utf8_lossy(&buffer);
101                if let Some(password) = delegate
102                    .ask_password(prompt.to_string())
103                    .await
104                    .context("getting askpass password")
105                    .log_err()
106                {
107                    stream.write_all(password.as_bytes()).await.log_err();
108                    #[cfg(target_os = "windows")]
109                    {
110                        *askpass_secret.lock() = password;
111                    }
112                } else {
113                    if let Some(kill_tx) = kill_tx.take() {
114                        kill_tx.send(()).log_err();
115                    }
116                    // note: we expect the caller to drop this task when it's done.
117                    // We need to keep the stream open until the caller is done to avoid
118                    // spurious errors from ssh.
119                    std::future::pending::<()>().await;
120                    drop(stream);
121                }
122            }
123            drop(temp_dir)
124        });
125
126        // Create an askpass script that communicates back to this process.
127        let askpass_script = generate_askpass_script(&zed_path, &askpass_socket);
128        fs::write(&askpass_script_path, askpass_script)
129            .await
130            .with_context(|| format!("creating askpass script at {askpass_script_path:?}"))?;
131        make_file_executable(&askpass_script_path).await?;
132        #[cfg(target_os = "windows")]
133        let askpass_helper = format!(
134            "powershell.exe -ExecutionPolicy Bypass -File {}",
135            askpass_script_path.display()
136        );
137
138        Ok(Self {
139            #[cfg(not(target_os = "windows"))]
140            script_path: askpass_script_path,
141
142            #[cfg(target_os = "windows")]
143            secret,
144            #[cfg(target_os = "windows")]
145            askpass_helper,
146
147            _askpass_task: askpass_task,
148            askpass_kill_master_rx: Some(askpass_kill_master_rx),
149            askpass_opened_rx: Some(askpass_opened_rx),
150        })
151    }
152
153    #[cfg(not(target_os = "windows"))]
154    pub fn script_path(&self) -> impl AsRef<OsStr> {
155        &self.script_path
156    }
157
158    #[cfg(target_os = "windows")]
159    pub fn script_path(&self) -> impl AsRef<OsStr> {
160        &self.askpass_helper
161    }
162
163    // This will run the askpass task forever, resolving as many authentication requests as needed.
164    // The caller is responsible for examining the result of their own commands and cancelling this
165    // future when this is no longer needed. Note that this can only be called once, but due to the
166    // drop order this takes an &mut, so you can `drop()` it after you're done with the master process.
167    pub async fn run(&mut self) -> AskPassResult {
168        // This is the default timeout setting used by VSCode.
169        let connection_timeout = Duration::from_secs(17);
170        let askpass_opened_rx = self.askpass_opened_rx.take().expect("Only call run once");
171        let askpass_kill_master_rx = self
172            .askpass_kill_master_rx
173            .take()
174            .expect("Only call run once");
175
176        select_biased! {
177            _ = askpass_opened_rx.fuse() => {
178                // Note: this await can only resolve after we are dropped.
179                askpass_kill_master_rx.await.ok();
180                AskPassResult::CancelledByUser
181            }
182
183            _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
184                AskPassResult::Timedout
185            }
186        }
187    }
188
189    /// This will return the password that was last set by the askpass script.
190    #[cfg(target_os = "windows")]
191    pub fn get_password(&self) -> String {
192        self.secret.lock().clone()
193    }
194}
195
196/// The main function for when Zed is running in netcat mode for use in askpass.
197/// Called from both the remote server binary and the zed binary in their respective main functions.
198pub fn main(socket: &str) {
199    use net::UnixStream;
200    use std::io::{self, Read, Write};
201    use std::process::exit;
202
203    let mut stream = match UnixStream::connect(socket) {
204        Ok(stream) => stream,
205        Err(err) => {
206            eprintln!("Error connecting to socket {}: {}", socket, err);
207            exit(1);
208        }
209    };
210
211    let mut buffer = Vec::new();
212    if let Err(err) = io::stdin().read_to_end(&mut buffer) {
213        eprintln!("Error reading from stdin: {}", err);
214        exit(1);
215    }
216
217    #[cfg(target_os = "windows")]
218    while buffer.last().is_some_and(|&b| b == b'\n' || b == b'\r') {
219        buffer.pop();
220    }
221    if buffer.last() != Some(&b'\0') {
222        buffer.push(b'\0');
223    }
224
225    if let Err(err) = stream.write_all(&buffer) {
226        eprintln!("Error writing to socket: {}", err);
227        exit(1);
228    }
229
230    let mut response = Vec::new();
231    if let Err(err) = stream.read_to_end(&mut response) {
232        eprintln!("Error reading from socket: {}", err);
233        exit(1);
234    }
235
236    if let Err(err) = io::stdout().write_all(&response) {
237        eprintln!("Error writing to stdout: {}", err);
238        exit(1);
239    }
240}
241
242#[inline]
243#[cfg(not(target_os = "windows"))]
244fn generate_askpass_script(zed_path: &str, askpass_socket: &std::path::Path) -> String {
245    format!(
246        "{shebang}\n{print_args} | {zed_exe} --askpass={askpass_socket} 2> /dev/null \n",
247        zed_exe = zed_path,
248        askpass_socket = askpass_socket.display(),
249        print_args = "printf '%s\\0' \"$@\"",
250        shebang = "#!/bin/sh",
251    )
252}
253
254#[inline]
255#[cfg(target_os = "windows")]
256fn generate_askpass_script(zed_path: &std::path::Path, askpass_socket: &std::path::Path) -> String {
257    format!(
258        r#"
259        $ErrorActionPreference = 'Stop';
260        ($args -join [char]0) | & "{zed_exe}" --askpass={askpass_socket} 2> $null
261        "#,
262        zed_exe = zed_path.display(),
263        askpass_socket = askpass_socket.display(),
264    )
265}