1use crate::HeadlessProject;
2use anyhow::{anyhow, Context, Result};
3use fs::RealFs;
4use futures::channel::mpsc;
5use futures::{select, select_biased, AsyncRead, AsyncWrite, FutureExt, SinkExt};
6use gpui::{AppContext, Context as _};
7use remote::ssh_session::ChannelClient;
8use remote::{
9 json_log::LogRecord,
10 protocol::{read_message, write_message},
11};
12use rpc::proto::Envelope;
13use smol::Async;
14use smol::{io::AsyncWriteExt, net::unix::UnixListener, stream::StreamExt as _};
15use std::{
16 env,
17 io::Write,
18 mem,
19 path::{Path, PathBuf},
20 sync::Arc,
21};
22
23pub fn init_logging(log_file: Option<PathBuf>) -> Result<()> {
24 if let Some(log_file) = log_file {
25 let target = Box::new(if log_file.exists() {
26 std::fs::OpenOptions::new()
27 .append(true)
28 .open(&log_file)
29 .context("Failed to open log file in append mode")?
30 } else {
31 std::fs::File::create(&log_file).context("Failed to create log file")?
32 });
33
34 env_logger::Builder::from_default_env()
35 .target(env_logger::Target::Pipe(target))
36 .init();
37 } else {
38 env_logger::builder()
39 .format(|buf, record| {
40 serde_json::to_writer(&mut *buf, &LogRecord::new(record))?;
41 buf.write_all(b"\n")?;
42 Ok(())
43 })
44 .init();
45 }
46 Ok(())
47}
48
49fn start_server(
50 stdin_listener: UnixListener,
51 stdout_listener: UnixListener,
52 cx: &mut AppContext,
53) -> Arc<ChannelClient> {
54 // This is the server idle timeout. If no connection comes in in this timeout, the server will shut down.
55 const IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10 * 60);
56
57 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
58 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded::<Envelope>();
59 let (app_quit_tx, mut app_quit_rx) = mpsc::unbounded::<()>();
60
61 cx.on_app_quit(move |_| {
62 let mut app_quit_tx = app_quit_tx.clone();
63 async move {
64 app_quit_tx.send(()).await.ok();
65 }
66 })
67 .detach();
68
69 cx.spawn(|cx| async move {
70 let mut stdin_incoming = stdin_listener.incoming();
71 let mut stdout_incoming = stdout_listener.incoming();
72
73 loop {
74 let streams = futures::future::join(stdin_incoming.next(), stdout_incoming.next());
75
76 log::info!("server: accepting new connections");
77 let result = select! {
78 streams = streams.fuse() => {
79 let (Some(Ok(stdin_stream)), Some(Ok(stdout_stream))) = streams else {
80 break;
81 };
82 anyhow::Ok((stdin_stream, stdout_stream))
83 }
84 _ = futures::FutureExt::fuse(smol::Timer::after(IDLE_TIMEOUT)) => {
85 log::warn!("server: timed out waiting for new connections after {:?}. exiting.", IDLE_TIMEOUT);
86 cx.update(|cx| {
87 // TODO: This is a hack, because in a headless project, shutdown isn't executed
88 // when calling quit, but it should be.
89 cx.shutdown();
90 cx.quit();
91 })?;
92 break;
93 }
94 _ = app_quit_rx.next().fuse() => {
95 break;
96 }
97 };
98
99 let Ok((mut stdin_stream, mut stdout_stream)) = result else {
100 break;
101 };
102
103 let mut input_buffer = Vec::new();
104 let mut output_buffer = Vec::new();
105 loop {
106 select_biased! {
107 _ = app_quit_rx.next().fuse() => {
108 return anyhow::Ok(());
109 }
110
111 stdin_message = read_message(&mut stdin_stream, &mut input_buffer).fuse() => {
112 let message = match stdin_message {
113 Ok(message) => message,
114 Err(error) => {
115 log::warn!("server: error reading message on stdin: {}. exiting.", error);
116 break;
117 }
118 };
119 if let Err(error) = incoming_tx.unbounded_send(message) {
120 log::error!("server: failed to send message to application: {:?}. exiting.", error);
121 return Err(anyhow!(error));
122 }
123 }
124
125 outgoing_message = outgoing_rx.next().fuse() => {
126 let Some(message) = outgoing_message else {
127 log::error!("server: stdout handler, no message");
128 break;
129 };
130
131 if let Err(error) =
132 write_message(&mut stdout_stream, &mut output_buffer, message).await
133 {
134 log::error!("server: failed to write stdout message: {:?}", error);
135 break;
136 }
137 if let Err(error) = stdout_stream.flush().await {
138 log::error!("server: failed to flush stdout message: {:?}", error);
139 break;
140 }
141 }
142 }
143 }
144 }
145 anyhow::Ok(())
146 })
147 .detach();
148
149 ChannelClient::new(incoming_rx, outgoing_tx, cx)
150}
151
152pub fn execute_run(pid_file: PathBuf, stdin_socket: PathBuf, stdout_socket: PathBuf) -> Result<()> {
153 write_pid_file(&pid_file)
154 .with_context(|| format!("failed to write pid file: {:?}", &pid_file))?;
155
156 let stdin_listener = UnixListener::bind(stdin_socket).context("failed to bind stdin socket")?;
157 let stdout_listener =
158 UnixListener::bind(stdout_socket).context("failed to bind stdout socket")?;
159
160 gpui::App::headless().run(move |cx| {
161 settings::init(cx);
162 HeadlessProject::init(cx);
163
164 let session = start_server(stdin_listener, stdout_listener, cx);
165 let project = cx.new_model(|cx| {
166 HeadlessProject::new(session, Arc::new(RealFs::new(Default::default(), None)), cx)
167 });
168
169 mem::forget(project);
170 });
171 log::info!("server: gpui app is shut down. quitting.");
172 Ok(())
173}
174
175pub fn execute_proxy(identifier: String) -> Result<()> {
176 log::debug!("proxy: starting up. PID: {}", std::process::id());
177
178 let project_dir = ensure_project_dir(&identifier)?;
179
180 let pid_file = project_dir.join("server.pid");
181 let stdin_socket = project_dir.join("stdin.sock");
182 let stdout_socket = project_dir.join("stdout.sock");
183 let log_file = project_dir.join("server.log");
184
185 let server_running = check_pid_file(&pid_file)?;
186 if !server_running {
187 spawn_server(&log_file, &pid_file, &stdin_socket, &stdout_socket)?;
188 };
189
190 let stdin_task = smol::spawn(async move {
191 let stdin = Async::new(std::io::stdin())?;
192 let stream = smol::net::unix::UnixStream::connect(stdin_socket).await?;
193 handle_io(stdin, stream, "stdin").await
194 });
195
196 let stdout_task: smol::Task<Result<()>> = smol::spawn(async move {
197 let stdout = Async::new(std::io::stdout())?;
198 let stream = smol::net::unix::UnixStream::connect(stdout_socket).await?;
199 handle_io(stream, stdout, "stdout").await
200 });
201
202 if let Err(forwarding_result) =
203 smol::block_on(async move { smol::future::race(stdin_task, stdout_task).await })
204 {
205 log::error!(
206 "proxy: failed to forward messages: {:?}, terminating...",
207 forwarding_result
208 );
209 return Err(forwarding_result);
210 }
211
212 Ok(())
213}
214
215fn ensure_project_dir(identifier: &str) -> Result<PathBuf> {
216 let project_dir = env::var("HOME").unwrap_or_else(|_| ".".to_string());
217 let project_dir = PathBuf::from(project_dir)
218 .join(".local")
219 .join("state")
220 .join("zed-remote-server")
221 .join(identifier);
222
223 std::fs::create_dir_all(&project_dir)?;
224
225 Ok(project_dir)
226}
227
228fn spawn_server(
229 log_file: &Path,
230 pid_file: &Path,
231 stdin_socket: &Path,
232 stdout_socket: &Path,
233) -> Result<()> {
234 if stdin_socket.exists() {
235 std::fs::remove_file(&stdin_socket)?;
236 }
237 if stdout_socket.exists() {
238 std::fs::remove_file(&stdout_socket)?;
239 }
240
241 let binary_name = std::env::current_exe()?;
242 let server_process = std::process::Command::new(binary_name)
243 .arg("run")
244 .arg("--log-file")
245 .arg(log_file)
246 .arg("--pid-file")
247 .arg(pid_file)
248 .arg("--stdin-socket")
249 .arg(stdin_socket)
250 .arg("--stdout-socket")
251 .arg(stdout_socket)
252 .spawn()?;
253
254 log::debug!("proxy: server started. PID: {:?}", server_process.id());
255
256 let mut total_time_waited = std::time::Duration::from_secs(0);
257 let wait_duration = std::time::Duration::from_millis(20);
258 while !stdout_socket.exists() || !stdin_socket.exists() {
259 log::debug!("proxy: waiting for server to be ready to accept connections...");
260 std::thread::sleep(wait_duration);
261 total_time_waited += wait_duration;
262 }
263
264 log::info!(
265 "proxy: server ready to accept connections. total time waited: {:?}",
266 total_time_waited
267 );
268 Ok(())
269}
270
271fn check_pid_file(path: &Path) -> Result<bool> {
272 let Some(pid) = std::fs::read_to_string(&path)
273 .ok()
274 .and_then(|contents| contents.parse::<u32>().ok())
275 else {
276 return Ok(false);
277 };
278
279 log::debug!("proxy: Checking if process with PID {} exists...", pid);
280 match std::process::Command::new("kill")
281 .arg("-0")
282 .arg(pid.to_string())
283 .output()
284 {
285 Ok(output) if output.status.success() => {
286 log::debug!("proxy: Process with PID {} exists. NOT spawning new server, but attaching to existing one.", pid);
287 Ok(true)
288 }
289 _ => {
290 log::debug!("proxy: Found PID file, but process with that PID does not exist. Removing PID file.");
291 std::fs::remove_file(&path).context("proxy: Failed to remove PID file")?;
292 Ok(false)
293 }
294 }
295}
296
297fn write_pid_file(path: &Path) -> Result<()> {
298 if path.exists() {
299 std::fs::remove_file(path)?;
300 }
301
302 std::fs::write(path, std::process::id().to_string()).context("Failed to write PID file")
303}
304
305async fn handle_io<R, W>(mut reader: R, mut writer: W, socket_name: &str) -> Result<()>
306where
307 R: AsyncRead + Unpin,
308 W: AsyncWrite + Unpin,
309{
310 use remote::protocol::read_message_raw;
311
312 let mut buffer = Vec::new();
313 loop {
314 read_message_raw(&mut reader, &mut buffer)
315 .await
316 .with_context(|| format!("proxy: failed to read message from {}", socket_name))?;
317
318 write_size_prefixed_buffer(&mut writer, &mut buffer)
319 .await
320 .with_context(|| format!("proxy: failed to write message to {}", socket_name))?;
321
322 writer.flush().await?;
323
324 buffer.clear();
325 }
326}
327
328async fn write_size_prefixed_buffer<S: AsyncWrite + Unpin>(
329 stream: &mut S,
330 buffer: &mut Vec<u8>,
331) -> Result<()> {
332 let len = buffer.len() as u32;
333 stream.write_all(len.to_le_bytes().as_slice()).await?;
334 stream.write_all(buffer).await?;
335 Ok(())
336}