1mod encrypted_password;
2
3pub use encrypted_password::{EncryptedPassword, IKnowWhatIAmDoingAndIHaveReadTheDocs};
4
5use net::async_net::UnixListener;
6use smol::lock::Mutex;
7use util::fs::make_file_executable;
8
9use std::ffi::OsStr;
10use std::ops::ControlFlow;
11use std::sync::Arc;
12use std::sync::OnceLock;
13use std::time::Duration;
14
15use anyhow::{Context as _, Result};
16use futures::channel::{mpsc, oneshot};
17use futures::{
18 AsyncBufReadExt as _, AsyncWriteExt as _, FutureExt as _, SinkExt, StreamExt, io::BufReader,
19 select_biased,
20};
21use gpui::{AsyncApp, BackgroundExecutor, Task};
22use smol::fs;
23use util::{ResultExt as _, debug_panic, maybe, paths::PathExt, shell::ShellKind};
24
25/// Path to the program used for askpass
26///
27/// On Unix and remote servers, this defaults to the current executable
28/// On Windows, this is set to the CLI variant of zed
29static ASKPASS_PROGRAM: OnceLock<std::path::PathBuf> = OnceLock::new();
30
31#[derive(PartialEq, Eq)]
32pub enum AskPassResult {
33 CancelledByUser,
34 Timedout,
35}
36
37pub struct AskPassDelegate {
38 tx: mpsc::UnboundedSender<(String, oneshot::Sender<EncryptedPassword>)>,
39 executor: BackgroundExecutor,
40 _task: Task<()>,
41}
42
43impl AskPassDelegate {
44 pub fn new(
45 cx: &mut AsyncApp,
46 password_prompt: impl Fn(String, oneshot::Sender<EncryptedPassword>, &mut AsyncApp)
47 + Send
48 + Sync
49 + 'static,
50 ) -> Self {
51 let (tx, mut rx) = mpsc::unbounded::<(String, oneshot::Sender<_>)>();
52 let task = cx.spawn(async move |cx: &mut AsyncApp| {
53 while let Some((prompt, channel)) = rx.next().await {
54 password_prompt(prompt, channel, cx);
55 }
56 });
57 Self {
58 tx,
59 _task: task,
60 executor: cx.background_executor().clone(),
61 }
62 }
63
64 pub fn ask_password(&mut self, prompt: String) -> Task<Option<EncryptedPassword>> {
65 let mut this_tx = self.tx.clone();
66 self.executor.spawn(async move {
67 let (tx, rx) = oneshot::channel();
68 this_tx.send((prompt, tx)).await.ok()?;
69 rx.await.ok()
70 })
71 }
72}
73
74pub struct AskPassSession {
75 #[cfg(target_os = "windows")]
76 secret: std::sync::Arc<std::sync::Mutex<Option<EncryptedPassword>>>,
77 askpass_task: PasswordProxy,
78 askpass_opened_rx: Option<oneshot::Receiver<()>>,
79 askpass_kill_master_rx: Option<oneshot::Receiver<()>>,
80 executor: BackgroundExecutor,
81}
82
83const ASKPASS_SCRIPT_NAME: &str = if cfg!(target_os = "windows") {
84 "askpass.ps1"
85} else {
86 "askpass.sh"
87};
88
89impl AskPassSession {
90 /// This will create a new AskPassSession.
91 /// You must retain this session until the master process exits.
92 #[must_use]
93 pub async fn new(executor: BackgroundExecutor, mut delegate: AskPassDelegate) -> Result<Self> {
94 #[cfg(target_os = "windows")]
95 let secret = std::sync::Arc::new(std::sync::Mutex::new(None));
96
97 let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
98
99 let askpass_opened_tx = Arc::new(Mutex::new(Some(askpass_opened_tx)));
100
101 let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<()>();
102 let kill_tx = Arc::new(Mutex::new(Some(askpass_kill_master_tx)));
103
104 let get_password = {
105 let executor = executor.clone();
106
107 #[cfg(target_os = "windows")]
108 let askpass_secret = secret.clone();
109 move |prompt| {
110 let prompt = delegate.ask_password(prompt);
111 let kill_tx = kill_tx.clone();
112 let askpass_opened_tx = askpass_opened_tx.clone();
113 #[cfg(target_os = "windows")]
114 let askpass_secret = askpass_secret.clone();
115 executor.spawn(async move {
116 if let Some(askpass_opened_tx) = askpass_opened_tx.lock().await.take() {
117 askpass_opened_tx.send(()).ok();
118 }
119 if let Some(password) = prompt.await {
120 #[cfg(target_os = "windows")]
121 {
122 askpass_secret.lock().unwrap().replace(password.clone());
123 }
124 ControlFlow::Continue(Ok(password))
125 } else {
126 if let Some(kill_tx) = kill_tx.lock().await.take() {
127 kill_tx.send(()).log_err();
128 }
129 ControlFlow::Break(())
130 }
131 })
132 }
133 };
134 let askpass_task = PasswordProxy::new(Box::new(get_password), executor.clone()).await?;
135
136 Ok(Self {
137 #[cfg(target_os = "windows")]
138 secret,
139
140 askpass_task,
141 askpass_kill_master_rx: Some(askpass_kill_master_rx),
142 askpass_opened_rx: Some(askpass_opened_rx),
143 executor,
144 })
145 }
146
147 // This will run the askpass task forever, resolving as many authentication requests as needed.
148 // The caller is responsible for examining the result of their own commands and cancelling this
149 // future when this is no longer needed. Note that this can only be called once, but due to the
150 // drop order this takes an &mut, so you can `drop()` it after you're done with the master process.
151 pub async fn run(&mut self) -> AskPassResult {
152 // This is the default timeout setting used by VSCode.
153 let connection_timeout = Duration::from_secs(17);
154 let askpass_opened_rx = self.askpass_opened_rx.take().expect("Only call run once");
155 let askpass_kill_master_rx = self
156 .askpass_kill_master_rx
157 .take()
158 .expect("Only call run once");
159 let executor = self.executor.clone();
160
161 select_biased! {
162 _ = askpass_opened_rx.fuse() => {
163 // Note: this await can only resolve after we are dropped.
164 askpass_kill_master_rx.await.ok();
165 AskPassResult::CancelledByUser
166 }
167
168 _ = futures::FutureExt::fuse(executor.timer(connection_timeout)) => {
169 AskPassResult::Timedout
170 }
171 }
172 }
173
174 /// This will return the password that was last set by the askpass script.
175 #[cfg(target_os = "windows")]
176 pub fn get_password(&self) -> Option<EncryptedPassword> {
177 self.secret.lock().ok()?.clone()
178 }
179
180 pub fn script_path(&self) -> impl AsRef<OsStr> {
181 self.askpass_task.script_path()
182 }
183}
184
185pub struct PasswordProxy {
186 _task: Task<()>,
187 #[cfg(not(target_os = "windows"))]
188 askpass_script_path: std::path::PathBuf,
189 #[cfg(target_os = "windows")]
190 askpass_helper: String,
191}
192
193impl PasswordProxy {
194 pub async fn new(
195 mut get_password: Box<
196 dyn FnMut(String) -> Task<ControlFlow<(), Result<EncryptedPassword>>>
197 + 'static
198 + Send
199 + Sync,
200 >,
201 executor: BackgroundExecutor,
202 ) -> Result<Self> {
203 let temp_dir = tempfile::Builder::new().prefix("zed-askpass").tempdir()?;
204 let askpass_socket = temp_dir.path().join("askpass.sock");
205 let askpass_script_path = temp_dir.path().join(ASKPASS_SCRIPT_NAME);
206 let current_exec =
207 std::env::current_exe().context("Failed to determine current zed executable path.")?;
208
209 // TODO: inferred from the use of powershell.exe in askpass_helper_script
210 let shell_kind = if cfg!(windows) {
211 ShellKind::PowerShell
212 } else {
213 ShellKind::Posix
214 };
215 let askpass_program = ASKPASS_PROGRAM.get_or_init(|| current_exec);
216 // Create an askpass script that communicates back to this process.
217 let askpass_script = generate_askpass_script(shell_kind, askpass_program, &askpass_socket)?;
218 let _task = executor.spawn(async move {
219 maybe!(async move {
220 let listener =
221 UnixListener::bind(&askpass_socket).context("creating askpass socket")?;
222
223 while let Ok((mut stream, _)) = listener.accept().await {
224 let mut buffer = Vec::new();
225 let mut reader = BufReader::new(&mut stream);
226 if reader.read_until(b'\0', &mut buffer).await.is_err() {
227 buffer.clear();
228 }
229 let prompt = String::from_utf8_lossy(&buffer).into_owned();
230 let password = get_password(prompt).await;
231 match password {
232 ControlFlow::Continue(password) => {
233 if let Ok(password) = password
234 && let Ok(decrypted) =
235 password.decrypt(IKnowWhatIAmDoingAndIHaveReadTheDocs)
236 {
237 stream.write_all(decrypted.as_bytes()).await.log_err();
238 }
239 }
240 ControlFlow::Break(()) => {
241 // note: we expect the caller to drop this task when it's done.
242 // We need to keep the stream open until the caller is done to avoid
243 // spurious errors from ssh.
244 std::future::pending::<()>().await;
245 drop(stream);
246 }
247 }
248 }
249 drop(temp_dir);
250 Result::<_, anyhow::Error>::Ok(())
251 })
252 .await
253 .log_err();
254 });
255
256 fs::write(&askpass_script_path, askpass_script)
257 .await
258 .with_context(|| format!("creating askpass script at {askpass_script_path:?}"))?;
259 make_file_executable(&askpass_script_path)
260 .await
261 .with_context(|| {
262 format!("marking askpass script executable at {askpass_script_path:?}")
263 })?;
264 // todo(shell): There might be no powershell on the system
265 #[cfg(target_os = "windows")]
266 let askpass_helper = format!(
267 "powershell.exe -ExecutionPolicy Bypass -File \"{}\"",
268 askpass_script_path.display()
269 );
270
271 Ok(Self {
272 _task,
273 #[cfg(not(target_os = "windows"))]
274 askpass_script_path,
275 #[cfg(target_os = "windows")]
276 askpass_helper,
277 })
278 }
279
280 pub fn script_path(&self) -> impl AsRef<OsStr> {
281 #[cfg(not(target_os = "windows"))]
282 {
283 &self.askpass_script_path
284 }
285 #[cfg(target_os = "windows")]
286 {
287 &self.askpass_helper
288 }
289 }
290}
291/// The main function for when Zed is running in netcat mode for use in askpass.
292/// Called from both the remote server binary and the zed binary in their respective main functions.
293pub fn main(socket: &str) {
294 use net::UnixStream;
295 use std::io::{self, Read, Write};
296 use std::process::exit;
297
298 let mut stream = match UnixStream::connect(socket) {
299 Ok(stream) => stream,
300 Err(err) => {
301 eprintln!("Error connecting to socket {}: {}", socket, err);
302 exit(1);
303 }
304 };
305
306 let mut buffer = Vec::new();
307 if let Err(err) = io::stdin().read_to_end(&mut buffer) {
308 eprintln!("Error reading from stdin: {}", err);
309 exit(1);
310 }
311
312 #[cfg(target_os = "windows")]
313 while buffer.last().is_some_and(|&b| b == b'\n' || b == b'\r') {
314 buffer.pop();
315 }
316 if buffer.last() != Some(&b'\0') {
317 buffer.push(b'\0');
318 }
319
320 if let Err(err) = stream.write_all(&buffer) {
321 eprintln!("Error writing to socket: {}", err);
322 exit(1);
323 }
324
325 let mut response = Vec::new();
326 if let Err(err) = stream.read_to_end(&mut response) {
327 eprintln!("Error reading from socket: {}", err);
328 exit(1);
329 }
330
331 if let Err(err) = io::stdout().write_all(&response) {
332 eprintln!("Error writing to stdout: {}", err);
333 exit(1);
334 }
335}
336
337pub fn set_askpass_program(path: std::path::PathBuf) {
338 if ASKPASS_PROGRAM.set(path).is_err() {
339 debug_panic!("askpass program has already been set");
340 }
341}
342
343#[inline]
344#[cfg(not(target_os = "windows"))]
345fn generate_askpass_script(
346 shell_kind: ShellKind,
347 askpass_program: &std::path::Path,
348 askpass_socket: &std::path::Path,
349) -> Result<String> {
350 let askpass_program = shell_kind.prepend_command_prefix(
351 askpass_program
352 .to_str()
353 .context("Askpass program is on a non-utf8 path")?,
354 );
355 let askpass_program = shell_kind
356 .try_quote_prefix_aware(&askpass_program)
357 .context("Failed to shell-escape Askpass program path")?;
358 let askpass_socket = askpass_socket
359 .try_shell_safe(shell_kind)
360 .context("Failed to shell-escape Askpass socket path")?;
361 let print_args = "printf '%s\\0' \"$@\"";
362 let shebang = "#!/bin/sh";
363 Ok(format!(
364 "{shebang}\n{print_args} | {askpass_program} --askpass={askpass_socket} 2> /dev/null \n",
365 ))
366}
367
368#[inline]
369#[cfg(target_os = "windows")]
370fn generate_askpass_script(
371 shell_kind: ShellKind,
372 askpass_program: &std::path::Path,
373 askpass_socket: &std::path::Path,
374) -> Result<String> {
375 let askpass_program = shell_kind.prepend_command_prefix(
376 askpass_program
377 .to_str()
378 .context("Askpass program is on a non-utf8 path")?,
379 );
380 let askpass_program = shell_kind
381 .try_quote_prefix_aware(&askpass_program)
382 .context("Failed to shell-escape Askpass program path")?;
383 let askpass_socket = askpass_socket
384 .try_shell_safe(shell_kind)
385 .context("Failed to shell-escape Askpass socket path")?;
386 Ok(format!(
387 r#"
388 $ErrorActionPreference = 'Stop';
389 ($args -join [char]0) | {askpass_program} --askpass={askpass_socket} 2> $null
390 "#,
391 ))
392}