unix.rs

  1use crate::HeadlessProject;
  2use anyhow::{anyhow, Context, Result};
  3use fs::RealFs;
  4use futures::channel::mpsc;
  5use futures::{select, select_biased, AsyncRead, AsyncWrite, FutureExt, SinkExt};
  6use gpui::{AppContext, Context as _};
  7use remote::ssh_session::ChannelClient;
  8use remote::{
  9    json_log::LogRecord,
 10    protocol::{read_message, write_message},
 11};
 12use rpc::proto::Envelope;
 13use smol::Async;
 14use smol::{io::AsyncWriteExt, net::unix::UnixListener, stream::StreamExt as _};
 15use std::{
 16    env,
 17    io::Write,
 18    mem,
 19    path::{Path, PathBuf},
 20    sync::Arc,
 21};
 22
 23pub fn init(log_file: Option<PathBuf>) -> Result<()> {
 24    init_logging(log_file)?;
 25    init_panic_hook();
 26    Ok(())
 27}
 28
 29fn init_logging(log_file: Option<PathBuf>) -> Result<()> {
 30    if let Some(log_file) = log_file {
 31        let target = Box::new(if log_file.exists() {
 32            std::fs::OpenOptions::new()
 33                .append(true)
 34                .open(&log_file)
 35                .context("Failed to open log file in append mode")?
 36        } else {
 37            std::fs::File::create(&log_file).context("Failed to create log file")?
 38        });
 39
 40        env_logger::Builder::from_default_env()
 41            .target(env_logger::Target::Pipe(target))
 42            .init();
 43    } else {
 44        env_logger::builder()
 45            .format(|buf, record| {
 46                serde_json::to_writer(&mut *buf, &LogRecord::new(record))?;
 47                buf.write_all(b"\n")?;
 48                Ok(())
 49            })
 50            .init();
 51    }
 52    Ok(())
 53}
 54
 55fn init_panic_hook() {
 56    std::panic::set_hook(Box::new(|info| {
 57        let payload = info
 58            .payload()
 59            .downcast_ref::<&str>()
 60            .map(|s| s.to_string())
 61            .or_else(|| info.payload().downcast_ref::<String>().cloned())
 62            .unwrap_or_else(|| "Box<Any>".to_string());
 63
 64        let backtrace = backtrace::Backtrace::new();
 65        let mut backtrace = backtrace
 66            .frames()
 67            .iter()
 68            .flat_map(|frame| {
 69                frame
 70                    .symbols()
 71                    .iter()
 72                    .filter_map(|frame| Some(format!("{:#}", frame.name()?)))
 73            })
 74            .collect::<Vec<_>>();
 75
 76        // Strip out leading stack frames for rust panic-handling.
 77        if let Some(ix) = backtrace
 78            .iter()
 79            .position(|name| name == "rust_begin_unwind")
 80        {
 81            backtrace.drain(0..=ix);
 82        }
 83
 84        log::error!(
 85            "server: panic occurred: {}\nBacktrace:\n{}",
 86            payload,
 87            backtrace.join("\n")
 88        );
 89
 90        std::process::abort();
 91    }));
 92}
 93
 94fn start_server(
 95    stdin_listener: UnixListener,
 96    stdout_listener: UnixListener,
 97    cx: &mut AppContext,
 98) -> Arc<ChannelClient> {
 99    // This is the server idle timeout. If no connection comes in in this timeout, the server will shut down.
100    const IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10 * 60);
101
102    let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
103    let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded::<Envelope>();
104    let (app_quit_tx, mut app_quit_rx) = mpsc::unbounded::<()>();
105
106    cx.on_app_quit(move |_| {
107        let mut app_quit_tx = app_quit_tx.clone();
108        async move {
109            log::info!("app quitting. sending signal to server main loop");
110            app_quit_tx.send(()).await.ok();
111        }
112    })
113    .detach();
114
115    cx.spawn(|cx| async move {
116        let mut stdin_incoming = stdin_listener.incoming();
117        let mut stdout_incoming = stdout_listener.incoming();
118
119        loop {
120            let streams = futures::future::join(stdin_incoming.next(), stdout_incoming.next());
121
122            log::info!("server: accepting new connections");
123            let result = select! {
124                streams = streams.fuse() => {
125                    let (Some(Ok(stdin_stream)), Some(Ok(stdout_stream))) = streams else {
126                        break;
127                    };
128                    anyhow::Ok((stdin_stream, stdout_stream))
129                }
130                _ = futures::FutureExt::fuse(smol::Timer::after(IDLE_TIMEOUT)) => {
131                    log::warn!("server: timed out waiting for new connections after {:?}. exiting.", IDLE_TIMEOUT);
132                    cx.update(|cx| {
133                        // TODO: This is a hack, because in a headless project, shutdown isn't executed
134                        // when calling quit, but it should be.
135                        cx.shutdown();
136                        cx.quit();
137                    })?;
138                    break;
139                }
140                _ = app_quit_rx.next().fuse() => {
141                    break;
142                }
143            };
144
145            let Ok((mut stdin_stream, mut stdout_stream)) = result else {
146                break;
147            };
148
149            let mut input_buffer = Vec::new();
150            let mut output_buffer = Vec::new();
151            loop {
152                select_biased! {
153                    _ = app_quit_rx.next().fuse() => {
154                        return anyhow::Ok(());
155                    }
156
157                    stdin_message = read_message(&mut stdin_stream, &mut input_buffer).fuse() => {
158                        let message = match stdin_message {
159                            Ok(message) => message,
160                            Err(error) => {
161                                log::warn!("server: error reading message on stdin: {}. exiting.", error);
162                                break;
163                            }
164                        };
165                        if let Err(error) = incoming_tx.unbounded_send(message) {
166                            log::error!("server: failed to send message to application: {:?}. exiting.", error);
167                            return Err(anyhow!(error));
168                        }
169                    }
170
171                    outgoing_message  = outgoing_rx.next().fuse() => {
172                        let Some(message) = outgoing_message else {
173                            log::error!("server: stdout handler, no message");
174                            break;
175                        };
176
177                        if let Err(error) =
178                            write_message(&mut stdout_stream, &mut output_buffer, message).await
179                        {
180                            log::error!("server: failed to write stdout message: {:?}", error);
181                            break;
182                        }
183                        if let Err(error) = stdout_stream.flush().await {
184                            log::error!("server: failed to flush stdout message: {:?}", error);
185                            break;
186                        }
187                    }
188                }
189            }
190        }
191        anyhow::Ok(())
192    })
193    .detach();
194
195    ChannelClient::new(incoming_rx, outgoing_tx, cx)
196}
197
198pub fn execute_run(pid_file: PathBuf, stdin_socket: PathBuf, stdout_socket: PathBuf) -> Result<()> {
199    log::info!(
200        "server: starting up. pid_file: {:?}, stdin_socket: {:?}, stdout_socket: {:?}",
201        pid_file,
202        stdin_socket,
203        stdout_socket
204    );
205
206    write_pid_file(&pid_file)
207        .with_context(|| format!("failed to write pid file: {:?}", &pid_file))?;
208
209    let stdin_listener = UnixListener::bind(stdin_socket).context("failed to bind stdin socket")?;
210    let stdout_listener =
211        UnixListener::bind(stdout_socket).context("failed to bind stdout socket")?;
212
213    log::debug!("server: starting gpui app");
214    gpui::App::headless().run(move |cx| {
215        settings::init(cx);
216        HeadlessProject::init(cx);
217
218        log::info!("server: gpui app started, initializing server");
219        let session = start_server(stdin_listener, stdout_listener, cx);
220        let project = cx.new_model(|cx| {
221            HeadlessProject::new(session, Arc::new(RealFs::new(Default::default(), None)), cx)
222        });
223
224        mem::forget(project);
225    });
226    log::info!("server: gpui app is shut down. quitting.");
227    Ok(())
228}
229
230pub fn execute_proxy(identifier: String) -> Result<()> {
231    log::debug!("proxy: starting up. PID: {}", std::process::id());
232
233    let project_dir = ensure_project_dir(&identifier)?;
234
235    let pid_file = project_dir.join("server.pid");
236    let stdin_socket = project_dir.join("stdin.sock");
237    let stdout_socket = project_dir.join("stdout.sock");
238    let log_file = project_dir.join("server.log");
239
240    let server_running = check_pid_file(&pid_file)?;
241    if !server_running {
242        spawn_server(&log_file, &pid_file, &stdin_socket, &stdout_socket)?;
243    };
244
245    let stdin_task = smol::spawn(async move {
246        let stdin = Async::new(std::io::stdin())?;
247        let stream = smol::net::unix::UnixStream::connect(stdin_socket).await?;
248        handle_io(stdin, stream, "stdin").await
249    });
250
251    let stdout_task: smol::Task<Result<()>> = smol::spawn(async move {
252        let stdout = Async::new(std::io::stdout())?;
253        let stream = smol::net::unix::UnixStream::connect(stdout_socket).await?;
254        handle_io(stream, stdout, "stdout").await
255    });
256
257    if let Err(forwarding_result) =
258        smol::block_on(async move { smol::future::race(stdin_task, stdout_task).await })
259    {
260        log::error!(
261            "proxy: failed to forward messages: {:?}, terminating...",
262            forwarding_result
263        );
264        return Err(forwarding_result);
265    }
266
267    Ok(())
268}
269
270fn ensure_project_dir(identifier: &str) -> Result<PathBuf> {
271    let project_dir = env::var("HOME").unwrap_or_else(|_| ".".to_string());
272    let project_dir = PathBuf::from(project_dir)
273        .join(".local")
274        .join("state")
275        .join("zed-remote-server")
276        .join(identifier);
277
278    std::fs::create_dir_all(&project_dir)?;
279
280    Ok(project_dir)
281}
282
283fn spawn_server(
284    log_file: &Path,
285    pid_file: &Path,
286    stdin_socket: &Path,
287    stdout_socket: &Path,
288) -> Result<()> {
289    if stdin_socket.exists() {
290        std::fs::remove_file(&stdin_socket)?;
291    }
292    if stdout_socket.exists() {
293        std::fs::remove_file(&stdout_socket)?;
294    }
295
296    let binary_name = std::env::current_exe()?;
297    let server_process = std::process::Command::new(binary_name)
298        .arg("run")
299        .arg("--log-file")
300        .arg(log_file)
301        .arg("--pid-file")
302        .arg(pid_file)
303        .arg("--stdin-socket")
304        .arg(stdin_socket)
305        .arg("--stdout-socket")
306        .arg(stdout_socket)
307        .spawn()?;
308
309    log::debug!("proxy: server started. PID: {:?}", server_process.id());
310
311    let mut total_time_waited = std::time::Duration::from_secs(0);
312    let wait_duration = std::time::Duration::from_millis(20);
313    while !stdout_socket.exists() || !stdin_socket.exists() {
314        log::debug!("proxy: waiting for server to be ready to accept connections...");
315        std::thread::sleep(wait_duration);
316        total_time_waited += wait_duration;
317    }
318
319    log::info!(
320        "proxy: server ready to accept connections. total time waited: {:?}",
321        total_time_waited
322    );
323    Ok(())
324}
325
326fn check_pid_file(path: &Path) -> Result<bool> {
327    let Some(pid) = std::fs::read_to_string(&path)
328        .ok()
329        .and_then(|contents| contents.parse::<u32>().ok())
330    else {
331        return Ok(false);
332    };
333
334    log::debug!("proxy: Checking if process with PID {} exists...", pid);
335    match std::process::Command::new("kill")
336        .arg("-0")
337        .arg(pid.to_string())
338        .output()
339    {
340        Ok(output) if output.status.success() => {
341            log::debug!("proxy: Process with PID {} exists. NOT spawning new server, but attaching to existing one.", pid);
342            Ok(true)
343        }
344        _ => {
345            log::debug!("proxy: Found PID file, but process with that PID does not exist. Removing PID file.");
346            std::fs::remove_file(&path).context("proxy: Failed to remove PID file")?;
347            Ok(false)
348        }
349    }
350}
351
352fn write_pid_file(path: &Path) -> Result<()> {
353    if path.exists() {
354        std::fs::remove_file(path)?;
355    }
356    let pid = std::process::id().to_string();
357    log::debug!("server: writing PID {} to file {:?}", pid, path);
358    std::fs::write(path, pid).context("Failed to write PID file")
359}
360
361async fn handle_io<R, W>(mut reader: R, mut writer: W, socket_name: &str) -> Result<()>
362where
363    R: AsyncRead + Unpin,
364    W: AsyncWrite + Unpin,
365{
366    use remote::protocol::read_message_raw;
367
368    let mut buffer = Vec::new();
369    loop {
370        read_message_raw(&mut reader, &mut buffer)
371            .await
372            .with_context(|| format!("proxy: failed to read message from {}", socket_name))?;
373
374        write_size_prefixed_buffer(&mut writer, &mut buffer)
375            .await
376            .with_context(|| format!("proxy: failed to write message to {}", socket_name))?;
377
378        writer.flush().await?;
379
380        buffer.clear();
381    }
382}
383
384async fn write_size_prefixed_buffer<S: AsyncWrite + Unpin>(
385    stream: &mut S,
386    buffer: &mut Vec<u8>,
387) -> Result<()> {
388    let len = buffer.len() as u32;
389    stream.write_all(len.to_le_bytes().as_slice()).await?;
390    stream.write_all(buffer).await?;
391    Ok(())
392}