1mod encrypted_password;
2
3pub use encrypted_password::{EncryptedPassword, IKnowWhatIAmDoingAndIHaveReadTheDocs};
4
5use net::async_net::UnixListener;
6use smol::lock::Mutex;
7use util::fs::make_file_executable;
8
9use std::ffi::OsStr;
10use std::ops::ControlFlow;
11use std::sync::Arc;
12use std::sync::OnceLock;
13use std::time::Duration;
14
15use anyhow::{Context as _, Result};
16use futures::channel::{mpsc, oneshot};
17use futures::{
18 AsyncBufReadExt as _, AsyncWriteExt as _, FutureExt as _, SinkExt, StreamExt, io::BufReader,
19 select_biased,
20};
21use gpui::{AsyncApp, BackgroundExecutor, Task};
22use smol::fs;
23use util::{
24 ResultExt as _, debug_panic, maybe,
25 paths::PathExt,
26 shell::{PosixShell, ShellKind},
27};
28
29/// Path to the program used for askpass
30///
31/// On Unix and remote servers, this defaults to the current executable
32/// On Windows, this is set to the CLI variant of zed
33static ASKPASS_PROGRAM: OnceLock<std::path::PathBuf> = OnceLock::new();
34
35#[derive(PartialEq, Eq)]
36pub enum AskPassResult {
37 CancelledByUser,
38 Timedout,
39}
40
41pub struct AskPassDelegate {
42 tx: mpsc::UnboundedSender<(String, oneshot::Sender<EncryptedPassword>)>,
43 executor: BackgroundExecutor,
44 _task: Task<()>,
45}
46
47impl AskPassDelegate {
48 pub fn new(
49 cx: &mut AsyncApp,
50 password_prompt: impl Fn(String, oneshot::Sender<EncryptedPassword>, &mut AsyncApp)
51 + Send
52 + Sync
53 + 'static,
54 ) -> Self {
55 let (tx, mut rx) = mpsc::unbounded::<(String, oneshot::Sender<_>)>();
56 let task = cx.spawn(async move |cx: &mut AsyncApp| {
57 while let Some((prompt, channel)) = rx.next().await {
58 password_prompt(prompt, channel, cx);
59 }
60 });
61 Self {
62 tx,
63 _task: task,
64 executor: cx.background_executor().clone(),
65 }
66 }
67
68 pub fn ask_password(&mut self, prompt: String) -> Task<Option<EncryptedPassword>> {
69 let mut this_tx = self.tx.clone();
70 self.executor.spawn(async move {
71 let (tx, rx) = oneshot::channel();
72 this_tx.send((prompt, tx)).await.ok()?;
73 rx.await.ok()
74 })
75 }
76}
77
78pub struct AskPassSession {
79 #[cfg(target_os = "windows")]
80 secret: std::sync::Arc<std::sync::Mutex<Option<EncryptedPassword>>>,
81 askpass_task: PasswordProxy,
82 askpass_opened_rx: Option<oneshot::Receiver<()>>,
83 askpass_kill_master_rx: Option<oneshot::Receiver<()>>,
84 executor: BackgroundExecutor,
85}
86
87const ASKPASS_SCRIPT_NAME: &str = if cfg!(target_os = "windows") {
88 "askpass.ps1"
89} else {
90 "askpass.sh"
91};
92
93impl AskPassSession {
94 /// This will create a new AskPassSession.
95 /// You must retain this session until the master process exits.
96 #[must_use]
97 pub async fn new(executor: BackgroundExecutor, mut delegate: AskPassDelegate) -> Result<Self> {
98 #[cfg(target_os = "windows")]
99 let secret = std::sync::Arc::new(std::sync::Mutex::new(None));
100
101 let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
102
103 let askpass_opened_tx = Arc::new(Mutex::new(Some(askpass_opened_tx)));
104
105 let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<()>();
106 let kill_tx = Arc::new(Mutex::new(Some(askpass_kill_master_tx)));
107
108 let get_password = {
109 let executor = executor.clone();
110
111 #[cfg(target_os = "windows")]
112 let askpass_secret = secret.clone();
113 move |prompt| {
114 let prompt = delegate.ask_password(prompt);
115 let kill_tx = kill_tx.clone();
116 let askpass_opened_tx = askpass_opened_tx.clone();
117 #[cfg(target_os = "windows")]
118 let askpass_secret = askpass_secret.clone();
119 executor.spawn(async move {
120 if let Some(askpass_opened_tx) = askpass_opened_tx.lock().await.take() {
121 askpass_opened_tx.send(()).ok();
122 }
123 if let Some(password) = prompt.await {
124 #[cfg(target_os = "windows")]
125 {
126 askpass_secret.lock().unwrap().replace(password.clone());
127 }
128 ControlFlow::Continue(Ok(password))
129 } else {
130 if let Some(kill_tx) = kill_tx.lock().await.take() {
131 kill_tx.send(()).log_err();
132 }
133 ControlFlow::Break(())
134 }
135 })
136 }
137 };
138 let askpass_task = PasswordProxy::new(get_password, executor.clone()).await?;
139
140 Ok(Self {
141 #[cfg(target_os = "windows")]
142 secret,
143
144 askpass_task,
145 askpass_kill_master_rx: Some(askpass_kill_master_rx),
146 askpass_opened_rx: Some(askpass_opened_rx),
147 executor,
148 })
149 }
150
151 // This will run the askpass task forever, resolving as many authentication requests as needed.
152 // The caller is responsible for examining the result of their own commands and cancelling this
153 // future when this is no longer needed. Note that this can only be called once, but due to the
154 // drop order this takes an &mut, so you can `drop()` it after you're done with the master process.
155 pub async fn run(&mut self) -> AskPassResult {
156 // This is the default timeout setting used by VSCode.
157 let connection_timeout = Duration::from_secs(17);
158 let askpass_opened_rx = self.askpass_opened_rx.take().expect("Only call run once");
159 let askpass_kill_master_rx = self
160 .askpass_kill_master_rx
161 .take()
162 .expect("Only call run once");
163 let executor = self.executor.clone();
164
165 select_biased! {
166 _ = askpass_opened_rx.fuse() => {
167 // Note: this await can only resolve after we are dropped.
168 askpass_kill_master_rx.await.ok();
169 AskPassResult::CancelledByUser
170 }
171
172 _ = futures::FutureExt::fuse(executor.timer(connection_timeout)) => {
173 AskPassResult::Timedout
174 }
175 }
176 }
177
178 /// This will return the password that was last set by the askpass script.
179 #[cfg(target_os = "windows")]
180 pub fn get_password(&self) -> Option<EncryptedPassword> {
181 self.secret.lock().ok()?.clone()
182 }
183
184 pub fn script_path(&self) -> impl AsRef<OsStr> {
185 self.askpass_task.script_path()
186 }
187}
188
189pub struct PasswordProxy {
190 _task: Task<()>,
191 #[cfg(not(target_os = "windows"))]
192 askpass_script_path: std::path::PathBuf,
193 #[cfg(target_os = "windows")]
194 askpass_helper: String,
195}
196
197impl PasswordProxy {
198 pub async fn new(
199 mut get_password: impl FnMut(String) -> Task<ControlFlow<(), Result<EncryptedPassword>>>
200 + 'static
201 + Send
202 + Sync,
203 executor: BackgroundExecutor,
204 ) -> Result<Self> {
205 let temp_dir = tempfile::Builder::new().prefix("zed-askpass").tempdir()?;
206 let askpass_socket = temp_dir.path().join("askpass.sock");
207 let askpass_script_path = temp_dir.path().join(ASKPASS_SCRIPT_NAME);
208 let current_exec =
209 std::env::current_exe().context("Failed to determine current zed executable path.")?;
210
211 // TODO: inferred from the use of powershell.exe in askpass_helper_script
212 let shell_kind = if cfg!(windows) {
213 ShellKind::PowerShell
214 } else {
215 // TODO: Consider using the user's actual shell instead of hardcoding "sh"
216 ShellKind::Posix(PosixShell::Sh)
217 };
218 let askpass_program = ASKPASS_PROGRAM.get_or_init(|| current_exec);
219 // Create an askpass script that communicates back to this process.
220 let askpass_script = generate_askpass_script(shell_kind, askpass_program, &askpass_socket)?;
221 let _task = executor.spawn(async move {
222 maybe!(async move {
223 let listener =
224 UnixListener::bind(&askpass_socket).context("creating askpass socket")?;
225
226 while let Ok((mut stream, _)) = listener.accept().await {
227 let mut buffer = Vec::new();
228 let mut reader = BufReader::new(&mut stream);
229 if reader.read_until(b'\0', &mut buffer).await.is_err() {
230 buffer.clear();
231 }
232 let prompt = String::from_utf8_lossy(&buffer).into_owned();
233 let password = get_password(prompt).await;
234 match password {
235 ControlFlow::Continue(password) => {
236 if let Ok(password) = password
237 && let Ok(decrypted) =
238 password.decrypt(IKnowWhatIAmDoingAndIHaveReadTheDocs)
239 {
240 stream.write_all(decrypted.as_bytes()).await.log_err();
241 }
242 }
243 ControlFlow::Break(()) => {
244 // note: we expect the caller to drop this task when it's done.
245 // We need to keep the stream open until the caller is done to avoid
246 // spurious errors from ssh.
247 std::future::pending::<()>().await;
248 drop(stream);
249 }
250 }
251 }
252 drop(temp_dir);
253 Result::<_, anyhow::Error>::Ok(())
254 })
255 .await
256 .log_err();
257 });
258
259 fs::write(&askpass_script_path, askpass_script)
260 .await
261 .with_context(|| format!("creating askpass script at {askpass_script_path:?}"))?;
262 make_file_executable(&askpass_script_path)
263 .await
264 .with_context(|| {
265 format!("marking askpass script executable at {askpass_script_path:?}")
266 })?;
267 // todo(shell): There might be no powershell on the system
268 #[cfg(target_os = "windows")]
269 let askpass_helper = format!(
270 "powershell.exe -ExecutionPolicy Bypass -File \"{}\"",
271 askpass_script_path.display()
272 );
273
274 Ok(Self {
275 _task,
276 #[cfg(not(target_os = "windows"))]
277 askpass_script_path,
278 #[cfg(target_os = "windows")]
279 askpass_helper,
280 })
281 }
282
283 pub fn script_path(&self) -> impl AsRef<OsStr> {
284 #[cfg(not(target_os = "windows"))]
285 {
286 &self.askpass_script_path
287 }
288 #[cfg(target_os = "windows")]
289 {
290 &self.askpass_helper
291 }
292 }
293}
294/// The main function for when Zed is running in netcat mode for use in askpass.
295/// Called from both the remote server binary and the zed binary in their respective main functions.
296pub fn main(socket: &str) {
297 use net::UnixStream;
298 use std::io::{self, Read, Write};
299 use std::process::exit;
300
301 let mut stream = match UnixStream::connect(socket) {
302 Ok(stream) => stream,
303 Err(err) => {
304 eprintln!("Error connecting to socket {}: {}", socket, err);
305 exit(1);
306 }
307 };
308
309 let mut buffer = Vec::new();
310 if let Err(err) = io::stdin().read_to_end(&mut buffer) {
311 eprintln!("Error reading from stdin: {}", err);
312 exit(1);
313 }
314
315 #[cfg(target_os = "windows")]
316 while buffer.last().is_some_and(|&b| b == b'\n' || b == b'\r') {
317 buffer.pop();
318 }
319 if buffer.last() != Some(&b'\0') {
320 buffer.push(b'\0');
321 }
322
323 if let Err(err) = stream.write_all(&buffer) {
324 eprintln!("Error writing to socket: {}", err);
325 exit(1);
326 }
327
328 let mut response = Vec::new();
329 if let Err(err) = stream.read_to_end(&mut response) {
330 eprintln!("Error reading from socket: {}", err);
331 exit(1);
332 }
333
334 if let Err(err) = io::stdout().write_all(&response) {
335 eprintln!("Error writing to stdout: {}", err);
336 exit(1);
337 }
338}
339
340pub fn set_askpass_program(path: std::path::PathBuf) {
341 if ASKPASS_PROGRAM.set(path).is_err() {
342 debug_panic!("askpass program has already been set");
343 }
344}
345
346#[inline]
347#[cfg(not(target_os = "windows"))]
348fn generate_askpass_script(
349 shell_kind: ShellKind,
350 askpass_program: &std::path::Path,
351 askpass_socket: &std::path::Path,
352) -> Result<String> {
353 let askpass_program = shell_kind.prepend_command_prefix(
354 askpass_program
355 .to_str()
356 .context("Askpass program is on a non-utf8 path")?,
357 );
358 let askpass_program = shell_kind
359 .try_quote_prefix_aware(&askpass_program)
360 .context("Failed to shell-escape Askpass program path")?;
361 let askpass_socket = askpass_socket
362 .try_shell_safe(Some(&shell_kind))
363 .context("Failed to shell-escape Askpass socket path")?;
364 let print_args = "printf '%s\\0' \"$@\"";
365 let shebang = "#!/bin/sh";
366 Ok(format!(
367 "{shebang}\n{print_args} | {askpass_program} --askpass={askpass_socket} 2> /dev/null \n",
368 ))
369}
370
371#[inline]
372#[cfg(target_os = "windows")]
373fn generate_askpass_script(
374 shell_kind: ShellKind,
375 askpass_program: &std::path::Path,
376 askpass_socket: &std::path::Path,
377) -> Result<String> {
378 let askpass_program = shell_kind.prepend_command_prefix(
379 askpass_program
380 .to_str()
381 .context("Askpass program is on a non-utf8 path")?,
382 );
383 let askpass_program = shell_kind
384 .try_quote_prefix_aware(&askpass_program)
385 .context("Failed to shell-escape Askpass program path")?;
386 let askpass_socket = askpass_socket
387 .try_shell_safe(Some(&shell_kind))
388 .context("Failed to shell-escape Askpass socket path")?;
389 Ok(format!(
390 r#"
391 $ErrorActionPreference = 'Stop';
392 ($args -join [char]0) | {askpass_program} --askpass={askpass_socket} 2> $null
393 "#,
394 ))
395}