askpass.rs

  1use std::path::{Path, PathBuf};
  2use std::time::Duration;
  3
  4#[cfg(unix)]
  5use anyhow::Context as _;
  6use futures::channel::{mpsc, oneshot};
  7#[cfg(unix)]
  8use futures::{AsyncBufReadExt as _, io::BufReader};
  9#[cfg(unix)]
 10use futures::{AsyncWriteExt as _, FutureExt as _, select_biased};
 11use futures::{SinkExt, StreamExt};
 12use gpui::{AsyncApp, BackgroundExecutor, Task};
 13#[cfg(unix)]
 14use smol::fs;
 15#[cfg(unix)]
 16use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
 17#[cfg(unix)]
 18use util::ResultExt as _;
 19#[cfg(unix)]
 20use util::get_shell_safe_zed_path;
 21
 22#[derive(PartialEq, Eq)]
 23pub enum AskPassResult {
 24    CancelledByUser,
 25    Timedout,
 26}
 27
 28pub struct AskPassDelegate {
 29    tx: mpsc::UnboundedSender<(String, oneshot::Sender<String>)>,
 30    _task: Task<()>,
 31}
 32
 33impl AskPassDelegate {
 34    pub fn new(
 35        cx: &mut AsyncApp,
 36        password_prompt: impl Fn(String, oneshot::Sender<String>, &mut AsyncApp) + Send + Sync + 'static,
 37    ) -> Self {
 38        let (tx, mut rx) = mpsc::unbounded::<(String, oneshot::Sender<String>)>();
 39        let task = cx.spawn(async move |cx: &mut AsyncApp| {
 40            while let Some((prompt, channel)) = rx.next().await {
 41                password_prompt(prompt, channel, cx);
 42            }
 43        });
 44        Self { tx, _task: task }
 45    }
 46
 47    pub async fn ask_password(&mut self, prompt: String) -> anyhow::Result<String> {
 48        let (tx, rx) = oneshot::channel();
 49        self.tx.send((prompt, tx)).await?;
 50        Ok(rx.await?)
 51    }
 52}
 53
 54#[cfg(unix)]
 55pub struct AskPassSession {
 56    script_path: PathBuf,
 57    _askpass_task: Task<()>,
 58    askpass_opened_rx: Option<oneshot::Receiver<()>>,
 59    askpass_kill_master_rx: Option<oneshot::Receiver<()>>,
 60}
 61
 62#[cfg(unix)]
 63impl AskPassSession {
 64    /// This will create a new AskPassSession.
 65    /// You must retain this session until the master process exits.
 66    #[must_use]
 67    pub async fn new(
 68        executor: &BackgroundExecutor,
 69        mut delegate: AskPassDelegate,
 70    ) -> anyhow::Result<Self> {
 71        let temp_dir = tempfile::Builder::new().prefix("zed-askpass").tempdir()?;
 72        let askpass_socket = temp_dir.path().join("askpass.sock");
 73        let askpass_script_path = temp_dir.path().join("askpass.sh");
 74        let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
 75        let listener =
 76            UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
 77        let zed_path = get_shell_safe_zed_path()?;
 78
 79        let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<()>();
 80        let mut kill_tx = Some(askpass_kill_master_tx);
 81
 82        let askpass_task = executor.spawn(async move {
 83            let mut askpass_opened_tx = Some(askpass_opened_tx);
 84
 85            while let Ok((mut stream, _)) = listener.accept().await {
 86                if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
 87                    askpass_opened_tx.send(()).ok();
 88                }
 89                let mut buffer = Vec::new();
 90                let mut reader = BufReader::new(&mut stream);
 91                if reader.read_until(b'\0', &mut buffer).await.is_err() {
 92                    buffer.clear();
 93                }
 94                let prompt = String::from_utf8_lossy(&buffer);
 95                if let Some(password) = delegate
 96                    .ask_password(prompt.to_string())
 97                    .await
 98                    .context("failed to get askpass password")
 99                    .log_err()
100                {
101                    stream.write_all(password.as_bytes()).await.log_err();
102                } else {
103                    if let Some(kill_tx) = kill_tx.take() {
104                        kill_tx.send(()).log_err();
105                    }
106                    // note: we expect the caller to drop this task when it's done.
107                    // We need to keep the stream open until the caller is done to avoid
108                    // spurious errors from ssh.
109                    std::future::pending::<()>().await;
110                    drop(stream);
111                }
112            }
113            drop(temp_dir)
114        });
115
116        // Create an askpass script that communicates back to this process.
117        let askpass_script = format!(
118            "{shebang}\n{print_args} | {zed_exe} --askpass={askpass_socket} 2> /dev/null \n",
119            zed_exe = zed_path,
120            askpass_socket = askpass_socket.display(),
121            print_args = "printf '%s\\0' \"$@\"",
122            shebang = "#!/bin/sh",
123        );
124        fs::write(&askpass_script_path, askpass_script).await?;
125        fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
126
127        Ok(Self {
128            script_path: askpass_script_path,
129            _askpass_task: askpass_task,
130            askpass_kill_master_rx: Some(askpass_kill_master_rx),
131            askpass_opened_rx: Some(askpass_opened_rx),
132        })
133    }
134
135    pub fn script_path(&self) -> &Path {
136        &self.script_path
137    }
138
139    // This will run the askpass task forever, resolving as many authentication requests as needed.
140    // The caller is responsible for examining the result of their own commands and cancelling this
141    // future when this is no longer needed. Note that this can only be called once, but due to the
142    // drop order this takes an &mut, so you can `drop()` it after you're done with the master process.
143    pub async fn run(&mut self) -> AskPassResult {
144        let connection_timeout = Duration::from_secs(10);
145        let askpass_opened_rx = self.askpass_opened_rx.take().expect("Only call run once");
146        let askpass_kill_master_rx = self
147            .askpass_kill_master_rx
148            .take()
149            .expect("Only call run once");
150
151        select_biased! {
152            _ = askpass_opened_rx.fuse() => {
153                // Note: this await can only resolve after we are dropped.
154                askpass_kill_master_rx.await.ok();
155                return AskPassResult::CancelledByUser
156            }
157
158            _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
159                return AskPassResult::Timedout
160            }
161        }
162    }
163}
164
165/// The main function for when Zed is running in netcat mode for use in askpass.
166/// Called from both the remote server binary and the zed binary in their respective main functions.
167#[cfg(unix)]
168pub fn main(socket: &str) {
169    use std::io::{self, Read, Write};
170    use std::os::unix::net::UnixStream;
171    use std::process::exit;
172
173    let mut stream = match UnixStream::connect(socket) {
174        Ok(stream) => stream,
175        Err(err) => {
176            eprintln!("Error connecting to socket {}: {}", socket, err);
177            exit(1);
178        }
179    };
180
181    let mut buffer = Vec::new();
182    if let Err(err) = io::stdin().read_to_end(&mut buffer) {
183        eprintln!("Error reading from stdin: {}", err);
184        exit(1);
185    }
186
187    if buffer.last() != Some(&b'\0') {
188        buffer.push(b'\0');
189    }
190
191    if let Err(err) = stream.write_all(&buffer) {
192        eprintln!("Error writing to socket: {}", err);
193        exit(1);
194    }
195
196    let mut response = Vec::new();
197    if let Err(err) = stream.read_to_end(&mut response) {
198        eprintln!("Error reading from socket: {}", err);
199        exit(1);
200    }
201
202    if let Err(err) = io::stdout().write_all(&response) {
203        eprintln!("Error writing to stdout: {}", err);
204        exit(1);
205    }
206}
207#[cfg(not(unix))]
208pub fn main(_socket: &str) {}
209
210#[cfg(not(unix))]
211pub struct AskPassSession {
212    path: PathBuf,
213}
214
215#[cfg(not(unix))]
216impl AskPassSession {
217    pub async fn new(_: &BackgroundExecutor, _: AskPassDelegate) -> anyhow::Result<Self> {
218        Ok(Self {
219            path: PathBuf::new(),
220        })
221    }
222
223    pub fn script_path(&self) -> &Path {
224        &self.path
225    }
226
227    pub async fn run(&mut self) -> AskPassResult {
228        futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))).await;
229        AskPassResult::Timedout
230    }
231}