askpass.rs

  1use std::{ffi::OsStr, time::Duration};
  2
  3use anyhow::{Context as _, Result};
  4use futures::channel::{mpsc, oneshot};
  5use futures::{
  6    AsyncBufReadExt as _, AsyncWriteExt as _, FutureExt as _, SinkExt, StreamExt, io::BufReader,
  7    select_biased,
  8};
  9use gpui::{AsyncApp, BackgroundExecutor, Task};
 10use smol::fs;
 11use util::ResultExt as _;
 12
 13#[derive(PartialEq, Eq)]
 14pub enum AskPassResult {
 15    CancelledByUser,
 16    Timedout,
 17}
 18
 19pub struct AskPassDelegate {
 20    tx: mpsc::UnboundedSender<(String, oneshot::Sender<String>)>,
 21    _task: Task<()>,
 22}
 23
 24impl AskPassDelegate {
 25    pub fn new(
 26        cx: &mut AsyncApp,
 27        password_prompt: impl Fn(String, oneshot::Sender<String>, &mut AsyncApp) + Send + Sync + 'static,
 28    ) -> Self {
 29        let (tx, mut rx) = mpsc::unbounded::<(String, oneshot::Sender<String>)>();
 30        let task = cx.spawn(async move |cx: &mut AsyncApp| {
 31            while let Some((prompt, channel)) = rx.next().await {
 32                password_prompt(prompt, channel, cx);
 33            }
 34        });
 35        Self { tx, _task: task }
 36    }
 37
 38    pub async fn ask_password(&mut self, prompt: String) -> Result<String> {
 39        let (tx, rx) = oneshot::channel();
 40        self.tx.send((prompt, tx)).await?;
 41        Ok(rx.await?)
 42    }
 43
 44    pub fn new_always_failing() -> Self {
 45        let (tx, _rx) = mpsc::unbounded::<(String, oneshot::Sender<String>)>();
 46        Self {
 47            tx,
 48            _task: Task::ready(()),
 49        }
 50    }
 51}
 52
 53pub struct AskPassSession {
 54    #[cfg(not(target_os = "windows"))]
 55    script_path: std::path::PathBuf,
 56    #[cfg(not(target_os = "windows"))]
 57    gpg_script_path: std::path::PathBuf,
 58    #[cfg(target_os = "windows")]
 59    askpass_helper: String,
 60    #[cfg(target_os = "windows")]
 61    secret: std::sync::Arc<parking_lot::Mutex<String>>,
 62    _askpass_task: Task<()>,
 63    askpass_opened_rx: Option<oneshot::Receiver<()>>,
 64    askpass_kill_master_rx: Option<oneshot::Receiver<()>>,
 65}
 66
 67#[cfg(not(target_os = "windows"))]
 68const ASKPASS_SCRIPT_NAME: &str = "askpass.sh";
 69#[cfg(target_os = "windows")]
 70const ASKPASS_SCRIPT_NAME: &str = "askpass.ps1";
 71
 72#[cfg(not(target_os = "windows"))]
 73const GPG_SCRIPT_NAME: &str = "gpg.sh";
 74
 75impl AskPassSession {
 76    /// This will create a new AskPassSession.
 77    /// You must retain this session until the master process exits.
 78    #[must_use]
 79    pub async fn new(executor: &BackgroundExecutor, mut delegate: AskPassDelegate) -> Result<Self> {
 80        use net::async_net::UnixListener;
 81        use util::fs::make_file_executable;
 82
 83        #[cfg(target_os = "windows")]
 84        let secret = std::sync::Arc::new(parking_lot::Mutex::new(String::new()));
 85        let temp_dir = tempfile::Builder::new().prefix("zed-askpass").tempdir()?;
 86        let askpass_socket = temp_dir.path().join("askpass.sock");
 87        let askpass_script_path = temp_dir.path().join(ASKPASS_SCRIPT_NAME);
 88        #[cfg(not(target_os = "windows"))]
 89        let gpg_script_path = temp_dir.path().join(GPG_SCRIPT_NAME);
 90        let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
 91        let listener = UnixListener::bind(&askpass_socket).context("creating askpass socket")?;
 92        #[cfg(not(target_os = "windows"))]
 93        let zed_path = util::get_shell_safe_zed_path()?;
 94        #[cfg(target_os = "windows")]
 95        let zed_path = std::env::current_exe()
 96            .context("finding current executable path for use in askpass")?;
 97
 98        let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<()>();
 99        let mut kill_tx = Some(askpass_kill_master_tx);
100
101        #[cfg(target_os = "windows")]
102        let askpass_secret = secret.clone();
103        let askpass_task = executor.spawn(async move {
104            let mut askpass_opened_tx = Some(askpass_opened_tx);
105
106            while let Ok((mut stream, _)) = listener.accept().await {
107                if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
108                    askpass_opened_tx.send(()).ok();
109                }
110                let mut buffer = Vec::new();
111                let mut reader = BufReader::new(&mut stream);
112                if reader.read_until(b'\0', &mut buffer).await.is_err() {
113                    buffer.clear();
114                }
115                let prompt = String::from_utf8_lossy(&buffer);
116                if let Some(password) = delegate
117                    .ask_password(prompt.to_string())
118                    .await
119                    .context("getting askpass password")
120                    .log_err()
121                {
122                    stream.write_all(password.as_bytes()).await.log_err();
123                    #[cfg(target_os = "windows")]
124                    {
125                        *askpass_secret.lock() = password;
126                    }
127                } else {
128                    if let Some(kill_tx) = kill_tx.take() {
129                        kill_tx.send(()).log_err();
130                    }
131                    // note: we expect the caller to drop this task when it's done.
132                    // We need to keep the stream open until the caller is done to avoid
133                    // spurious errors from ssh.
134                    std::future::pending::<()>().await;
135                    drop(stream);
136                }
137            }
138            drop(temp_dir)
139        });
140
141        // Create an askpass script that communicates back to this process.
142        let askpass_script = generate_askpass_script(&zed_path, &askpass_socket);
143        fs::write(&askpass_script_path, askpass_script)
144            .await
145            .with_context(|| format!("creating askpass script at {askpass_script_path:?}"))?;
146        make_file_executable(&askpass_script_path).await?;
147        #[cfg(target_os = "windows")]
148        let askpass_helper = format!(
149            "powershell.exe -ExecutionPolicy Bypass -File {}",
150            askpass_script_path.display()
151        );
152
153        #[cfg(not(target_os = "windows"))]
154        {
155            let gpg_script = generate_gpg_script();
156            fs::write(&gpg_script_path, gpg_script)
157                .await
158                .with_context(|| format!("creating gpg wrapper script at {gpg_script_path:?}"))?;
159            make_file_executable(&gpg_script_path).await?;
160        }
161
162        Ok(Self {
163            #[cfg(not(target_os = "windows"))]
164            script_path: askpass_script_path,
165            #[cfg(not(target_os = "windows"))]
166            gpg_script_path,
167
168            #[cfg(target_os = "windows")]
169            secret,
170            #[cfg(target_os = "windows")]
171            askpass_helper,
172
173            _askpass_task: askpass_task,
174            askpass_kill_master_rx: Some(askpass_kill_master_rx),
175            askpass_opened_rx: Some(askpass_opened_rx),
176        })
177    }
178
179    #[cfg(not(target_os = "windows"))]
180    pub fn script_path(&self) -> impl AsRef<OsStr> {
181        &self.script_path
182    }
183
184    #[cfg(target_os = "windows")]
185    pub fn script_path(&self) -> impl AsRef<OsStr> {
186        &self.askpass_helper
187    }
188
189    #[cfg(not(target_os = "windows"))]
190    pub fn gpg_script_path(&self) -> Option<impl AsRef<OsStr>> {
191        Some(&self.gpg_script_path)
192    }
193
194    #[cfg(target_os = "windows")]
195    pub fn gpg_script_path(&self) -> Option<impl AsRef<OsStr>> {
196        // TODO implement wrapping GPG on Windows. This is more difficult than on Unix
197        // because we can't use --passphrase-fd with a nonstandard FD, and both --passphrase
198        // and --passphrase-file are insecure.
199        None::<std::path::PathBuf>
200    }
201
202    // This will run the askpass task forever, resolving as many authentication requests as needed.
203    // The caller is responsible for examining the result of their own commands and cancelling this
204    // future when this is no longer needed. Note that this can only be called once, but due to the
205    // drop order this takes an &mut, so you can `drop()` it after you're done with the master process.
206    pub async fn run(&mut self) -> AskPassResult {
207        // This is the default timeout setting used by VSCode.
208        let connection_timeout = Duration::from_secs(17);
209        let askpass_opened_rx = self.askpass_opened_rx.take().expect("Only call run once");
210        let askpass_kill_master_rx = self
211            .askpass_kill_master_rx
212            .take()
213            .expect("Only call run once");
214
215        select_biased! {
216            _ = askpass_opened_rx.fuse() => {
217                // Note: this await can only resolve after we are dropped.
218                askpass_kill_master_rx.await.ok();
219                return AskPassResult::CancelledByUser
220            }
221
222            _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
223                return AskPassResult::Timedout
224            }
225        }
226    }
227
228    /// This will return the password that was last set by the askpass script.
229    #[cfg(target_os = "windows")]
230    pub fn get_password(&self) -> String {
231        self.secret.lock().clone()
232    }
233}
234
235/// The main function for when Zed is running in netcat mode for use in askpass.
236/// Called from both the remote server binary and the zed binary in their respective main functions.
237pub fn main(socket: &str) {
238    use net::UnixStream;
239    use std::io::{self, Read, Write};
240    use std::process::exit;
241
242    let mut stream = match UnixStream::connect(socket) {
243        Ok(stream) => stream,
244        Err(err) => {
245            eprintln!("Error connecting to socket {}: {}", socket, err);
246            exit(1);
247        }
248    };
249
250    let mut buffer = Vec::new();
251    if let Err(err) = io::stdin().read_to_end(&mut buffer) {
252        eprintln!("Error reading from stdin: {}", err);
253        exit(1);
254    }
255
256    #[cfg(target_os = "windows")]
257    while buffer.last().map_or(false, |&b| b == b'\n' || b == b'\r') {
258        buffer.pop();
259    }
260    if buffer.last() != Some(&b'\0') {
261        buffer.push(b'\0');
262    }
263
264    if let Err(err) = stream.write_all(&buffer) {
265        eprintln!("Error writing to socket: {}", err);
266        exit(1);
267    }
268
269    let mut response = Vec::new();
270    if let Err(err) = stream.read_to_end(&mut response) {
271        eprintln!("Error reading from socket: {}", err);
272        exit(1);
273    }
274
275    if let Err(err) = io::stdout().write_all(&response) {
276        eprintln!("Error writing to stdout: {}", err);
277        exit(1);
278    }
279}
280
281#[inline]
282#[cfg(not(target_os = "windows"))]
283fn generate_askpass_script(zed_path: &str, askpass_socket: &std::path::Path) -> String {
284    format!(
285        "{shebang}\n{print_args} | {zed_exe} --askpass={askpass_socket} 2> /dev/null \n",
286        zed_exe = zed_path,
287        askpass_socket = askpass_socket.display(),
288        print_args = "printf '%s\\0' \"$@\"",
289        shebang = "#!/bin/sh",
290    )
291}
292
293#[inline]
294#[cfg(target_os = "windows")]
295fn generate_askpass_script(zed_path: &std::path::Path, askpass_socket: &std::path::Path) -> String {
296    format!(
297        r#"
298        $ErrorActionPreference = 'Stop';
299        ($args -join [char]0) | & "{zed_exe}" --askpass={askpass_socket} 2> $null
300        "#,
301        zed_exe = zed_path.display(),
302        askpass_socket = askpass_socket.display(),
303    )
304}
305
306#[inline]
307#[cfg(not(target_os = "windows"))]
308fn generate_gpg_script() -> String {
309    use unindent::Unindent as _;
310
311    r#"
312        #!/bin/sh
313        set -eu
314
315        unset GIT_CONFIG_PARAMETERS
316        GPG_PROGRAM=$(git config gpg.program || echo 'gpg')
317        PROMPT="Enter passphrase to unlock GPG key:"
318        PASSPHRASE=$(${GIT_ASKPASS} "${PROMPT}")
319
320        exec "${GPG_PROGRAM}" --batch --no-tty --yes --passphrase-fd 3 --pinentry-mode loopback "$@" 3<<EOF
321        ${PASSPHRASE}
322        EOF
323    "#.unindent()
324}