1use std::path::{Path, PathBuf};
2use std::time::Duration;
3
4#[cfg(unix)]
5use anyhow::Context as _;
6use futures::channel::{mpsc, oneshot};
7#[cfg(unix)]
8use futures::{AsyncBufReadExt as _, io::BufReader};
9#[cfg(unix)]
10use futures::{AsyncWriteExt as _, FutureExt as _, select_biased};
11use futures::{SinkExt, StreamExt};
12use gpui::{AsyncApp, BackgroundExecutor, Task};
13#[cfg(unix)]
14use smol::fs;
15#[cfg(unix)]
16use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
17#[cfg(unix)]
18use util::ResultExt as _;
19
20#[derive(PartialEq, Eq)]
21pub enum AskPassResult {
22 CancelledByUser,
23 Timedout,
24}
25
26pub struct AskPassDelegate {
27 tx: mpsc::UnboundedSender<(String, oneshot::Sender<String>)>,
28 _task: Task<()>,
29}
30
31impl AskPassDelegate {
32 pub fn new(
33 cx: &mut AsyncApp,
34 password_prompt: impl Fn(String, oneshot::Sender<String>, &mut AsyncApp) + Send + Sync + 'static,
35 ) -> Self {
36 let (tx, mut rx) = mpsc::unbounded::<(String, oneshot::Sender<String>)>();
37 let task = cx.spawn(async move |cx: &mut AsyncApp| {
38 while let Some((prompt, channel)) = rx.next().await {
39 password_prompt(prompt, channel, cx);
40 }
41 });
42 Self { tx, _task: task }
43 }
44
45 pub async fn ask_password(&mut self, prompt: String) -> anyhow::Result<String> {
46 let (tx, rx) = oneshot::channel();
47 self.tx.send((prompt, tx)).await?;
48 Ok(rx.await?)
49 }
50}
51
52#[cfg(unix)]
53pub struct AskPassSession {
54 script_path: PathBuf,
55 _askpass_task: Task<()>,
56 askpass_opened_rx: Option<oneshot::Receiver<()>>,
57 askpass_kill_master_rx: Option<oneshot::Receiver<()>>,
58}
59
60#[cfg(unix)]
61impl AskPassSession {
62 /// This will create a new AskPassSession.
63 /// You must retain this session until the master process exits.
64 #[must_use]
65 pub async fn new(
66 executor: &BackgroundExecutor,
67 mut delegate: AskPassDelegate,
68 ) -> anyhow::Result<Self> {
69 let temp_dir = tempfile::Builder::new().prefix("zed-askpass").tempdir()?;
70 let askpass_socket = temp_dir.path().join("askpass.sock");
71 let askpass_script_path = temp_dir.path().join("askpass.sh");
72 let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
73 let listener =
74 UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
75 let zed_path = std::env::current_exe()
76 .context("Failed to figure out current executable path for use in askpass")?;
77
78 let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<()>();
79 let mut kill_tx = Some(askpass_kill_master_tx);
80
81 let askpass_task = executor.spawn(async move {
82 let mut askpass_opened_tx = Some(askpass_opened_tx);
83
84 while let Ok((mut stream, _)) = listener.accept().await {
85 if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
86 askpass_opened_tx.send(()).ok();
87 }
88 let mut buffer = Vec::new();
89 let mut reader = BufReader::new(&mut stream);
90 if reader.read_until(b'\0', &mut buffer).await.is_err() {
91 buffer.clear();
92 }
93 let prompt = String::from_utf8_lossy(&buffer);
94 if let Some(password) = delegate
95 .ask_password(prompt.to_string())
96 .await
97 .context("failed to get askpass password")
98 .log_err()
99 {
100 stream.write_all(password.as_bytes()).await.log_err();
101 } else {
102 if let Some(kill_tx) = kill_tx.take() {
103 kill_tx.send(()).log_err();
104 }
105 // note: we expect the caller to drop this task when it's done.
106 // We need to keep the stream open until the caller is done to avoid
107 // spurious errors from ssh.
108 std::future::pending::<()>().await;
109 drop(stream);
110 }
111 }
112 drop(temp_dir)
113 });
114
115 // Create an askpass script that communicates back to this process.
116 let askpass_script = format!(
117 "{shebang}\n{print_args} | {zed_exe} --askpass={askpass_socket} 2> /dev/null \n",
118 zed_exe = zed_path.display(),
119 askpass_socket = askpass_socket.display(),
120 print_args = "printf '%s\\0' \"$@\"",
121 shebang = "#!/bin/sh",
122 );
123 fs::write(&askpass_script_path, askpass_script).await?;
124 fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
125
126 Ok(Self {
127 script_path: askpass_script_path,
128 _askpass_task: askpass_task,
129 askpass_kill_master_rx: Some(askpass_kill_master_rx),
130 askpass_opened_rx: Some(askpass_opened_rx),
131 })
132 }
133
134 pub fn script_path(&self) -> &Path {
135 &self.script_path
136 }
137
138 // This will run the askpass task forever, resolving as many authentication requests as needed.
139 // The caller is responsible for examining the result of their own commands and cancelling this
140 // future when this is no longer needed. Note that this can only be called once, but due to the
141 // drop order this takes an &mut, so you can `drop()` it after you're done with the master process.
142 pub async fn run(&mut self) -> AskPassResult {
143 let connection_timeout = Duration::from_secs(10);
144 let askpass_opened_rx = self.askpass_opened_rx.take().expect("Only call run once");
145 let askpass_kill_master_rx = self
146 .askpass_kill_master_rx
147 .take()
148 .expect("Only call run once");
149
150 select_biased! {
151 _ = askpass_opened_rx.fuse() => {
152 // Note: this await can only resolve after we are dropped.
153 askpass_kill_master_rx.await.ok();
154 return AskPassResult::CancelledByUser
155 }
156
157 _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
158 return AskPassResult::Timedout
159 }
160 }
161 }
162}
163
164/// The main function for when Zed is running in netcat mode for use in askpass.
165/// Called from both the remote server binary and the zed binary in their respective main functions.
166#[cfg(unix)]
167pub fn main(socket: &str) {
168 use std::io::{self, Read, Write};
169 use std::os::unix::net::UnixStream;
170 use std::process::exit;
171
172 let mut stream = match UnixStream::connect(socket) {
173 Ok(stream) => stream,
174 Err(err) => {
175 eprintln!("Error connecting to socket {}: {}", socket, err);
176 exit(1);
177 }
178 };
179
180 let mut buffer = Vec::new();
181 if let Err(err) = io::stdin().read_to_end(&mut buffer) {
182 eprintln!("Error reading from stdin: {}", err);
183 exit(1);
184 }
185
186 if buffer.last() != Some(&b'\0') {
187 buffer.push(b'\0');
188 }
189
190 if let Err(err) = stream.write_all(&buffer) {
191 eprintln!("Error writing to socket: {}", err);
192 exit(1);
193 }
194
195 let mut response = Vec::new();
196 if let Err(err) = stream.read_to_end(&mut response) {
197 eprintln!("Error reading from socket: {}", err);
198 exit(1);
199 }
200
201 if let Err(err) = io::stdout().write_all(&response) {
202 eprintln!("Error writing to stdout: {}", err);
203 exit(1);
204 }
205}
206#[cfg(not(unix))]
207pub fn main(_socket: &str) {}
208
209#[cfg(not(unix))]
210pub struct AskPassSession {
211 path: PathBuf,
212}
213
214#[cfg(not(unix))]
215impl AskPassSession {
216 pub async fn new(_: &BackgroundExecutor, _: AskPassDelegate) -> anyhow::Result<Self> {
217 Ok(Self {
218 path: PathBuf::new(),
219 })
220 }
221
222 pub fn script_path(&self) -> &Path {
223 &self.path
224 }
225
226 pub async fn run(&mut self) -> AskPassResult {
227 futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))).await;
228 AskPassResult::Timedout
229 }
230}