1use std::path::{Path, PathBuf};
2use std::time::Duration;
3
4#[cfg(unix)]
5use anyhow::Context as _;
6use futures::channel::{mpsc, oneshot};
7#[cfg(unix)]
8use futures::{AsyncBufReadExt as _, io::BufReader};
9#[cfg(unix)]
10use futures::{AsyncWriteExt as _, FutureExt as _, select_biased};
11use futures::{SinkExt, StreamExt};
12use gpui::{AsyncApp, BackgroundExecutor, Task};
13#[cfg(unix)]
14use smol::fs;
15#[cfg(unix)]
16use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
17#[cfg(unix)]
18use util::ResultExt as _;
19#[cfg(unix)]
20use util::get_shell_safe_zed_path;
21
22#[derive(PartialEq, Eq)]
23pub enum AskPassResult {
24 CancelledByUser,
25 Timedout,
26}
27
28pub struct AskPassDelegate {
29 tx: mpsc::UnboundedSender<(String, oneshot::Sender<String>)>,
30 _task: Task<()>,
31}
32
33impl AskPassDelegate {
34 pub fn new(
35 cx: &mut AsyncApp,
36 password_prompt: impl Fn(String, oneshot::Sender<String>, &mut AsyncApp) + Send + Sync + 'static,
37 ) -> Self {
38 let (tx, mut rx) = mpsc::unbounded::<(String, oneshot::Sender<String>)>();
39 let task = cx.spawn(async move |cx: &mut AsyncApp| {
40 while let Some((prompt, channel)) = rx.next().await {
41 password_prompt(prompt, channel, cx);
42 }
43 });
44 Self { tx, _task: task }
45 }
46
47 pub async fn ask_password(&mut self, prompt: String) -> anyhow::Result<String> {
48 let (tx, rx) = oneshot::channel();
49 self.tx.send((prompt, tx)).await?;
50 Ok(rx.await?)
51 }
52}
53
54#[cfg(unix)]
55pub struct AskPassSession {
56 script_path: PathBuf,
57 _askpass_task: Task<()>,
58 askpass_opened_rx: Option<oneshot::Receiver<()>>,
59 askpass_kill_master_rx: Option<oneshot::Receiver<()>>,
60}
61
62#[cfg(unix)]
63impl AskPassSession {
64 /// This will create a new AskPassSession.
65 /// You must retain this session until the master process exits.
66 #[must_use]
67 pub async fn new(
68 executor: &BackgroundExecutor,
69 mut delegate: AskPassDelegate,
70 ) -> anyhow::Result<Self> {
71 let temp_dir = tempfile::Builder::new().prefix("zed-askpass").tempdir()?;
72 let askpass_socket = temp_dir.path().join("askpass.sock");
73 let askpass_script_path = temp_dir.path().join("askpass.sh");
74 let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
75 let listener =
76 UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
77 let zed_path = get_shell_safe_zed_path()?;
78
79 let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<()>();
80 let mut kill_tx = Some(askpass_kill_master_tx);
81
82 let askpass_task = executor.spawn(async move {
83 let mut askpass_opened_tx = Some(askpass_opened_tx);
84
85 while let Ok((mut stream, _)) = listener.accept().await {
86 if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
87 askpass_opened_tx.send(()).ok();
88 }
89 let mut buffer = Vec::new();
90 let mut reader = BufReader::new(&mut stream);
91 if reader.read_until(b'\0', &mut buffer).await.is_err() {
92 buffer.clear();
93 }
94 let prompt = String::from_utf8_lossy(&buffer);
95 if let Some(password) = delegate
96 .ask_password(prompt.to_string())
97 .await
98 .context("failed to get askpass password")
99 .log_err()
100 {
101 stream.write_all(password.as_bytes()).await.log_err();
102 } else {
103 if let Some(kill_tx) = kill_tx.take() {
104 kill_tx.send(()).log_err();
105 }
106 // note: we expect the caller to drop this task when it's done.
107 // We need to keep the stream open until the caller is done to avoid
108 // spurious errors from ssh.
109 std::future::pending::<()>().await;
110 drop(stream);
111 }
112 }
113 drop(temp_dir)
114 });
115
116 // Create an askpass script that communicates back to this process.
117 let askpass_script = format!(
118 "{shebang}\n{print_args} | {zed_exe} --askpass={askpass_socket} 2> /dev/null \n",
119 zed_exe = zed_path,
120 askpass_socket = askpass_socket.display(),
121 print_args = "printf '%s\\0' \"$@\"",
122 shebang = "#!/bin/sh",
123 );
124 fs::write(&askpass_script_path, askpass_script).await?;
125 fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
126
127 Ok(Self {
128 script_path: askpass_script_path,
129 _askpass_task: askpass_task,
130 askpass_kill_master_rx: Some(askpass_kill_master_rx),
131 askpass_opened_rx: Some(askpass_opened_rx),
132 })
133 }
134
135 pub fn script_path(&self) -> &Path {
136 &self.script_path
137 }
138
139 // This will run the askpass task forever, resolving as many authentication requests as needed.
140 // The caller is responsible for examining the result of their own commands and cancelling this
141 // future when this is no longer needed. Note that this can only be called once, but due to the
142 // drop order this takes an &mut, so you can `drop()` it after you're done with the master process.
143 pub async fn run(&mut self) -> AskPassResult {
144 let connection_timeout = Duration::from_secs(10);
145 let askpass_opened_rx = self.askpass_opened_rx.take().expect("Only call run once");
146 let askpass_kill_master_rx = self
147 .askpass_kill_master_rx
148 .take()
149 .expect("Only call run once");
150
151 select_biased! {
152 _ = askpass_opened_rx.fuse() => {
153 // Note: this await can only resolve after we are dropped.
154 askpass_kill_master_rx.await.ok();
155 return AskPassResult::CancelledByUser
156 }
157
158 _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
159 return AskPassResult::Timedout
160 }
161 }
162 }
163}
164
165/// The main function for when Zed is running in netcat mode for use in askpass.
166/// Called from both the remote server binary and the zed binary in their respective main functions.
167#[cfg(unix)]
168pub fn main(socket: &str) {
169 use std::io::{self, Read, Write};
170 use std::os::unix::net::UnixStream;
171 use std::process::exit;
172
173 let mut stream = match UnixStream::connect(socket) {
174 Ok(stream) => stream,
175 Err(err) => {
176 eprintln!("Error connecting to socket {}: {}", socket, err);
177 exit(1);
178 }
179 };
180
181 let mut buffer = Vec::new();
182 if let Err(err) = io::stdin().read_to_end(&mut buffer) {
183 eprintln!("Error reading from stdin: {}", err);
184 exit(1);
185 }
186
187 if buffer.last() != Some(&b'\0') {
188 buffer.push(b'\0');
189 }
190
191 if let Err(err) = stream.write_all(&buffer) {
192 eprintln!("Error writing to socket: {}", err);
193 exit(1);
194 }
195
196 let mut response = Vec::new();
197 if let Err(err) = stream.read_to_end(&mut response) {
198 eprintln!("Error reading from socket: {}", err);
199 exit(1);
200 }
201
202 if let Err(err) = io::stdout().write_all(&response) {
203 eprintln!("Error writing to stdout: {}", err);
204 exit(1);
205 }
206}
207#[cfg(not(unix))]
208pub fn main(_socket: &str) {}
209
210#[cfg(not(unix))]
211pub struct AskPassSession {
212 path: PathBuf,
213}
214
215#[cfg(not(unix))]
216impl AskPassSession {
217 pub async fn new(_: &BackgroundExecutor, _: AskPassDelegate) -> anyhow::Result<Self> {
218 Ok(Self {
219 path: PathBuf::new(),
220 })
221 }
222
223 pub fn script_path(&self) -> &Path {
224 &self.path
225 }
226
227 pub async fn run(&mut self) -> AskPassResult {
228 futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))).await;
229 AskPassResult::Timedout
230 }
231}