1use std::{ffi::OsStr, time::Duration};
2
3use anyhow::{Context as _, Result};
4use futures::channel::{mpsc, oneshot};
5use futures::{
6 AsyncBufReadExt as _, AsyncWriteExt as _, FutureExt as _, SinkExt, StreamExt, io::BufReader,
7 select_biased,
8};
9use gpui::{AsyncApp, BackgroundExecutor, Task};
10use smol::fs;
11use util::ResultExt as _;
12
13#[derive(PartialEq, Eq)]
14pub enum AskPassResult {
15 CancelledByUser,
16 Timedout,
17}
18
19pub struct AskPassDelegate {
20 tx: mpsc::UnboundedSender<(String, oneshot::Sender<String>)>,
21 _task: Task<()>,
22}
23
24impl AskPassDelegate {
25 pub fn new(
26 cx: &mut AsyncApp,
27 password_prompt: impl Fn(String, oneshot::Sender<String>, &mut AsyncApp) + Send + Sync + 'static,
28 ) -> Self {
29 let (tx, mut rx) = mpsc::unbounded::<(String, oneshot::Sender<String>)>();
30 let task = cx.spawn(async move |cx: &mut AsyncApp| {
31 while let Some((prompt, channel)) = rx.next().await {
32 password_prompt(prompt, channel, cx);
33 }
34 });
35 Self { tx, _task: task }
36 }
37
38 pub async fn ask_password(&mut self, prompt: String) -> Result<String> {
39 let (tx, rx) = oneshot::channel();
40 self.tx.send((prompt, tx)).await?;
41 Ok(rx.await?)
42 }
43}
44
45pub struct AskPassSession {
46 #[cfg(not(target_os = "windows"))]
47 script_path: std::path::PathBuf,
48 #[cfg(target_os = "windows")]
49 askpass_helper: String,
50 #[cfg(target_os = "windows")]
51 secret: std::sync::Arc<parking_lot::Mutex<String>>,
52 _askpass_task: Task<()>,
53 askpass_opened_rx: Option<oneshot::Receiver<()>>,
54 askpass_kill_master_rx: Option<oneshot::Receiver<()>>,
55}
56
57#[cfg(not(target_os = "windows"))]
58const ASKPASS_SCRIPT_NAME: &str = "askpass.sh";
59#[cfg(target_os = "windows")]
60const ASKPASS_SCRIPT_NAME: &str = "askpass.ps1";
61
62impl AskPassSession {
63 /// This will create a new AskPassSession.
64 /// You must retain this session until the master process exits.
65 #[must_use]
66 pub async fn new(executor: &BackgroundExecutor, mut delegate: AskPassDelegate) -> Result<Self> {
67 use net::async_net::UnixListener;
68 use util::fs::make_file_executable;
69
70 #[cfg(target_os = "windows")]
71 let secret = std::sync::Arc::new(parking_lot::Mutex::new(String::new()));
72 let temp_dir = tempfile::Builder::new().prefix("zed-askpass").tempdir()?;
73 let askpass_socket = temp_dir.path().join("askpass.sock");
74 let askpass_script_path = temp_dir.path().join(ASKPASS_SCRIPT_NAME);
75 let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
76 let listener = UnixListener::bind(&askpass_socket).context("creating askpass socket")?;
77 #[cfg(not(target_os = "windows"))]
78 let zed_path = util::get_shell_safe_zed_path()?;
79 #[cfg(target_os = "windows")]
80 let zed_path = std::env::current_exe()
81 .context("finding current executable path for use in askpass")?;
82
83 let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<()>();
84 let mut kill_tx = Some(askpass_kill_master_tx);
85
86 #[cfg(target_os = "windows")]
87 let askpass_secret = secret.clone();
88 let askpass_task = executor.spawn(async move {
89 let mut askpass_opened_tx = Some(askpass_opened_tx);
90
91 while let Ok((mut stream, _)) = listener.accept().await {
92 if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
93 askpass_opened_tx.send(()).ok();
94 }
95 let mut buffer = Vec::new();
96 let mut reader = BufReader::new(&mut stream);
97 if reader.read_until(b'\0', &mut buffer).await.is_err() {
98 buffer.clear();
99 }
100 let prompt = String::from_utf8_lossy(&buffer);
101 if let Some(password) = delegate
102 .ask_password(prompt.to_string())
103 .await
104 .context("getting askpass password")
105 .log_err()
106 {
107 stream.write_all(password.as_bytes()).await.log_err();
108 #[cfg(target_os = "windows")]
109 {
110 *askpass_secret.lock() = password;
111 }
112 } else {
113 if let Some(kill_tx) = kill_tx.take() {
114 kill_tx.send(()).log_err();
115 }
116 // note: we expect the caller to drop this task when it's done.
117 // We need to keep the stream open until the caller is done to avoid
118 // spurious errors from ssh.
119 std::future::pending::<()>().await;
120 drop(stream);
121 }
122 }
123 drop(temp_dir)
124 });
125
126 // Create an askpass script that communicates back to this process.
127 let askpass_script = generate_askpass_script(&zed_path, &askpass_socket);
128 fs::write(&askpass_script_path, askpass_script)
129 .await
130 .with_context(|| format!("creating askpass script at {askpass_script_path:?}"))?;
131 make_file_executable(&askpass_script_path).await?;
132 #[cfg(target_os = "windows")]
133 let askpass_helper = format!(
134 "powershell.exe -ExecutionPolicy Bypass -File {}",
135 askpass_script_path.display()
136 );
137
138 Ok(Self {
139 #[cfg(not(target_os = "windows"))]
140 script_path: askpass_script_path,
141
142 #[cfg(target_os = "windows")]
143 secret,
144 #[cfg(target_os = "windows")]
145 askpass_helper,
146
147 _askpass_task: askpass_task,
148 askpass_kill_master_rx: Some(askpass_kill_master_rx),
149 askpass_opened_rx: Some(askpass_opened_rx),
150 })
151 }
152
153 #[cfg(not(target_os = "windows"))]
154 pub fn script_path(&self) -> impl AsRef<OsStr> {
155 &self.script_path
156 }
157
158 #[cfg(target_os = "windows")]
159 pub fn script_path(&self) -> impl AsRef<OsStr> {
160 &self.askpass_helper
161 }
162
163 // This will run the askpass task forever, resolving as many authentication requests as needed.
164 // The caller is responsible for examining the result of their own commands and cancelling this
165 // future when this is no longer needed. Note that this can only be called once, but due to the
166 // drop order this takes an &mut, so you can `drop()` it after you're done with the master process.
167 pub async fn run(&mut self) -> AskPassResult {
168 // This is the default timeout setting used by VSCode.
169 let connection_timeout = Duration::from_secs(17);
170 let askpass_opened_rx = self.askpass_opened_rx.take().expect("Only call run once");
171 let askpass_kill_master_rx = self
172 .askpass_kill_master_rx
173 .take()
174 .expect("Only call run once");
175
176 select_biased! {
177 _ = askpass_opened_rx.fuse() => {
178 // Note: this await can only resolve after we are dropped.
179 askpass_kill_master_rx.await.ok();
180 AskPassResult::CancelledByUser
181 }
182
183 _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
184 AskPassResult::Timedout
185 }
186 }
187 }
188
189 /// This will return the password that was last set by the askpass script.
190 #[cfg(target_os = "windows")]
191 pub fn get_password(&self) -> String {
192 self.secret.lock().clone()
193 }
194}
195
196/// The main function for when Zed is running in netcat mode for use in askpass.
197/// Called from both the remote server binary and the zed binary in their respective main functions.
198pub fn main(socket: &str) {
199 use net::UnixStream;
200 use std::io::{self, Read, Write};
201 use std::process::exit;
202
203 let mut stream = match UnixStream::connect(socket) {
204 Ok(stream) => stream,
205 Err(err) => {
206 eprintln!("Error connecting to socket {}: {}", socket, err);
207 exit(1);
208 }
209 };
210
211 let mut buffer = Vec::new();
212 if let Err(err) = io::stdin().read_to_end(&mut buffer) {
213 eprintln!("Error reading from stdin: {}", err);
214 exit(1);
215 }
216
217 #[cfg(target_os = "windows")]
218 while buffer.last().is_some_and(|&b| b == b'\n' || b == b'\r') {
219 buffer.pop();
220 }
221 if buffer.last() != Some(&b'\0') {
222 buffer.push(b'\0');
223 }
224
225 if let Err(err) = stream.write_all(&buffer) {
226 eprintln!("Error writing to socket: {}", err);
227 exit(1);
228 }
229
230 let mut response = Vec::new();
231 if let Err(err) = stream.read_to_end(&mut response) {
232 eprintln!("Error reading from socket: {}", err);
233 exit(1);
234 }
235
236 if let Err(err) = io::stdout().write_all(&response) {
237 eprintln!("Error writing to stdout: {}", err);
238 exit(1);
239 }
240}
241
242#[inline]
243#[cfg(not(target_os = "windows"))]
244fn generate_askpass_script(zed_path: &str, askpass_socket: &std::path::Path) -> String {
245 format!(
246 "{shebang}\n{print_args} | {zed_exe} --askpass={askpass_socket} 2> /dev/null \n",
247 zed_exe = zed_path,
248 askpass_socket = askpass_socket.display(),
249 print_args = "printf '%s\\0' \"$@\"",
250 shebang = "#!/bin/sh",
251 )
252}
253
254#[inline]
255#[cfg(target_os = "windows")]
256fn generate_askpass_script(zed_path: &std::path::Path, askpass_socket: &std::path::Path) -> String {
257 format!(
258 r#"
259 $ErrorActionPreference = 'Stop';
260 ($args -join [char]0) | & "{zed_exe}" --askpass={askpass_socket} 2> $null
261 "#,
262 zed_exe = zed_path.display(),
263 askpass_socket = askpass_socket.display(),
264 )
265}