1use crate::HeadlessProject;
2use anyhow::{anyhow, Context, Result};
3use fs::RealFs;
4use futures::channel::mpsc;
5use futures::{select, select_biased, AsyncRead, AsyncWrite, FutureExt, SinkExt};
6use gpui::{AppContext, Context as _};
7use remote::ssh_session::ChannelClient;
8use remote::{
9 json_log::LogRecord,
10 protocol::{read_message, write_message},
11};
12use rpc::proto::Envelope;
13use smol::Async;
14use smol::{io::AsyncWriteExt, net::unix::UnixListener, stream::StreamExt as _};
15use std::{
16 env,
17 io::Write,
18 mem,
19 path::{Path, PathBuf},
20 sync::Arc,
21};
22
23pub fn init(log_file: Option<PathBuf>) -> Result<()> {
24 init_logging(log_file)?;
25 init_panic_hook();
26 Ok(())
27}
28
29fn init_logging(log_file: Option<PathBuf>) -> Result<()> {
30 if let Some(log_file) = log_file {
31 let target = Box::new(if log_file.exists() {
32 std::fs::OpenOptions::new()
33 .append(true)
34 .open(&log_file)
35 .context("Failed to open log file in append mode")?
36 } else {
37 std::fs::File::create(&log_file).context("Failed to create log file")?
38 });
39
40 env_logger::Builder::from_default_env()
41 .target(env_logger::Target::Pipe(target))
42 .init();
43 } else {
44 env_logger::builder()
45 .format(|buf, record| {
46 serde_json::to_writer(&mut *buf, &LogRecord::new(record))?;
47 buf.write_all(b"\n")?;
48 Ok(())
49 })
50 .init();
51 }
52 Ok(())
53}
54
55fn init_panic_hook() {
56 std::panic::set_hook(Box::new(|info| {
57 let payload = info
58 .payload()
59 .downcast_ref::<&str>()
60 .map(|s| s.to_string())
61 .or_else(|| info.payload().downcast_ref::<String>().cloned())
62 .unwrap_or_else(|| "Box<Any>".to_string());
63
64 let backtrace = backtrace::Backtrace::new();
65 let mut backtrace = backtrace
66 .frames()
67 .iter()
68 .flat_map(|frame| {
69 frame
70 .symbols()
71 .iter()
72 .filter_map(|frame| Some(format!("{:#}", frame.name()?)))
73 })
74 .collect::<Vec<_>>();
75
76 // Strip out leading stack frames for rust panic-handling.
77 if let Some(ix) = backtrace
78 .iter()
79 .position(|name| name == "rust_begin_unwind")
80 {
81 backtrace.drain(0..=ix);
82 }
83
84 log::error!(
85 "server: panic occurred: {}\nBacktrace:\n{}",
86 payload,
87 backtrace.join("\n")
88 );
89
90 std::process::abort();
91 }));
92}
93
94fn start_server(
95 stdin_listener: UnixListener,
96 stdout_listener: UnixListener,
97 cx: &mut AppContext,
98) -> Arc<ChannelClient> {
99 // This is the server idle timeout. If no connection comes in in this timeout, the server will shut down.
100 const IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10 * 60);
101
102 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
103 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded::<Envelope>();
104 let (app_quit_tx, mut app_quit_rx) = mpsc::unbounded::<()>();
105
106 cx.on_app_quit(move |_| {
107 let mut app_quit_tx = app_quit_tx.clone();
108 async move {
109 log::info!("app quitting. sending signal to server main loop");
110 app_quit_tx.send(()).await.ok();
111 }
112 })
113 .detach();
114
115 cx.spawn(|cx| async move {
116 let mut stdin_incoming = stdin_listener.incoming();
117 let mut stdout_incoming = stdout_listener.incoming();
118
119 loop {
120 let streams = futures::future::join(stdin_incoming.next(), stdout_incoming.next());
121
122 log::info!("server: accepting new connections");
123 let result = select! {
124 streams = streams.fuse() => {
125 let (Some(Ok(stdin_stream)), Some(Ok(stdout_stream))) = streams else {
126 break;
127 };
128 anyhow::Ok((stdin_stream, stdout_stream))
129 }
130 _ = futures::FutureExt::fuse(smol::Timer::after(IDLE_TIMEOUT)) => {
131 log::warn!("server: timed out waiting for new connections after {:?}. exiting.", IDLE_TIMEOUT);
132 cx.update(|cx| {
133 // TODO: This is a hack, because in a headless project, shutdown isn't executed
134 // when calling quit, but it should be.
135 cx.shutdown();
136 cx.quit();
137 })?;
138 break;
139 }
140 _ = app_quit_rx.next().fuse() => {
141 break;
142 }
143 };
144
145 let Ok((mut stdin_stream, mut stdout_stream)) = result else {
146 break;
147 };
148
149 let mut input_buffer = Vec::new();
150 let mut output_buffer = Vec::new();
151 loop {
152 select_biased! {
153 _ = app_quit_rx.next().fuse() => {
154 return anyhow::Ok(());
155 }
156
157 stdin_message = read_message(&mut stdin_stream, &mut input_buffer).fuse() => {
158 let message = match stdin_message {
159 Ok(message) => message,
160 Err(error) => {
161 log::warn!("server: error reading message on stdin: {}. exiting.", error);
162 break;
163 }
164 };
165 if let Err(error) = incoming_tx.unbounded_send(message) {
166 log::error!("server: failed to send message to application: {:?}. exiting.", error);
167 return Err(anyhow!(error));
168 }
169 }
170
171 outgoing_message = outgoing_rx.next().fuse() => {
172 let Some(message) = outgoing_message else {
173 log::error!("server: stdout handler, no message");
174 break;
175 };
176
177 if let Err(error) =
178 write_message(&mut stdout_stream, &mut output_buffer, message).await
179 {
180 log::error!("server: failed to write stdout message: {:?}", error);
181 break;
182 }
183 if let Err(error) = stdout_stream.flush().await {
184 log::error!("server: failed to flush stdout message: {:?}", error);
185 break;
186 }
187 }
188 }
189 }
190 }
191 anyhow::Ok(())
192 })
193 .detach();
194
195 ChannelClient::new(incoming_rx, outgoing_tx, cx)
196}
197
198pub fn execute_run(pid_file: PathBuf, stdin_socket: PathBuf, stdout_socket: PathBuf) -> Result<()> {
199 log::info!(
200 "server: starting up. pid_file: {:?}, stdin_socket: {:?}, stdout_socket: {:?}",
201 pid_file,
202 stdin_socket,
203 stdout_socket
204 );
205
206 write_pid_file(&pid_file)
207 .with_context(|| format!("failed to write pid file: {:?}", &pid_file))?;
208
209 let stdin_listener = UnixListener::bind(stdin_socket).context("failed to bind stdin socket")?;
210 let stdout_listener =
211 UnixListener::bind(stdout_socket).context("failed to bind stdout socket")?;
212
213 log::debug!("server: starting gpui app");
214 gpui::App::headless().run(move |cx| {
215 settings::init(cx);
216 HeadlessProject::init(cx);
217
218 log::info!("server: gpui app started, initializing server");
219 let session = start_server(stdin_listener, stdout_listener, cx);
220 let project = cx.new_model(|cx| {
221 HeadlessProject::new(session, Arc::new(RealFs::new(Default::default(), None)), cx)
222 });
223
224 mem::forget(project);
225 });
226 log::info!("server: gpui app is shut down. quitting.");
227 Ok(())
228}
229
230pub fn execute_proxy(identifier: String) -> Result<()> {
231 log::debug!("proxy: starting up. PID: {}", std::process::id());
232
233 let project_dir = ensure_project_dir(&identifier)?;
234
235 let pid_file = project_dir.join("server.pid");
236 let stdin_socket = project_dir.join("stdin.sock");
237 let stdout_socket = project_dir.join("stdout.sock");
238 let log_file = project_dir.join("server.log");
239
240 let server_running = check_pid_file(&pid_file)?;
241 if !server_running {
242 spawn_server(&log_file, &pid_file, &stdin_socket, &stdout_socket)?;
243 };
244
245 let stdin_task = smol::spawn(async move {
246 let stdin = Async::new(std::io::stdin())?;
247 let stream = smol::net::unix::UnixStream::connect(stdin_socket).await?;
248 handle_io(stdin, stream, "stdin").await
249 });
250
251 let stdout_task: smol::Task<Result<()>> = smol::spawn(async move {
252 let stdout = Async::new(std::io::stdout())?;
253 let stream = smol::net::unix::UnixStream::connect(stdout_socket).await?;
254 handle_io(stream, stdout, "stdout").await
255 });
256
257 if let Err(forwarding_result) =
258 smol::block_on(async move { smol::future::race(stdin_task, stdout_task).await })
259 {
260 log::error!(
261 "proxy: failed to forward messages: {:?}, terminating...",
262 forwarding_result
263 );
264 return Err(forwarding_result);
265 }
266
267 Ok(())
268}
269
270fn ensure_project_dir(identifier: &str) -> Result<PathBuf> {
271 let project_dir = env::var("HOME").unwrap_or_else(|_| ".".to_string());
272 let project_dir = PathBuf::from(project_dir)
273 .join(".local")
274 .join("state")
275 .join("zed-remote-server")
276 .join(identifier);
277
278 std::fs::create_dir_all(&project_dir)?;
279
280 Ok(project_dir)
281}
282
283fn spawn_server(
284 log_file: &Path,
285 pid_file: &Path,
286 stdin_socket: &Path,
287 stdout_socket: &Path,
288) -> Result<()> {
289 if stdin_socket.exists() {
290 std::fs::remove_file(&stdin_socket)?;
291 }
292 if stdout_socket.exists() {
293 std::fs::remove_file(&stdout_socket)?;
294 }
295
296 let binary_name = std::env::current_exe()?;
297 let server_process = std::process::Command::new(binary_name)
298 .arg("run")
299 .arg("--log-file")
300 .arg(log_file)
301 .arg("--pid-file")
302 .arg(pid_file)
303 .arg("--stdin-socket")
304 .arg(stdin_socket)
305 .arg("--stdout-socket")
306 .arg(stdout_socket)
307 .spawn()?;
308
309 log::debug!("proxy: server started. PID: {:?}", server_process.id());
310
311 let mut total_time_waited = std::time::Duration::from_secs(0);
312 let wait_duration = std::time::Duration::from_millis(20);
313 while !stdout_socket.exists() || !stdin_socket.exists() {
314 log::debug!("proxy: waiting for server to be ready to accept connections...");
315 std::thread::sleep(wait_duration);
316 total_time_waited += wait_duration;
317 }
318
319 log::info!(
320 "proxy: server ready to accept connections. total time waited: {:?}",
321 total_time_waited
322 );
323 Ok(())
324}
325
326fn check_pid_file(path: &Path) -> Result<bool> {
327 let Some(pid) = std::fs::read_to_string(&path)
328 .ok()
329 .and_then(|contents| contents.parse::<u32>().ok())
330 else {
331 return Ok(false);
332 };
333
334 log::debug!("proxy: Checking if process with PID {} exists...", pid);
335 match std::process::Command::new("kill")
336 .arg("-0")
337 .arg(pid.to_string())
338 .output()
339 {
340 Ok(output) if output.status.success() => {
341 log::debug!("proxy: Process with PID {} exists. NOT spawning new server, but attaching to existing one.", pid);
342 Ok(true)
343 }
344 _ => {
345 log::debug!("proxy: Found PID file, but process with that PID does not exist. Removing PID file.");
346 std::fs::remove_file(&path).context("proxy: Failed to remove PID file")?;
347 Ok(false)
348 }
349 }
350}
351
352fn write_pid_file(path: &Path) -> Result<()> {
353 if path.exists() {
354 std::fs::remove_file(path)?;
355 }
356 let pid = std::process::id().to_string();
357 log::debug!("server: writing PID {} to file {:?}", pid, path);
358 std::fs::write(path, pid).context("Failed to write PID file")
359}
360
361async fn handle_io<R, W>(mut reader: R, mut writer: W, socket_name: &str) -> Result<()>
362where
363 R: AsyncRead + Unpin,
364 W: AsyncWrite + Unpin,
365{
366 use remote::protocol::read_message_raw;
367
368 let mut buffer = Vec::new();
369 loop {
370 read_message_raw(&mut reader, &mut buffer)
371 .await
372 .with_context(|| format!("proxy: failed to read message from {}", socket_name))?;
373
374 write_size_prefixed_buffer(&mut writer, &mut buffer)
375 .await
376 .with_context(|| format!("proxy: failed to write message to {}", socket_name))?;
377
378 writer.flush().await?;
379
380 buffer.clear();
381 }
382}
383
384async fn write_size_prefixed_buffer<S: AsyncWrite + Unpin>(
385 stream: &mut S,
386 buffer: &mut Vec<u8>,
387) -> Result<()> {
388 let len = buffer.len() as u32;
389 stream.write_all(len.to_le_bytes().as_slice()).await?;
390 stream.write_all(buffer).await?;
391 Ok(())
392}