unix.rs

  1use crate::HeadlessProject;
  2use anyhow::{anyhow, Context, Result};
  3use fs::RealFs;
  4use futures::channel::mpsc;
  5use futures::{select, select_biased, AsyncRead, AsyncWrite, FutureExt, SinkExt};
  6use gpui::{AppContext, Context as _};
  7use remote::proxy::ProxyLaunchError;
  8use remote::ssh_session::ChannelClient;
  9use remote::{
 10    json_log::LogRecord,
 11    protocol::{read_message, write_message},
 12};
 13use rpc::proto::Envelope;
 14use smol::Async;
 15use smol::{io::AsyncWriteExt, net::unix::UnixListener, stream::StreamExt as _};
 16use std::{
 17    env,
 18    io::Write,
 19    mem,
 20    path::{Path, PathBuf},
 21    sync::Arc,
 22};
 23
 24pub fn init(log_file: Option<PathBuf>) -> Result<()> {
 25    init_logging(log_file)?;
 26    init_panic_hook();
 27    Ok(())
 28}
 29
 30fn init_logging(log_file: Option<PathBuf>) -> Result<()> {
 31    if let Some(log_file) = log_file {
 32        let target = Box::new(if log_file.exists() {
 33            std::fs::OpenOptions::new()
 34                .append(true)
 35                .open(&log_file)
 36                .context("Failed to open log file in append mode")?
 37        } else {
 38            std::fs::File::create(&log_file).context("Failed to create log file")?
 39        });
 40
 41        env_logger::Builder::from_default_env()
 42            .target(env_logger::Target::Pipe(target))
 43            .init();
 44    } else {
 45        env_logger::builder()
 46            .format(|buf, record| {
 47                serde_json::to_writer(&mut *buf, &LogRecord::new(record))?;
 48                buf.write_all(b"\n")?;
 49                Ok(())
 50            })
 51            .init();
 52    }
 53    Ok(())
 54}
 55
 56fn init_panic_hook() {
 57    std::panic::set_hook(Box::new(|info| {
 58        let payload = info
 59            .payload()
 60            .downcast_ref::<&str>()
 61            .map(|s| s.to_string())
 62            .or_else(|| info.payload().downcast_ref::<String>().cloned())
 63            .unwrap_or_else(|| "Box<Any>".to_string());
 64
 65        let backtrace = backtrace::Backtrace::new();
 66        let mut backtrace = backtrace
 67            .frames()
 68            .iter()
 69            .flat_map(|frame| {
 70                frame
 71                    .symbols()
 72                    .iter()
 73                    .filter_map(|frame| Some(format!("{:#}", frame.name()?)))
 74            })
 75            .collect::<Vec<_>>();
 76
 77        // Strip out leading stack frames for rust panic-handling.
 78        if let Some(ix) = backtrace
 79            .iter()
 80            .position(|name| name == "rust_begin_unwind")
 81        {
 82            backtrace.drain(0..=ix);
 83        }
 84
 85        log::error!(
 86            "server: panic occurred: {}\nBacktrace:\n{}",
 87            payload,
 88            backtrace.join("\n")
 89        );
 90
 91        std::process::abort();
 92    }));
 93}
 94
 95fn start_server(
 96    stdin_listener: UnixListener,
 97    stdout_listener: UnixListener,
 98    cx: &mut AppContext,
 99) -> Arc<ChannelClient> {
100    // This is the server idle timeout. If no connection comes in in this timeout, the server will shut down.
101    const IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10 * 60);
102
103    let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
104    let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded::<Envelope>();
105    let (app_quit_tx, mut app_quit_rx) = mpsc::unbounded::<()>();
106
107    cx.on_app_quit(move |_| {
108        let mut app_quit_tx = app_quit_tx.clone();
109        async move {
110            log::info!("app quitting. sending signal to server main loop");
111            app_quit_tx.send(()).await.ok();
112        }
113    })
114    .detach();
115
116    cx.spawn(|cx| async move {
117        let mut stdin_incoming = stdin_listener.incoming();
118        let mut stdout_incoming = stdout_listener.incoming();
119
120        loop {
121            let streams = futures::future::join(stdin_incoming.next(), stdout_incoming.next());
122
123            log::info!("server: accepting new connections");
124            let result = select! {
125                streams = streams.fuse() => {
126                    let (Some(Ok(stdin_stream)), Some(Ok(stdout_stream))) = streams else {
127                        break;
128                    };
129                    anyhow::Ok((stdin_stream, stdout_stream))
130                }
131                _ = futures::FutureExt::fuse(smol::Timer::after(IDLE_TIMEOUT)) => {
132                    log::warn!("server: timed out waiting for new connections after {:?}. exiting.", IDLE_TIMEOUT);
133                    cx.update(|cx| {
134                        // TODO: This is a hack, because in a headless project, shutdown isn't executed
135                        // when calling quit, but it should be.
136                        cx.shutdown();
137                        cx.quit();
138                    })?;
139                    break;
140                }
141                _ = app_quit_rx.next().fuse() => {
142                    break;
143                }
144            };
145
146            let Ok((mut stdin_stream, mut stdout_stream)) = result else {
147                break;
148            };
149
150            let mut input_buffer = Vec::new();
151            let mut output_buffer = Vec::new();
152            loop {
153                select_biased! {
154                    _ = app_quit_rx.next().fuse() => {
155                        return anyhow::Ok(());
156                    }
157
158                    stdin_message = read_message(&mut stdin_stream, &mut input_buffer).fuse() => {
159                        let message = match stdin_message {
160                            Ok(message) => message,
161                            Err(error) => {
162                                log::warn!("server: error reading message on stdin: {}. exiting.", error);
163                                break;
164                            }
165                        };
166                        if let Err(error) = incoming_tx.unbounded_send(message) {
167                            log::error!("server: failed to send message to application: {:?}. exiting.", error);
168                            return Err(anyhow!(error));
169                        }
170                    }
171
172                    outgoing_message  = outgoing_rx.next().fuse() => {
173                        let Some(message) = outgoing_message else {
174                            log::error!("server: stdout handler, no message");
175                            break;
176                        };
177
178                        if let Err(error) =
179                            write_message(&mut stdout_stream, &mut output_buffer, message).await
180                        {
181                            log::error!("server: failed to write stdout message: {:?}", error);
182                            break;
183                        }
184                        if let Err(error) = stdout_stream.flush().await {
185                            log::error!("server: failed to flush stdout message: {:?}", error);
186                            break;
187                        }
188                    }
189                }
190            }
191        }
192        anyhow::Ok(())
193    })
194    .detach();
195
196    ChannelClient::new(incoming_rx, outgoing_tx, cx)
197}
198
199pub fn execute_run(pid_file: PathBuf, stdin_socket: PathBuf, stdout_socket: PathBuf) -> Result<()> {
200    log::info!(
201        "server: starting up. pid_file: {:?}, stdin_socket: {:?}, stdout_socket: {:?}",
202        pid_file,
203        stdin_socket,
204        stdout_socket
205    );
206
207    write_pid_file(&pid_file)
208        .with_context(|| format!("failed to write pid file: {:?}", &pid_file))?;
209
210    let stdin_listener = UnixListener::bind(stdin_socket).context("failed to bind stdin socket")?;
211    let stdout_listener =
212        UnixListener::bind(stdout_socket).context("failed to bind stdout socket")?;
213
214    log::debug!("server: starting gpui app");
215    gpui::App::headless().run(move |cx| {
216        settings::init(cx);
217        HeadlessProject::init(cx);
218
219        log::info!("server: gpui app started, initializing server");
220        let session = start_server(stdin_listener, stdout_listener, cx);
221        let project = cx.new_model(|cx| {
222            HeadlessProject::new(session, Arc::new(RealFs::new(Default::default(), None)), cx)
223        });
224
225        mem::forget(project);
226    });
227    log::info!("server: gpui app is shut down. quitting.");
228    Ok(())
229}
230
231#[derive(Clone)]
232struct ServerPaths {
233    log_file: PathBuf,
234    pid_file: PathBuf,
235    stdin_socket: PathBuf,
236    stdout_socket: PathBuf,
237}
238
239impl ServerPaths {
240    fn new(identifier: &str) -> Result<Self> {
241        let project_dir = create_state_directory(identifier)?;
242
243        let pid_file = project_dir.join("server.pid");
244        let stdin_socket = project_dir.join("stdin.sock");
245        let stdout_socket = project_dir.join("stdout.sock");
246        let log_file = project_dir.join("server.log");
247
248        Ok(Self {
249            pid_file,
250            stdin_socket,
251            stdout_socket,
252            log_file,
253        })
254    }
255}
256
257pub fn execute_proxy(identifier: String, is_reconnecting: bool) -> Result<()> {
258    log::debug!("proxy: starting up. PID: {}", std::process::id());
259
260    let server_paths = ServerPaths::new(&identifier)?;
261
262    let server_pid = check_pid_file(&server_paths.pid_file)?;
263    let server_running = server_pid.is_some();
264    if is_reconnecting {
265        if !server_running {
266            log::error!("proxy: attempted to reconnect, but no server running");
267            return Err(anyhow!(ProxyLaunchError::ServerNotRunning));
268        }
269    } else {
270        if let Some(pid) = server_pid {
271            log::debug!("proxy: found server already running with PID {}. Killing process and cleaning up files...", pid);
272            kill_running_server(pid, &server_paths)?;
273        }
274
275        spawn_server(&server_paths)?;
276    }
277
278    let stdin_task = smol::spawn(async move {
279        let stdin = Async::new(std::io::stdin())?;
280        let stream = smol::net::unix::UnixStream::connect(&server_paths.stdin_socket).await?;
281        handle_io(stdin, stream, "stdin").await
282    });
283
284    let stdout_task: smol::Task<Result<()>> = smol::spawn(async move {
285        let stdout = Async::new(std::io::stdout())?;
286        let stream = smol::net::unix::UnixStream::connect(&server_paths.stdout_socket).await?;
287        handle_io(stream, stdout, "stdout").await
288    });
289
290    if let Err(forwarding_result) =
291        smol::block_on(async move { smol::future::race(stdin_task, stdout_task).await })
292    {
293        log::error!(
294            "proxy: failed to forward messages: {:?}, terminating...",
295            forwarding_result
296        );
297        return Err(forwarding_result);
298    }
299
300    Ok(())
301}
302
303fn create_state_directory(identifier: &str) -> Result<PathBuf> {
304    let home_dir = env::var("HOME").unwrap_or_else(|_| ".".to_string());
305    let server_dir = PathBuf::from(home_dir)
306        .join(".local")
307        .join("state")
308        .join("zed-remote-server")
309        .join(identifier);
310
311    std::fs::create_dir_all(&server_dir)?;
312
313    Ok(server_dir)
314}
315
316fn kill_running_server(pid: u32, paths: &ServerPaths) -> Result<()> {
317    log::info!("proxy: killing existing server with PID {}", pid);
318    std::process::Command::new("kill")
319        .arg(pid.to_string())
320        .output()
321        .context("proxy: failed to kill existing server")?;
322
323    for file in [&paths.pid_file, &paths.stdin_socket, &paths.stdout_socket] {
324        log::debug!(
325            "proxy: cleaning up file {:?} before starting new server",
326            file
327        );
328        std::fs::remove_file(file).ok();
329    }
330    Ok(())
331}
332
333fn spawn_server(paths: &ServerPaths) -> Result<()> {
334    if paths.stdin_socket.exists() {
335        std::fs::remove_file(&paths.stdin_socket)?;
336    }
337    if paths.stdout_socket.exists() {
338        std::fs::remove_file(&paths.stdout_socket)?;
339    }
340
341    let binary_name = std::env::current_exe()?;
342    let server_process = std::process::Command::new(binary_name)
343        .arg("run")
344        .arg("--log-file")
345        .arg(&paths.log_file)
346        .arg("--pid-file")
347        .arg(&paths.pid_file)
348        .arg("--stdin-socket")
349        .arg(&paths.stdin_socket)
350        .arg("--stdout-socket")
351        .arg(&paths.stdout_socket)
352        .spawn()?;
353
354    log::debug!("proxy: server started. PID: {:?}", server_process.id());
355
356    let mut total_time_waited = std::time::Duration::from_secs(0);
357    let wait_duration = std::time::Duration::from_millis(20);
358    while !paths.stdout_socket.exists() || !paths.stdin_socket.exists() {
359        log::debug!("proxy: waiting for server to be ready to accept connections...");
360        std::thread::sleep(wait_duration);
361        total_time_waited += wait_duration;
362    }
363
364    log::info!(
365        "proxy: server ready to accept connections. total time waited: {:?}",
366        total_time_waited
367    );
368    Ok(())
369}
370
371fn check_pid_file(path: &Path) -> Result<Option<u32>> {
372    let Some(pid) = std::fs::read_to_string(&path)
373        .ok()
374        .and_then(|contents| contents.parse::<u32>().ok())
375    else {
376        return Ok(None);
377    };
378
379    log::debug!("proxy: Checking if process with PID {} exists...", pid);
380    match std::process::Command::new("kill")
381        .arg("-0")
382        .arg(pid.to_string())
383        .output()
384    {
385        Ok(output) if output.status.success() => {
386            log::debug!("proxy: Process with PID {} exists. NOT spawning new server, but attaching to existing one.", pid);
387            Ok(Some(pid))
388        }
389        _ => {
390            log::debug!("proxy: Found PID file, but process with that PID does not exist. Removing PID file.");
391            std::fs::remove_file(&path).context("proxy: Failed to remove PID file")?;
392            Ok(None)
393        }
394    }
395}
396
397fn write_pid_file(path: &Path) -> Result<()> {
398    if path.exists() {
399        std::fs::remove_file(path)?;
400    }
401    let pid = std::process::id().to_string();
402    log::debug!("server: writing PID {} to file {:?}", pid, path);
403    std::fs::write(path, pid).context("Failed to write PID file")
404}
405
406async fn handle_io<R, W>(mut reader: R, mut writer: W, socket_name: &str) -> Result<()>
407where
408    R: AsyncRead + Unpin,
409    W: AsyncWrite + Unpin,
410{
411    use remote::protocol::read_message_raw;
412
413    let mut buffer = Vec::new();
414    loop {
415        read_message_raw(&mut reader, &mut buffer)
416            .await
417            .with_context(|| format!("proxy: failed to read message from {}", socket_name))?;
418
419        write_size_prefixed_buffer(&mut writer, &mut buffer)
420            .await
421            .with_context(|| format!("proxy: failed to write message to {}", socket_name))?;
422
423        writer.flush().await?;
424
425        buffer.clear();
426    }
427}
428
429async fn write_size_prefixed_buffer<S: AsyncWrite + Unpin>(
430    stream: &mut S,
431    buffer: &mut Vec<u8>,
432) -> Result<()> {
433    let len = buffer.len() as u32;
434    stream.write_all(len.to_le_bytes().as_slice()).await?;
435    stream.write_all(buffer).await?;
436    Ok(())
437}