1mod encrypted_password;
2
3pub use encrypted_password::{EncryptedPassword, ProcessExt};
4use util::paths::PathExt;
5
6use std::sync::OnceLock;
7use std::{ffi::OsStr, time::Duration};
8
9use anyhow::{Context as _, Result};
10use futures::channel::{mpsc, oneshot};
11use futures::{
12 AsyncBufReadExt as _, AsyncWriteExt as _, FutureExt as _, SinkExt, StreamExt, io::BufReader,
13 select_biased,
14};
15use gpui::{AsyncApp, BackgroundExecutor, Task};
16use smol::fs;
17use util::{ResultExt as _, debug_panic};
18
19use crate::encrypted_password::decrypt;
20
21/// Path to the program used for askpass
22///
23/// On Unix and remote servers, this defaults to the current executable
24/// On Windows, this is set to the CLI variant of zed
25static ASKPASS_PROGRAM: OnceLock<std::path::PathBuf> = OnceLock::new();
26
27#[derive(PartialEq, Eq)]
28pub enum AskPassResult {
29 CancelledByUser,
30 Timedout,
31}
32
33pub struct AskPassDelegate {
34 tx: mpsc::UnboundedSender<(String, oneshot::Sender<EncryptedPassword>)>,
35 _task: Task<()>,
36}
37
38impl AskPassDelegate {
39 pub fn new(
40 cx: &mut AsyncApp,
41 password_prompt: impl Fn(String, oneshot::Sender<EncryptedPassword>, &mut AsyncApp)
42 + Send
43 + Sync
44 + 'static,
45 ) -> Self {
46 let (tx, mut rx) = mpsc::unbounded::<(String, oneshot::Sender<_>)>();
47 let task = cx.spawn(async move |cx: &mut AsyncApp| {
48 while let Some((prompt, channel)) = rx.next().await {
49 password_prompt(prompt, channel, cx);
50 }
51 });
52 Self { tx, _task: task }
53 }
54
55 pub async fn ask_password(&mut self, prompt: String) -> Result<EncryptedPassword> {
56 let (tx, rx) = oneshot::channel();
57 self.tx.send((prompt, tx)).await?;
58 Ok(rx.await?)
59 }
60}
61
62pub struct AskPassSession {
63 #[cfg(not(target_os = "windows"))]
64 script_path: std::path::PathBuf,
65 #[cfg(target_os = "windows")]
66 askpass_helper: String,
67 #[cfg(target_os = "windows")]
68 secret: std::sync::Arc<OnceLock<EncryptedPassword>>,
69 _askpass_task: Task<()>,
70 askpass_opened_rx: Option<oneshot::Receiver<()>>,
71 askpass_kill_master_rx: Option<oneshot::Receiver<()>>,
72}
73
74#[cfg(not(target_os = "windows"))]
75const ASKPASS_SCRIPT_NAME: &str = "askpass.sh";
76#[cfg(target_os = "windows")]
77const ASKPASS_SCRIPT_NAME: &str = "askpass.ps1";
78
79impl AskPassSession {
80 /// This will create a new AskPassSession.
81 /// You must retain this session until the master process exits.
82 #[must_use]
83 pub async fn new(executor: &BackgroundExecutor, mut delegate: AskPassDelegate) -> Result<Self> {
84 use net::async_net::UnixListener;
85 use util::fs::make_file_executable;
86
87 #[cfg(target_os = "windows")]
88 let secret = std::sync::Arc::new(OnceLock::new());
89 let temp_dir = tempfile::Builder::new().prefix("zed-askpass").tempdir()?;
90 let askpass_socket = temp_dir.path().join("askpass.sock");
91 let askpass_script_path = temp_dir.path().join(ASKPASS_SCRIPT_NAME);
92 let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
93 let listener = UnixListener::bind(&askpass_socket).context("creating askpass socket")?;
94
95 let current_exec =
96 std::env::current_exe().context("Failed to determine current zed executable path.")?;
97
98 let askpass_program = ASKPASS_PROGRAM
99 .get_or_init(|| current_exec)
100 .try_shell_safe()
101 .context("Failed to shell-escape Askpass program path.")?
102 .to_string();
103
104 let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<()>();
105 let mut kill_tx = Some(askpass_kill_master_tx);
106
107 #[cfg(target_os = "windows")]
108 let askpass_secret = secret.clone();
109 let askpass_task = executor.spawn(async move {
110 let mut askpass_opened_tx = Some(askpass_opened_tx);
111
112 while let Ok((mut stream, _)) = listener.accept().await {
113 if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
114 askpass_opened_tx.send(()).ok();
115 }
116 let mut buffer = Vec::new();
117 let mut reader = BufReader::new(&mut stream);
118 if reader.read_until(b'\0', &mut buffer).await.is_err() {
119 buffer.clear();
120 }
121 let prompt = String::from_utf8_lossy(&buffer);
122 if let Some(password) = delegate
123 .ask_password(prompt.to_string())
124 .await
125 .context("getting askpass password")
126 .log_err()
127 {
128 #[cfg(target_os = "windows")]
129 {
130 askpass_secret.get_or_init(|| password.clone());
131 }
132 if let Ok(decrypted) = decrypt(password) {
133 stream.write_all(decrypted.as_bytes()).await.log_err();
134 }
135 } else {
136 if let Some(kill_tx) = kill_tx.take() {
137 kill_tx.send(()).log_err();
138 }
139 // note: we expect the caller to drop this task when it's done.
140 // We need to keep the stream open until the caller is done to avoid
141 // spurious errors from ssh.
142 std::future::pending::<()>().await;
143 drop(stream);
144 }
145 }
146 drop(temp_dir)
147 });
148
149 // Create an askpass script that communicates back to this process.
150 let askpass_script = generate_askpass_script(&askpass_program, &askpass_socket);
151 fs::write(&askpass_script_path, askpass_script)
152 .await
153 .with_context(|| format!("creating askpass script at {askpass_script_path:?}"))?;
154 make_file_executable(&askpass_script_path).await?;
155 #[cfg(target_os = "windows")]
156 let askpass_helper = format!(
157 "powershell.exe -ExecutionPolicy Bypass -File {}",
158 askpass_script_path.display()
159 );
160
161 Ok(Self {
162 #[cfg(not(target_os = "windows"))]
163 script_path: askpass_script_path,
164
165 #[cfg(target_os = "windows")]
166 secret,
167 #[cfg(target_os = "windows")]
168 askpass_helper,
169
170 _askpass_task: askpass_task,
171 askpass_kill_master_rx: Some(askpass_kill_master_rx),
172 askpass_opened_rx: Some(askpass_opened_rx),
173 })
174 }
175
176 #[cfg(not(target_os = "windows"))]
177 pub fn script_path(&self) -> impl AsRef<OsStr> {
178 &self.script_path
179 }
180
181 #[cfg(target_os = "windows")]
182 pub fn script_path(&self) -> impl AsRef<OsStr> {
183 &self.askpass_helper
184 }
185
186 // This will run the askpass task forever, resolving as many authentication requests as needed.
187 // The caller is responsible for examining the result of their own commands and cancelling this
188 // future when this is no longer needed. Note that this can only be called once, but due to the
189 // drop order this takes an &mut, so you can `drop()` it after you're done with the master process.
190 pub async fn run(&mut self) -> AskPassResult {
191 // This is the default timeout setting used by VSCode.
192 let connection_timeout = Duration::from_secs(17);
193 let askpass_opened_rx = self.askpass_opened_rx.take().expect("Only call run once");
194 let askpass_kill_master_rx = self
195 .askpass_kill_master_rx
196 .take()
197 .expect("Only call run once");
198
199 select_biased! {
200 _ = askpass_opened_rx.fuse() => {
201 // Note: this await can only resolve after we are dropped.
202 askpass_kill_master_rx.await.ok();
203 AskPassResult::CancelledByUser
204 }
205
206 _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
207 AskPassResult::Timedout
208 }
209 }
210 }
211
212 /// This will return the password that was last set by the askpass script.
213 #[cfg(target_os = "windows")]
214 pub fn get_password(&self) -> Option<EncryptedPassword> {
215 self.secret.get().cloned()
216 }
217}
218
219/// The main function for when Zed is running in netcat mode for use in askpass.
220/// Called from both the remote server binary and the zed binary in their respective main functions.
221pub fn main(socket: &str) {
222 use net::UnixStream;
223 use std::io::{self, Read, Write};
224 use std::process::exit;
225
226 let mut stream = match UnixStream::connect(socket) {
227 Ok(stream) => stream,
228 Err(err) => {
229 eprintln!("Error connecting to socket {}: {}", socket, err);
230 exit(1);
231 }
232 };
233
234 let mut buffer = Vec::new();
235 if let Err(err) = io::stdin().read_to_end(&mut buffer) {
236 eprintln!("Error reading from stdin: {}", err);
237 exit(1);
238 }
239
240 #[cfg(target_os = "windows")]
241 while buffer.last().is_some_and(|&b| b == b'\n' || b == b'\r') {
242 buffer.pop();
243 }
244 if buffer.last() != Some(&b'\0') {
245 buffer.push(b'\0');
246 }
247
248 if let Err(err) = stream.write_all(&buffer) {
249 eprintln!("Error writing to socket: {}", err);
250 exit(1);
251 }
252
253 let mut response = Vec::new();
254 if let Err(err) = stream.read_to_end(&mut response) {
255 eprintln!("Error reading from socket: {}", err);
256 exit(1);
257 }
258
259 if let Err(err) = io::stdout().write_all(&response) {
260 eprintln!("Error writing to stdout: {}", err);
261 exit(1);
262 }
263}
264
265pub fn set_askpass_program(path: std::path::PathBuf) {
266 if ASKPASS_PROGRAM.set(path).is_err() {
267 debug_panic!("askpass program has already been set");
268 }
269}
270
271#[inline]
272#[cfg(not(target_os = "windows"))]
273fn generate_askpass_script(askpass_program: &str, askpass_socket: &std::path::Path) -> String {
274 format!(
275 "{shebang}\n{print_args} | {askpass_program} --askpass={askpass_socket} 2> /dev/null \n",
276 askpass_socket = askpass_socket.display(),
277 print_args = "printf '%s\\0' \"$@\"",
278 shebang = "#!/bin/sh",
279 )
280}
281
282#[inline]
283#[cfg(target_os = "windows")]
284fn generate_askpass_script(askpass_program: &str, askpass_socket: &std::path::Path) -> String {
285 format!(
286 r#"
287 $ErrorActionPreference = 'Stop';
288 ($args -join [char]0) | & "{askpass_program}" --askpass={askpass_socket} 2> $null
289 "#,
290 askpass_socket = askpass_socket.display(),
291 )
292}