askpass.rs

  1use std::path::{Path, PathBuf};
  2use std::time::Duration;
  3
  4#[cfg(unix)]
  5use anyhow::Context as _;
  6use futures::channel::{mpsc, oneshot};
  7#[cfg(unix)]
  8use futures::{AsyncBufReadExt as _, io::BufReader};
  9#[cfg(unix)]
 10use futures::{AsyncWriteExt as _, FutureExt as _, select_biased};
 11use futures::{SinkExt, StreamExt};
 12use gpui::{AsyncApp, BackgroundExecutor, Task};
 13#[cfg(unix)]
 14use smol::fs;
 15#[cfg(unix)]
 16use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
 17#[cfg(unix)]
 18use util::ResultExt as _;
 19
 20#[derive(PartialEq, Eq)]
 21pub enum AskPassResult {
 22    CancelledByUser,
 23    Timedout,
 24}
 25
 26pub struct AskPassDelegate {
 27    tx: mpsc::UnboundedSender<(String, oneshot::Sender<String>)>,
 28    _task: Task<()>,
 29}
 30
 31impl AskPassDelegate {
 32    pub fn new(
 33        cx: &mut AsyncApp,
 34        password_prompt: impl Fn(String, oneshot::Sender<String>, &mut AsyncApp) + Send + Sync + 'static,
 35    ) -> Self {
 36        let (tx, mut rx) = mpsc::unbounded::<(String, oneshot::Sender<String>)>();
 37        let task = cx.spawn(async move |cx: &mut AsyncApp| {
 38            while let Some((prompt, channel)) = rx.next().await {
 39                password_prompt(prompt, channel, cx);
 40            }
 41        });
 42        Self { tx, _task: task }
 43    }
 44
 45    pub async fn ask_password(&mut self, prompt: String) -> anyhow::Result<String> {
 46        let (tx, rx) = oneshot::channel();
 47        self.tx.send((prompt, tx)).await?;
 48        Ok(rx.await?)
 49    }
 50}
 51
 52#[cfg(unix)]
 53pub struct AskPassSession {
 54    script_path: PathBuf,
 55    _askpass_task: Task<()>,
 56    askpass_opened_rx: Option<oneshot::Receiver<()>>,
 57    askpass_kill_master_rx: Option<oneshot::Receiver<()>>,
 58}
 59
 60#[cfg(unix)]
 61impl AskPassSession {
 62    /// This will create a new AskPassSession.
 63    /// You must retain this session until the master process exits.
 64    #[must_use]
 65    pub async fn new(
 66        executor: &BackgroundExecutor,
 67        mut delegate: AskPassDelegate,
 68    ) -> anyhow::Result<Self> {
 69        let temp_dir = tempfile::Builder::new().prefix("zed-askpass").tempdir()?;
 70        let askpass_socket = temp_dir.path().join("askpass.sock");
 71        let askpass_script_path = temp_dir.path().join("askpass.sh");
 72        let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
 73        let listener =
 74            UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
 75        let zed_path = get_shell_safe_zed_path()?;
 76
 77        let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<()>();
 78        let mut kill_tx = Some(askpass_kill_master_tx);
 79
 80        let askpass_task = executor.spawn(async move {
 81            let mut askpass_opened_tx = Some(askpass_opened_tx);
 82
 83            while let Ok((mut stream, _)) = listener.accept().await {
 84                if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
 85                    askpass_opened_tx.send(()).ok();
 86                }
 87                let mut buffer = Vec::new();
 88                let mut reader = BufReader::new(&mut stream);
 89                if reader.read_until(b'\0', &mut buffer).await.is_err() {
 90                    buffer.clear();
 91                }
 92                let prompt = String::from_utf8_lossy(&buffer);
 93                if let Some(password) = delegate
 94                    .ask_password(prompt.to_string())
 95                    .await
 96                    .context("failed to get askpass password")
 97                    .log_err()
 98                {
 99                    stream.write_all(password.as_bytes()).await.log_err();
100                } else {
101                    if let Some(kill_tx) = kill_tx.take() {
102                        kill_tx.send(()).log_err();
103                    }
104                    // note: we expect the caller to drop this task when it's done.
105                    // We need to keep the stream open until the caller is done to avoid
106                    // spurious errors from ssh.
107                    std::future::pending::<()>().await;
108                    drop(stream);
109                }
110            }
111            drop(temp_dir)
112        });
113
114        // Create an askpass script that communicates back to this process.
115        let askpass_script = format!(
116            "{shebang}\n{print_args} | {zed_exe} --askpass={askpass_socket} 2> /dev/null \n",
117            zed_exe = zed_path,
118            askpass_socket = askpass_socket.display(),
119            print_args = "printf '%s\\0' \"$@\"",
120            shebang = "#!/bin/sh",
121        );
122        fs::write(&askpass_script_path, askpass_script).await?;
123        fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
124
125        Ok(Self {
126            script_path: askpass_script_path,
127            _askpass_task: askpass_task,
128            askpass_kill_master_rx: Some(askpass_kill_master_rx),
129            askpass_opened_rx: Some(askpass_opened_rx),
130        })
131    }
132
133    pub fn script_path(&self) -> &Path {
134        &self.script_path
135    }
136
137    // This will run the askpass task forever, resolving as many authentication requests as needed.
138    // The caller is responsible for examining the result of their own commands and cancelling this
139    // future when this is no longer needed. Note that this can only be called once, but due to the
140    // drop order this takes an &mut, so you can `drop()` it after you're done with the master process.
141    pub async fn run(&mut self) -> AskPassResult {
142        let connection_timeout = Duration::from_secs(10);
143        let askpass_opened_rx = self.askpass_opened_rx.take().expect("Only call run once");
144        let askpass_kill_master_rx = self
145            .askpass_kill_master_rx
146            .take()
147            .expect("Only call run once");
148
149        select_biased! {
150            _ = askpass_opened_rx.fuse() => {
151                // Note: this await can only resolve after we are dropped.
152                askpass_kill_master_rx.await.ok();
153                return AskPassResult::CancelledByUser
154            }
155
156            _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
157                return AskPassResult::Timedout
158            }
159        }
160    }
161}
162
163#[cfg(unix)]
164fn get_shell_safe_zed_path() -> anyhow::Result<String> {
165    let zed_path = std::env::current_exe()
166        .context("Failed to determine current executable path for use in askpass")?
167        .to_string_lossy()
168        // see https://github.com/rust-lang/rust/issues/69343
169        .trim_end_matches(" (deleted)")
170        .to_string();
171
172    // NOTE: this was previously enabled, however, it caused errors when it shouldn't have
173    //       (see https://github.com/zed-industries/zed/issues/29819)
174    //       The zed path failing to execute within the askpass script results in very vague ssh
175    //       authentication failed errors, so this was done to try and surface a better error
176    //
177    // use std::os::unix::fs::MetadataExt;
178    // let metadata = std::fs::metadata(&zed_path)
179    //     .context("Failed to check metadata of Zed executable path for use in askpass")?;
180    // let is_executable = metadata.is_file() && metadata.mode() & 0o111 != 0;
181    // anyhow::ensure!(
182    //     is_executable,
183    //     "Failed to verify Zed executable path for use in askpass"
184    // );
185
186    // As of writing, this can only be fail if the path contains a null byte, which shouldn't be possible
187    // but shlex has annotated the error as #[non_exhaustive] so we can't make it a compile error if other
188    // errors are introduced in the future :(
189    let zed_path_escaped = shlex::try_quote(&zed_path)
190        .context("Failed to shell-escape Zed executable path for use in askpass")?;
191
192    return Ok(zed_path_escaped.to_string());
193}
194
195/// The main function for when Zed is running in netcat mode for use in askpass.
196/// Called from both the remote server binary and the zed binary in their respective main functions.
197#[cfg(unix)]
198pub fn main(socket: &str) {
199    use std::io::{self, Read, Write};
200    use std::os::unix::net::UnixStream;
201    use std::process::exit;
202
203    let mut stream = match UnixStream::connect(socket) {
204        Ok(stream) => stream,
205        Err(err) => {
206            eprintln!("Error connecting to socket {}: {}", socket, err);
207            exit(1);
208        }
209    };
210
211    let mut buffer = Vec::new();
212    if let Err(err) = io::stdin().read_to_end(&mut buffer) {
213        eprintln!("Error reading from stdin: {}", err);
214        exit(1);
215    }
216
217    if buffer.last() != Some(&b'\0') {
218        buffer.push(b'\0');
219    }
220
221    if let Err(err) = stream.write_all(&buffer) {
222        eprintln!("Error writing to socket: {}", err);
223        exit(1);
224    }
225
226    let mut response = Vec::new();
227    if let Err(err) = stream.read_to_end(&mut response) {
228        eprintln!("Error reading from socket: {}", err);
229        exit(1);
230    }
231
232    if let Err(err) = io::stdout().write_all(&response) {
233        eprintln!("Error writing to stdout: {}", err);
234        exit(1);
235    }
236}
237#[cfg(not(unix))]
238pub fn main(_socket: &str) {}
239
240#[cfg(not(unix))]
241pub struct AskPassSession {
242    path: PathBuf,
243}
244
245#[cfg(not(unix))]
246impl AskPassSession {
247    pub async fn new(_: &BackgroundExecutor, _: AskPassDelegate) -> anyhow::Result<Self> {
248        Ok(Self {
249            path: PathBuf::new(),
250        })
251    }
252
253    pub fn script_path(&self) -> &Path {
254        &self.path
255    }
256
257    pub async fn run(&mut self) -> AskPassResult {
258        futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))).await;
259        AskPassResult::Timedout
260    }
261}