1mod encrypted_password;
2
3pub use encrypted_password::{EncryptedPassword, IKnowWhatIAmDoingAndIHaveReadTheDocs};
4
5use net::async_net::UnixListener;
6use smol::lock::Mutex;
7use util::fs::make_file_executable;
8
9use std::ffi::OsStr;
10use std::ops::ControlFlow;
11use std::sync::Arc;
12use std::sync::OnceLock;
13use std::time::Duration;
14
15use anyhow::{Context as _, Result};
16use futures::channel::{mpsc, oneshot};
17use futures::{
18 AsyncBufReadExt as _, AsyncWriteExt as _, FutureExt as _, SinkExt, StreamExt, io::BufReader,
19 select_biased,
20};
21use gpui::{AsyncApp, BackgroundExecutor, Task};
22use smol::fs;
23use util::{ResultExt as _, debug_panic, maybe, paths::PathExt};
24
25/// Path to the program used for askpass
26///
27/// On Unix and remote servers, this defaults to the current executable
28/// On Windows, this is set to the CLI variant of zed
29static ASKPASS_PROGRAM: OnceLock<std::path::PathBuf> = OnceLock::new();
30
31#[derive(PartialEq, Eq)]
32pub enum AskPassResult {
33 CancelledByUser,
34 Timedout,
35}
36
37pub struct AskPassDelegate {
38 tx: mpsc::UnboundedSender<(String, oneshot::Sender<EncryptedPassword>)>,
39 executor: BackgroundExecutor,
40 _task: Task<()>,
41}
42
43impl AskPassDelegate {
44 pub fn new(
45 cx: &mut AsyncApp,
46 password_prompt: impl Fn(String, oneshot::Sender<EncryptedPassword>, &mut AsyncApp)
47 + Send
48 + Sync
49 + 'static,
50 ) -> Self {
51 let (tx, mut rx) = mpsc::unbounded::<(String, oneshot::Sender<_>)>();
52 let task = cx.spawn(async move |cx: &mut AsyncApp| {
53 while let Some((prompt, channel)) = rx.next().await {
54 password_prompt(prompt, channel, cx);
55 }
56 });
57 Self {
58 tx,
59 _task: task,
60 executor: cx.background_executor().clone(),
61 }
62 }
63
64 pub fn ask_password(&mut self, prompt: String) -> Task<Option<EncryptedPassword>> {
65 let mut this_tx = self.tx.clone();
66 self.executor.spawn(async move {
67 let (tx, rx) = oneshot::channel();
68 this_tx.send((prompt, tx)).await.ok()?;
69 rx.await.ok()
70 })
71 }
72}
73
74pub struct AskPassSession {
75 #[cfg(target_os = "windows")]
76 secret: std::sync::Arc<OnceLock<EncryptedPassword>>,
77 askpass_task: PasswordProxy,
78 askpass_opened_rx: Option<oneshot::Receiver<()>>,
79 askpass_kill_master_rx: Option<oneshot::Receiver<()>>,
80}
81
82#[cfg(not(target_os = "windows"))]
83const ASKPASS_SCRIPT_NAME: &str = "askpass.sh";
84#[cfg(target_os = "windows")]
85const ASKPASS_SCRIPT_NAME: &str = "askpass.ps1";
86
87impl AskPassSession {
88 /// This will create a new AskPassSession.
89 /// You must retain this session until the master process exits.
90 #[must_use]
91 pub async fn new(executor: &BackgroundExecutor, mut delegate: AskPassDelegate) -> Result<Self> {
92 #[cfg(target_os = "windows")]
93 let secret = std::sync::Arc::new(OnceLock::new());
94 let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
95
96 let askpass_opened_tx = Arc::new(Mutex::new(Some(askpass_opened_tx)));
97
98 let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<()>();
99 let kill_tx = Arc::new(Mutex::new(Some(askpass_kill_master_tx)));
100
101 #[cfg(target_os = "windows")]
102 let askpass_secret = secret.clone();
103 let get_password = {
104 let executor = executor.clone();
105
106 move |prompt| {
107 let prompt = delegate.ask_password(prompt);
108 let kill_tx = kill_tx.clone();
109 let askpass_opened_tx = askpass_opened_tx.clone();
110 #[cfg(target_os = "windows")]
111 let askpass_secret = askpass_secret.clone();
112 executor.spawn(async move {
113 if let Some(askpass_opened_tx) = askpass_opened_tx.lock().await.take() {
114 askpass_opened_tx.send(()).ok();
115 }
116 if let Some(password) = prompt.await {
117 #[cfg(target_os = "windows")]
118 {
119 _ = askpass_secret.set(password.clone());
120 }
121 ControlFlow::Continue(Ok(password))
122 } else {
123 if let Some(kill_tx) = kill_tx.lock().await.take() {
124 kill_tx.send(()).log_err();
125 }
126 ControlFlow::Break(())
127 }
128 })
129 }
130 };
131 let askpass_task = PasswordProxy::new(get_password, executor.clone()).await?;
132
133 Ok(Self {
134 #[cfg(target_os = "windows")]
135 secret,
136
137 askpass_task,
138 askpass_kill_master_rx: Some(askpass_kill_master_rx),
139 askpass_opened_rx: Some(askpass_opened_rx),
140 })
141 }
142
143 // This will run the askpass task forever, resolving as many authentication requests as needed.
144 // The caller is responsible for examining the result of their own commands and cancelling this
145 // future when this is no longer needed. Note that this can only be called once, but due to the
146 // drop order this takes an &mut, so you can `drop()` it after you're done with the master process.
147 pub async fn run(&mut self) -> AskPassResult {
148 // This is the default timeout setting used by VSCode.
149 let connection_timeout = Duration::from_secs(17);
150 let askpass_opened_rx = self.askpass_opened_rx.take().expect("Only call run once");
151 let askpass_kill_master_rx = self
152 .askpass_kill_master_rx
153 .take()
154 .expect("Only call run once");
155
156 select_biased! {
157 _ = askpass_opened_rx.fuse() => {
158 // Note: this await can only resolve after we are dropped.
159 askpass_kill_master_rx.await.ok();
160 AskPassResult::CancelledByUser
161 }
162
163 _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
164 AskPassResult::Timedout
165 }
166 }
167 }
168
169 /// This will return the password that was last set by the askpass script.
170 #[cfg(target_os = "windows")]
171 pub fn get_password(&self) -> Option<EncryptedPassword> {
172 self.secret.get().cloned()
173 }
174
175 pub fn script_path(&self) -> impl AsRef<OsStr> {
176 self.askpass_task.script_path()
177 }
178}
179
180pub struct PasswordProxy {
181 _task: Task<()>,
182 #[cfg(not(target_os = "windows"))]
183 askpass_script_path: std::path::PathBuf,
184 #[cfg(target_os = "windows")]
185 askpass_helper: String,
186}
187
188impl PasswordProxy {
189 pub async fn new(
190 mut get_password: impl FnMut(String) -> Task<ControlFlow<(), Result<EncryptedPassword>>>
191 + 'static
192 + Send
193 + Sync,
194 executor: BackgroundExecutor,
195 ) -> Result<Self> {
196 let temp_dir = tempfile::Builder::new().prefix("zed-askpass").tempdir()?;
197 let askpass_socket = temp_dir.path().join("askpass.sock");
198 let askpass_script_path = temp_dir.path().join(ASKPASS_SCRIPT_NAME);
199 let current_exec =
200 std::env::current_exe().context("Failed to determine current zed executable path.")?;
201
202 let askpass_program = ASKPASS_PROGRAM
203 .get_or_init(|| current_exec)
204 .try_shell_safe()
205 .context("Failed to shell-escape Askpass program path.")?
206 .to_string();
207 // Create an askpass script that communicates back to this process.
208 let askpass_script = generate_askpass_script(&askpass_program, &askpass_socket);
209 let _task = executor.spawn(async move {
210 maybe!(async move {
211 let listener =
212 UnixListener::bind(&askpass_socket).context("creating askpass socket")?;
213
214 while let Ok((mut stream, _)) = listener.accept().await {
215 let mut buffer = Vec::new();
216 let mut reader = BufReader::new(&mut stream);
217 if reader.read_until(b'\0', &mut buffer).await.is_err() {
218 buffer.clear();
219 }
220 let prompt = String::from_utf8_lossy(&buffer).into_owned();
221 let password = get_password(prompt).await;
222 match password {
223 ControlFlow::Continue(password) => {
224 if let Ok(password) = password
225 && let Ok(decrypted) =
226 password.decrypt(IKnowWhatIAmDoingAndIHaveReadTheDocs)
227 {
228 stream.write_all(decrypted.as_bytes()).await.log_err();
229 }
230 }
231 ControlFlow::Break(()) => {
232 // note: we expect the caller to drop this task when it's done.
233 // We need to keep the stream open until the caller is done to avoid
234 // spurious errors from ssh.
235 std::future::pending::<()>().await;
236 drop(stream);
237 }
238 }
239 }
240 drop(temp_dir);
241 Result::<_, anyhow::Error>::Ok(())
242 })
243 .await
244 .log_err();
245 });
246
247 fs::write(&askpass_script_path, askpass_script)
248 .await
249 .with_context(|| format!("creating askpass script at {askpass_script_path:?}"))?;
250 make_file_executable(&askpass_script_path).await?;
251 #[cfg(target_os = "windows")]
252 let askpass_helper = format!(
253 "powershell.exe -ExecutionPolicy Bypass -File {}",
254 askpass_script_path.display()
255 );
256
257 Ok(Self {
258 _task,
259 #[cfg(not(target_os = "windows"))]
260 askpass_script_path,
261 #[cfg(target_os = "windows")]
262 askpass_helper,
263 })
264 }
265
266 pub fn script_path(&self) -> impl AsRef<OsStr> {
267 #[cfg(not(target_os = "windows"))]
268 {
269 &self.askpass_script_path
270 }
271 #[cfg(target_os = "windows")]
272 {
273 &self.askpass_helper
274 }
275 }
276}
277/// The main function for when Zed is running in netcat mode for use in askpass.
278/// Called from both the remote server binary and the zed binary in their respective main functions.
279pub fn main(socket: &str) {
280 use net::UnixStream;
281 use std::io::{self, Read, Write};
282 use std::process::exit;
283
284 let mut stream = match UnixStream::connect(socket) {
285 Ok(stream) => stream,
286 Err(err) => {
287 eprintln!("Error connecting to socket {}: {}", socket, err);
288 exit(1);
289 }
290 };
291
292 let mut buffer = Vec::new();
293 if let Err(err) = io::stdin().read_to_end(&mut buffer) {
294 eprintln!("Error reading from stdin: {}", err);
295 exit(1);
296 }
297
298 #[cfg(target_os = "windows")]
299 while buffer.last().is_some_and(|&b| b == b'\n' || b == b'\r') {
300 buffer.pop();
301 }
302 if buffer.last() != Some(&b'\0') {
303 buffer.push(b'\0');
304 }
305
306 if let Err(err) = stream.write_all(&buffer) {
307 eprintln!("Error writing to socket: {}", err);
308 exit(1);
309 }
310
311 let mut response = Vec::new();
312 if let Err(err) = stream.read_to_end(&mut response) {
313 eprintln!("Error reading from socket: {}", err);
314 exit(1);
315 }
316
317 if let Err(err) = io::stdout().write_all(&response) {
318 eprintln!("Error writing to stdout: {}", err);
319 exit(1);
320 }
321}
322
323pub fn set_askpass_program(path: std::path::PathBuf) {
324 if ASKPASS_PROGRAM.set(path).is_err() {
325 debug_panic!("askpass program has already been set");
326 }
327}
328
329#[inline]
330#[cfg(not(target_os = "windows"))]
331fn generate_askpass_script(askpass_program: &str, askpass_socket: &std::path::Path) -> String {
332 format!(
333 "{shebang}\n{print_args} | {askpass_program} --askpass={askpass_socket} 2> /dev/null \n",
334 askpass_socket = askpass_socket.display(),
335 print_args = "printf '%s\\0' \"$@\"",
336 shebang = "#!/bin/sh",
337 )
338}
339
340#[inline]
341#[cfg(target_os = "windows")]
342fn generate_askpass_script(askpass_program: &str, askpass_socket: &std::path::Path) -> String {
343 format!(
344 r#"
345 $ErrorActionPreference = 'Stop';
346 ($args -join [char]0) | & "{askpass_program}" --askpass={askpass_socket} 2> $null
347 "#,
348 askpass_socket = askpass_socket.display(),
349 )
350}