1mod encrypted_password;
2
3pub use encrypted_password::{EncryptedPassword, ProcessExt};
4
5#[cfg(target_os = "windows")]
6use std::sync::OnceLock;
7use std::{ffi::OsStr, time::Duration};
8
9use anyhow::{Context as _, Result};
10use futures::channel::{mpsc, oneshot};
11use futures::{
12 AsyncBufReadExt as _, AsyncWriteExt as _, FutureExt as _, SinkExt, StreamExt, io::BufReader,
13 select_biased,
14};
15use gpui::{AsyncApp, BackgroundExecutor, Task};
16use smol::fs;
17use util::ResultExt as _;
18
19use crate::encrypted_password::decrypt;
20
21#[derive(PartialEq, Eq)]
22pub enum AskPassResult {
23 CancelledByUser,
24 Timedout,
25}
26
27pub struct AskPassDelegate {
28 tx: mpsc::UnboundedSender<(String, oneshot::Sender<EncryptedPassword>)>,
29 _task: Task<()>,
30}
31
32impl AskPassDelegate {
33 pub fn new(
34 cx: &mut AsyncApp,
35 password_prompt: impl Fn(String, oneshot::Sender<EncryptedPassword>, &mut AsyncApp)
36 + Send
37 + Sync
38 + 'static,
39 ) -> Self {
40 let (tx, mut rx) = mpsc::unbounded::<(String, oneshot::Sender<_>)>();
41 let task = cx.spawn(async move |cx: &mut AsyncApp| {
42 while let Some((prompt, channel)) = rx.next().await {
43 password_prompt(prompt, channel, cx);
44 }
45 });
46 Self { tx, _task: task }
47 }
48
49 pub async fn ask_password(&mut self, prompt: String) -> Result<EncryptedPassword> {
50 let (tx, rx) = oneshot::channel();
51 self.tx.send((prompt, tx)).await?;
52 Ok(rx.await?)
53 }
54}
55
56pub struct AskPassSession {
57 #[cfg(not(target_os = "windows"))]
58 script_path: std::path::PathBuf,
59 #[cfg(target_os = "windows")]
60 askpass_helper: String,
61 #[cfg(target_os = "windows")]
62 secret: std::sync::Arc<OnceLock<EncryptedPassword>>,
63 _askpass_task: Task<()>,
64 askpass_opened_rx: Option<oneshot::Receiver<()>>,
65 askpass_kill_master_rx: Option<oneshot::Receiver<()>>,
66}
67
68#[cfg(not(target_os = "windows"))]
69const ASKPASS_SCRIPT_NAME: &str = "askpass.sh";
70#[cfg(target_os = "windows")]
71const ASKPASS_SCRIPT_NAME: &str = "askpass.ps1";
72
73impl AskPassSession {
74 /// This will create a new AskPassSession.
75 /// You must retain this session until the master process exits.
76 #[must_use]
77 pub async fn new(executor: &BackgroundExecutor, mut delegate: AskPassDelegate) -> Result<Self> {
78 use net::async_net::UnixListener;
79 use util::fs::make_file_executable;
80
81 #[cfg(target_os = "windows")]
82 let secret = std::sync::Arc::new(OnceLock::new());
83 let temp_dir = tempfile::Builder::new().prefix("zed-askpass").tempdir()?;
84 let askpass_socket = temp_dir.path().join("askpass.sock");
85 let askpass_script_path = temp_dir.path().join(ASKPASS_SCRIPT_NAME);
86 let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
87 let listener = UnixListener::bind(&askpass_socket).context("creating askpass socket")?;
88 let zed_cli_path =
89 util::get_shell_safe_zed_cli_path().context("getting zed-cli path for askpass")?;
90
91 let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<()>();
92 let mut kill_tx = Some(askpass_kill_master_tx);
93
94 #[cfg(target_os = "windows")]
95 let askpass_secret = secret.clone();
96 let askpass_task = executor.spawn(async move {
97 let mut askpass_opened_tx = Some(askpass_opened_tx);
98
99 while let Ok((mut stream, _)) = listener.accept().await {
100 if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
101 askpass_opened_tx.send(()).ok();
102 }
103 let mut buffer = Vec::new();
104 let mut reader = BufReader::new(&mut stream);
105 if reader.read_until(b'\0', &mut buffer).await.is_err() {
106 buffer.clear();
107 }
108 let prompt = String::from_utf8_lossy(&buffer);
109 if let Some(password) = delegate
110 .ask_password(prompt.to_string())
111 .await
112 .context("getting askpass password")
113 .log_err()
114 {
115 #[cfg(target_os = "windows")]
116 {
117 askpass_secret.get_or_init(|| password.clone());
118 }
119 if let Ok(decrypted) = decrypt(password) {
120 stream.write_all(decrypted.as_bytes()).await.log_err();
121 }
122 } else {
123 if let Some(kill_tx) = kill_tx.take() {
124 kill_tx.send(()).log_err();
125 }
126 // note: we expect the caller to drop this task when it's done.
127 // We need to keep the stream open until the caller is done to avoid
128 // spurious errors from ssh.
129 std::future::pending::<()>().await;
130 drop(stream);
131 }
132 }
133 drop(temp_dir)
134 });
135
136 // Create an askpass script that communicates back to this process.
137 let askpass_script = generate_askpass_script(&zed_cli_path, &askpass_socket);
138 fs::write(&askpass_script_path, askpass_script)
139 .await
140 .with_context(|| format!("creating askpass script at {askpass_script_path:?}"))?;
141 make_file_executable(&askpass_script_path).await?;
142 #[cfg(target_os = "windows")]
143 let askpass_helper = format!(
144 "powershell.exe -ExecutionPolicy Bypass -File {}",
145 askpass_script_path.display()
146 );
147
148 Ok(Self {
149 #[cfg(not(target_os = "windows"))]
150 script_path: askpass_script_path,
151
152 #[cfg(target_os = "windows")]
153 secret,
154 #[cfg(target_os = "windows")]
155 askpass_helper,
156
157 _askpass_task: askpass_task,
158 askpass_kill_master_rx: Some(askpass_kill_master_rx),
159 askpass_opened_rx: Some(askpass_opened_rx),
160 })
161 }
162
163 #[cfg(not(target_os = "windows"))]
164 pub fn script_path(&self) -> impl AsRef<OsStr> {
165 &self.script_path
166 }
167
168 #[cfg(target_os = "windows")]
169 pub fn script_path(&self) -> impl AsRef<OsStr> {
170 &self.askpass_helper
171 }
172
173 // This will run the askpass task forever, resolving as many authentication requests as needed.
174 // The caller is responsible for examining the result of their own commands and cancelling this
175 // future when this is no longer needed. Note that this can only be called once, but due to the
176 // drop order this takes an &mut, so you can `drop()` it after you're done with the master process.
177 pub async fn run(&mut self) -> AskPassResult {
178 // This is the default timeout setting used by VSCode.
179 let connection_timeout = Duration::from_secs(17);
180 let askpass_opened_rx = self.askpass_opened_rx.take().expect("Only call run once");
181 let askpass_kill_master_rx = self
182 .askpass_kill_master_rx
183 .take()
184 .expect("Only call run once");
185
186 select_biased! {
187 _ = askpass_opened_rx.fuse() => {
188 // Note: this await can only resolve after we are dropped.
189 askpass_kill_master_rx.await.ok();
190 AskPassResult::CancelledByUser
191 }
192
193 _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
194 AskPassResult::Timedout
195 }
196 }
197 }
198
199 /// This will return the password that was last set by the askpass script.
200 #[cfg(target_os = "windows")]
201 pub fn get_password(&self) -> Option<EncryptedPassword> {
202 self.secret.get().cloned()
203 }
204}
205
206/// The main function for when Zed is running in netcat mode for use in askpass.
207/// Called from both the remote server binary and the zed binary in their respective main functions.
208pub fn main(socket: &str) {
209 use net::UnixStream;
210 use std::io::{self, Read, Write};
211 use std::process::exit;
212
213 let mut stream = match UnixStream::connect(socket) {
214 Ok(stream) => stream,
215 Err(err) => {
216 eprintln!("Error connecting to socket {}: {}", socket, err);
217 exit(1);
218 }
219 };
220
221 let mut buffer = Vec::new();
222 if let Err(err) = io::stdin().read_to_end(&mut buffer) {
223 eprintln!("Error reading from stdin: {}", err);
224 exit(1);
225 }
226
227 #[cfg(target_os = "windows")]
228 while buffer.last().is_some_and(|&b| b == b'\n' || b == b'\r') {
229 buffer.pop();
230 }
231 if buffer.last() != Some(&b'\0') {
232 buffer.push(b'\0');
233 }
234
235 if let Err(err) = stream.write_all(&buffer) {
236 eprintln!("Error writing to socket: {}", err);
237 exit(1);
238 }
239
240 let mut response = Vec::new();
241 if let Err(err) = stream.read_to_end(&mut response) {
242 eprintln!("Error reading from socket: {}", err);
243 exit(1);
244 }
245
246 if let Err(err) = io::stdout().write_all(&response) {
247 eprintln!("Error writing to stdout: {}", err);
248 exit(1);
249 }
250}
251
252#[inline]
253#[cfg(not(target_os = "windows"))]
254fn generate_askpass_script(zed_cli_path: &str, askpass_socket: &std::path::Path) -> String {
255 format!(
256 "{shebang}\n{print_args} | {zed_cli} --askpass={askpass_socket} 2> /dev/null \n",
257 zed_cli = zed_cli_path,
258 askpass_socket = askpass_socket.display(),
259 print_args = "printf '%s\\0' \"$@\"",
260 shebang = "#!/bin/sh",
261 )
262}
263
264#[inline]
265#[cfg(target_os = "windows")]
266fn generate_askpass_script(zed_cli_path: &str, askpass_socket: &std::path::Path) -> String {
267 format!(
268 r#"
269 $ErrorActionPreference = 'Stop';
270 ($args -join [char]0) | & "{zed_cli}" --askpass={askpass_socket} 2> $null
271 "#,
272 zed_cli = zed_cli_path,
273 askpass_socket = askpass_socket.display(),
274 )
275}