1use std::{ffi::OsStr, time::Duration};
2
3use anyhow::{Context as _, Result};
4use futures::channel::{mpsc, oneshot};
5use futures::{
6 AsyncBufReadExt as _, AsyncWriteExt as _, FutureExt as _, SinkExt, StreamExt, io::BufReader,
7 select_biased,
8};
9use gpui::{AsyncApp, BackgroundExecutor, Task};
10use smol::fs;
11use util::ResultExt as _;
12
13#[derive(PartialEq, Eq)]
14pub enum AskPassResult {
15 CancelledByUser,
16 Timedout,
17}
18
19pub struct AskPassDelegate {
20 tx: mpsc::UnboundedSender<(String, oneshot::Sender<String>)>,
21 _task: Task<()>,
22}
23
24impl AskPassDelegate {
25 pub fn new(
26 cx: &mut AsyncApp,
27 password_prompt: impl Fn(String, oneshot::Sender<String>, &mut AsyncApp) + Send + Sync + 'static,
28 ) -> Self {
29 let (tx, mut rx) = mpsc::unbounded::<(String, oneshot::Sender<String>)>();
30 let task = cx.spawn(async move |cx: &mut AsyncApp| {
31 while let Some((prompt, channel)) = rx.next().await {
32 password_prompt(prompt, channel, cx);
33 }
34 });
35 Self { tx, _task: task }
36 }
37
38 pub async fn ask_password(&mut self, prompt: String) -> Result<String> {
39 let (tx, rx) = oneshot::channel();
40 self.tx.send((prompt, tx)).await?;
41 Ok(rx.await?)
42 }
43
44 pub fn new_always_failing() -> Self {
45 let (tx, _rx) = mpsc::unbounded::<(String, oneshot::Sender<String>)>();
46 Self {
47 tx,
48 _task: Task::ready(()),
49 }
50 }
51}
52
53pub struct AskPassSession {
54 #[cfg(not(target_os = "windows"))]
55 script_path: std::path::PathBuf,
56 #[cfg(not(target_os = "windows"))]
57 gpg_script_path: std::path::PathBuf,
58 #[cfg(target_os = "windows")]
59 askpass_helper: String,
60 #[cfg(target_os = "windows")]
61 secret: std::sync::Arc<parking_lot::Mutex<String>>,
62 _askpass_task: Task<()>,
63 askpass_opened_rx: Option<oneshot::Receiver<()>>,
64 askpass_kill_master_rx: Option<oneshot::Receiver<()>>,
65}
66
67#[cfg(not(target_os = "windows"))]
68const ASKPASS_SCRIPT_NAME: &str = "askpass.sh";
69#[cfg(target_os = "windows")]
70const ASKPASS_SCRIPT_NAME: &str = "askpass.ps1";
71
72#[cfg(not(target_os = "windows"))]
73const GPG_SCRIPT_NAME: &str = "gpg.sh";
74
75impl AskPassSession {
76 /// This will create a new AskPassSession.
77 /// You must retain this session until the master process exits.
78 #[must_use]
79 pub async fn new(executor: &BackgroundExecutor, mut delegate: AskPassDelegate) -> Result<Self> {
80 use net::async_net::UnixListener;
81 use util::fs::make_file_executable;
82
83 #[cfg(target_os = "windows")]
84 let secret = std::sync::Arc::new(parking_lot::Mutex::new(String::new()));
85 let temp_dir = tempfile::Builder::new().prefix("zed-askpass").tempdir()?;
86 let askpass_socket = temp_dir.path().join("askpass.sock");
87 let askpass_script_path = temp_dir.path().join(ASKPASS_SCRIPT_NAME);
88 #[cfg(not(target_os = "windows"))]
89 let gpg_script_path = temp_dir.path().join(GPG_SCRIPT_NAME);
90 let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
91 let listener = UnixListener::bind(&askpass_socket).context("creating askpass socket")?;
92 #[cfg(not(target_os = "windows"))]
93 let zed_path = util::get_shell_safe_zed_path()?;
94 #[cfg(target_os = "windows")]
95 let zed_path = std::env::current_exe()
96 .context("finding current executable path for use in askpass")?;
97
98 let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<()>();
99 let mut kill_tx = Some(askpass_kill_master_tx);
100
101 #[cfg(target_os = "windows")]
102 let askpass_secret = secret.clone();
103 let askpass_task = executor.spawn(async move {
104 let mut askpass_opened_tx = Some(askpass_opened_tx);
105
106 while let Ok((mut stream, _)) = listener.accept().await {
107 if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
108 askpass_opened_tx.send(()).ok();
109 }
110 let mut buffer = Vec::new();
111 let mut reader = BufReader::new(&mut stream);
112 if reader.read_until(b'\0', &mut buffer).await.is_err() {
113 buffer.clear();
114 }
115 let prompt = String::from_utf8_lossy(&buffer);
116 if let Some(password) = delegate
117 .ask_password(prompt.to_string())
118 .await
119 .context("getting askpass password")
120 .log_err()
121 {
122 stream.write_all(password.as_bytes()).await.log_err();
123 #[cfg(target_os = "windows")]
124 {
125 *askpass_secret.lock() = password;
126 }
127 } else {
128 if let Some(kill_tx) = kill_tx.take() {
129 kill_tx.send(()).log_err();
130 }
131 // note: we expect the caller to drop this task when it's done.
132 // We need to keep the stream open until the caller is done to avoid
133 // spurious errors from ssh.
134 std::future::pending::<()>().await;
135 drop(stream);
136 }
137 }
138 drop(temp_dir)
139 });
140
141 // Create an askpass script that communicates back to this process.
142 let askpass_script = generate_askpass_script(&zed_path, &askpass_socket);
143 fs::write(&askpass_script_path, askpass_script)
144 .await
145 .with_context(|| format!("creating askpass script at {askpass_script_path:?}"))?;
146 make_file_executable(&askpass_script_path).await?;
147 #[cfg(target_os = "windows")]
148 let askpass_helper = format!(
149 "powershell.exe -ExecutionPolicy Bypass -File {}",
150 askpass_script_path.display()
151 );
152
153 #[cfg(not(target_os = "windows"))]
154 {
155 let gpg_script = generate_gpg_script();
156 fs::write(&gpg_script_path, gpg_script)
157 .await
158 .with_context(|| format!("creating gpg wrapper script at {gpg_script_path:?}"))?;
159 make_file_executable(&gpg_script_path).await?;
160 }
161
162 Ok(Self {
163 #[cfg(not(target_os = "windows"))]
164 script_path: askpass_script_path,
165 #[cfg(not(target_os = "windows"))]
166 gpg_script_path,
167
168 #[cfg(target_os = "windows")]
169 secret,
170 #[cfg(target_os = "windows")]
171 askpass_helper,
172
173 _askpass_task: askpass_task,
174 askpass_kill_master_rx: Some(askpass_kill_master_rx),
175 askpass_opened_rx: Some(askpass_opened_rx),
176 })
177 }
178
179 #[cfg(not(target_os = "windows"))]
180 pub fn script_path(&self) -> impl AsRef<OsStr> {
181 &self.script_path
182 }
183
184 #[cfg(target_os = "windows")]
185 pub fn script_path(&self) -> impl AsRef<OsStr> {
186 &self.askpass_helper
187 }
188
189 #[cfg(not(target_os = "windows"))]
190 pub fn gpg_script_path(&self) -> Option<impl AsRef<OsStr>> {
191 Some(&self.gpg_script_path)
192 }
193
194 #[cfg(target_os = "windows")]
195 pub fn gpg_script_path(&self) -> Option<impl AsRef<OsStr>> {
196 // TODO implement wrapping GPG on Windows. This is more difficult than on Unix
197 // because we can't use --passphrase-fd with a nonstandard FD, and both --passphrase
198 // and --passphrase-file are insecure.
199 None::<std::path::PathBuf>
200 }
201
202 // This will run the askpass task forever, resolving as many authentication requests as needed.
203 // The caller is responsible for examining the result of their own commands and cancelling this
204 // future when this is no longer needed. Note that this can only be called once, but due to the
205 // drop order this takes an &mut, so you can `drop()` it after you're done with the master process.
206 pub async fn run(&mut self) -> AskPassResult {
207 // This is the default timeout setting used by VSCode.
208 let connection_timeout = Duration::from_secs(17);
209 let askpass_opened_rx = self.askpass_opened_rx.take().expect("Only call run once");
210 let askpass_kill_master_rx = self
211 .askpass_kill_master_rx
212 .take()
213 .expect("Only call run once");
214
215 select_biased! {
216 _ = askpass_opened_rx.fuse() => {
217 // Note: this await can only resolve after we are dropped.
218 askpass_kill_master_rx.await.ok();
219 return AskPassResult::CancelledByUser
220 }
221
222 _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
223 return AskPassResult::Timedout
224 }
225 }
226 }
227
228 /// This will return the password that was last set by the askpass script.
229 #[cfg(target_os = "windows")]
230 pub fn get_password(&self) -> String {
231 self.secret.lock().clone()
232 }
233}
234
235/// The main function for when Zed is running in netcat mode for use in askpass.
236/// Called from both the remote server binary and the zed binary in their respective main functions.
237pub fn main(socket: &str) {
238 use net::UnixStream;
239 use std::io::{self, Read, Write};
240 use std::process::exit;
241
242 let mut stream = match UnixStream::connect(socket) {
243 Ok(stream) => stream,
244 Err(err) => {
245 eprintln!("Error connecting to socket {}: {}", socket, err);
246 exit(1);
247 }
248 };
249
250 let mut buffer = Vec::new();
251 if let Err(err) = io::stdin().read_to_end(&mut buffer) {
252 eprintln!("Error reading from stdin: {}", err);
253 exit(1);
254 }
255
256 #[cfg(target_os = "windows")]
257 while buffer.last().map_or(false, |&b| b == b'\n' || b == b'\r') {
258 buffer.pop();
259 }
260 if buffer.last() != Some(&b'\0') {
261 buffer.push(b'\0');
262 }
263
264 if let Err(err) = stream.write_all(&buffer) {
265 eprintln!("Error writing to socket: {}", err);
266 exit(1);
267 }
268
269 let mut response = Vec::new();
270 if let Err(err) = stream.read_to_end(&mut response) {
271 eprintln!("Error reading from socket: {}", err);
272 exit(1);
273 }
274
275 if let Err(err) = io::stdout().write_all(&response) {
276 eprintln!("Error writing to stdout: {}", err);
277 exit(1);
278 }
279}
280
281#[inline]
282#[cfg(not(target_os = "windows"))]
283fn generate_askpass_script(zed_path: &str, askpass_socket: &std::path::Path) -> String {
284 format!(
285 "{shebang}\n{print_args} | {zed_exe} --askpass={askpass_socket} 2> /dev/null \n",
286 zed_exe = zed_path,
287 askpass_socket = askpass_socket.display(),
288 print_args = "printf '%s\\0' \"$@\"",
289 shebang = "#!/bin/sh",
290 )
291}
292
293#[inline]
294#[cfg(target_os = "windows")]
295fn generate_askpass_script(zed_path: &std::path::Path, askpass_socket: &std::path::Path) -> String {
296 format!(
297 r#"
298 $ErrorActionPreference = 'Stop';
299 ($args -join [char]0) | & "{zed_exe}" --askpass={askpass_socket} 2> $null
300 "#,
301 zed_exe = zed_path.display(),
302 askpass_socket = askpass_socket.display(),
303 )
304}
305
306#[inline]
307#[cfg(not(target_os = "windows"))]
308fn generate_gpg_script() -> String {
309 use unindent::Unindent as _;
310
311 r#"
312 #!/bin/sh
313 set -eu
314
315 unset GIT_CONFIG_PARAMETERS
316 GPG_PROGRAM=$(git config gpg.program || echo 'gpg')
317 PROMPT="Enter passphrase to unlock GPG key:"
318 PASSPHRASE=$(${GIT_ASKPASS} "${PROMPT}")
319
320 exec "${GPG_PROGRAM}" --batch --no-tty --yes --passphrase-fd 3 --pinentry-mode loopback "$@" 3<<EOF
321 ${PASSPHRASE}
322 EOF
323 "#.unindent()
324}