remote server: Do not spawn server when proxy reconnects (#18864)

Thorsten Ball and Bennet Bo Fenner created

This ensures that we only ever reconnect to a running server and not
spawn a new server with no state.

This avoids the problem of the server process crashing, `proxy`
reconnecting, starting a new server, and the user getting errors like
"unknown buffer id: ...".

Release Notes:

- N/A

---------

Co-authored-by: Bennet Bo Fenner <bennet@zed.dev>

Change summary

Cargo.lock                       |   1 
crates/remote/Cargo.toml         |   1 
crates/remote/src/proxy.rs       |  25 ++++
crates/remote/src/remote.rs      |   1 
crates/remote/src/ssh_session.rs | 176 +++++++++++++++++++++++----------
crates/remote_server/src/main.rs |  18 +++
crates/remote_server/src/unix.rs | 117 +++++++++++++++------
7 files changed, 248 insertions(+), 91 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -9089,6 +9089,7 @@ dependencies = [
  "serde_json",
  "smol",
  "tempfile",
+ "thiserror",
  "util",
 ]
 

crates/remote/Cargo.toml 🔗

@@ -31,6 +31,7 @@ serde.workspace = true
 serde_json.workspace = true
 smol.workspace = true
 tempfile.workspace = true
+thiserror.workspace = true
 util.workspace = true
 
 [dev-dependencies]

crates/remote/src/proxy.rs 🔗

@@ -0,0 +1,25 @@
+use thiserror::Error;
+
+#[derive(Error, Debug)]
+pub enum ProxyLaunchError {
+    #[error("Attempted reconnect, but server not running.")]
+    ServerNotRunning,
+}
+
+impl ProxyLaunchError {
+    pub fn to_exit_code(&self) -> i32 {
+        match self {
+            // We're using 90 as the exit code, because 0-78 are often taken
+            // by shells and other conventions and >128 also has certain meanings
+            // in certain contexts.
+            Self::ServerNotRunning => 90,
+        }
+    }
+
+    pub fn from_exit_code(exit_code: i32) -> Option<Self> {
+        match exit_code {
+            90 => Some(Self::ServerNotRunning),
+            _ => None,
+        }
+    }
+}

crates/remote/src/ssh_session.rs 🔗

@@ -3,6 +3,7 @@ use crate::{
     protocol::{
         message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
     },
+    proxy::ProxyLaunchError,
 };
 use anyhow::{anyhow, Context as _, Result};
 use collections::HashMap;
@@ -271,6 +272,7 @@ enum State {
         attempts: usize,
     },
     ReconnectExhausted,
+    ServerNotRunning,
 }
 
 impl fmt::Display for State {
@@ -282,6 +284,7 @@ impl fmt::Display for State {
             Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
             Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
             Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
+            Self::ServerNotRunning { .. } => write!(f, "server not running"),
         }
     }
 }
@@ -297,10 +300,23 @@ impl State {
     }
 
     fn can_reconnect(&self) -> bool {
-        matches!(
-            self,
-            Self::Connected { .. } | Self::HeartbeatMissed { .. } | Self::ReconnectFailed { .. }
-        )
+        match self {
+            Self::Connected { .. }
+            | Self::HeartbeatMissed { .. }
+            | Self::ReconnectFailed { .. } => true,
+            State::Connecting
+            | State::Reconnecting
+            | State::ReconnectExhausted
+            | State::ServerNotRunning => false,
+        }
+    }
+
+    fn is_reconnect_failed(&self) -> bool {
+        matches!(self, Self::ReconnectFailed { .. })
+    }
+
+    fn is_reconnecting(&self) -> bool {
+        matches!(self, Self::Reconnecting { .. })
     }
 
     fn heartbeat_recovered(self) -> Self {
@@ -377,6 +393,7 @@ impl From<&State> for ConnectionState {
             State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
             State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
             State::ReconnectExhausted => Self::Disconnected,
+            State::ServerNotRunning => Self::Disconnected,
         }
     }
 }
@@ -426,6 +443,7 @@ impl SshRemoteClient {
 
             let (ssh_connection, ssh_proxy_process) = Self::establish_connection(
                 unique_identifier,
+                false,
                 connection_options,
                 delegate.clone(),
                 &mut cx,
@@ -496,6 +514,7 @@ impl SshRemoteClient {
             } else {
                 "no state set".to_string()
             };
+            log::info!("aborting reconnect, because not in state that allows reconnecting");
             return Err(anyhow!(error));
         }
 
@@ -527,7 +546,10 @@ impl SshRemoteClient {
                 forwarder,
                 ..
             } => (attempts, ssh_connection, delegate, forwarder),
-            State::Connecting | State::Reconnecting | State::ReconnectExhausted => unreachable!(),
+            State::Connecting
+            | State::Reconnecting
+            | State::ReconnectExhausted
+            | State::ServerNotRunning => unreachable!(),
         };
 
         let attempts = attempts + 1;
@@ -536,11 +558,12 @@ impl SshRemoteClient {
                 "Failed to reconnect to after {} attempts, giving up",
                 MAX_RECONNECT_ATTEMPTS
             );
-            *lock = Some(State::ReconnectExhausted);
+            drop(lock);
+            self.set_state(State::ReconnectExhausted, cx);
             return Ok(());
         }
-        *lock = Some(State::Reconnecting);
         drop(lock);
+        self.set_state(State::Reconnecting, cx);
 
         log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
 
@@ -580,6 +603,7 @@ impl SshRemoteClient {
 
             let (ssh_connection, ssh_process) = match Self::establish_connection(
                 identifier,
+                true,
                 connection_options,
                 delegate.clone(),
                 &mut cx,
@@ -616,33 +640,39 @@ impl SshRemoteClient {
         cx.spawn(|this, mut cx| async move {
             let new_state = reconnect_task.await;
             this.update(&mut cx, |this, cx| {
-                match &new_state {
-                    State::Connecting
-                    | State::Reconnecting { .. }
-                    | State::HeartbeatMissed { .. } => {}
-                    State::Connected { .. } => {
-                        log::info!("Successfully reconnected");
-                    }
-                    State::ReconnectFailed {
-                        error, attempts, ..
-                    } => {
-                        log::error!(
-                            "Reconnect attempt {} failed: {:?}. Starting new attempt...",
-                            attempts,
-                            error
-                        );
-                    }
-                    State::ReconnectExhausted => {
-                        log::error!("Reconnect attempt failed and all attempts exhausted");
+                this.try_set_state(cx, |old_state| {
+                    if old_state.is_reconnecting() {
+                        match &new_state {
+                            State::Connecting
+                            | State::Reconnecting { .. }
+                            | State::HeartbeatMissed { .. }
+                            | State::ServerNotRunning => {}
+                            State::Connected { .. } => {
+                                log::info!("Successfully reconnected");
+                            }
+                            State::ReconnectFailed {
+                                error, attempts, ..
+                            } => {
+                                log::error!(
+                                    "Reconnect attempt {} failed: {:?}. Starting new attempt...",
+                                    attempts,
+                                    error
+                                );
+                            }
+                            State::ReconnectExhausted => {
+                                log::error!("Reconnect attempt failed and all attempts exhausted");
+                            }
+                        }
+                        Some(new_state)
+                    } else {
+                        None
                     }
-                }
+                });
 
-                let reconnect_failed = matches!(new_state, State::ReconnectFailed { .. });
-                *this.state.lock() = Some(new_state);
-                cx.notify();
-                if reconnect_failed {
+                if this.state_is(State::is_reconnect_failed) {
                     this.reconnect(cx)
                 } else {
+                    log::debug!("State has transition from Reconnecting into new state while attempting reconnect. Ignoring new state.");
                     Ok(())
                 }
             })
@@ -676,8 +706,10 @@ impl SshRemoteClient {
                             missed_heartbeats,
                             MAX_MISSED_HEARTBEATS
                         );
-                    } else {
+                    } else if missed_heartbeats != 0 {
                         missed_heartbeats = 0;
+                    } else {
+                        continue;
                     }
 
                     let result = this.update(&mut cx, |this, mut cx| {
@@ -697,12 +729,12 @@ impl SshRemoteClient {
         cx: &mut ModelContext<Self>,
     ) -> ControlFlow<()> {
         let state = self.state.lock().take().unwrap();
-        self.state.lock().replace(if missed_heartbeats > 0 {
+        let next_state = if missed_heartbeats > 0 {
             state.heartbeat_missed()
         } else {
             state.heartbeat_recovered()
-        });
-        cx.notify();
+        };
+        self.set_state(next_state, cx);
 
         if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
             log::error!(
@@ -743,7 +775,7 @@ impl SshRemoteClient {
                 select_biased! {
                     outgoing = outgoing_rx.next().fuse() => {
                         let Some(outgoing) = outgoing else {
-                            return anyhow::Ok(());
+                            return anyhow::Ok(None);
                         };
 
                         write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
@@ -755,11 +787,7 @@ impl SshRemoteClient {
                                 child_stdin.close().await?;
                                 outgoing_rx.close();
                                 let status = ssh_proxy_process.status().await?;
-                                if !status.success() {
-                                    log::error!("ssh process exited with status: {status:?}");
-                                    return Err(anyhow!("ssh process exited with non-zero status code: {:?}", status.code()));
-                                }
-                                return Ok(());
+                                return Ok(status.code());
                             }
                             Ok(len) => {
                                 if len < stdout_buffer.len() {
@@ -813,19 +841,56 @@ impl SshRemoteClient {
         cx.spawn(|mut cx| async move {
             let result = io_task.await;
 
-            if let Err(error) = result {
-                log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
-                this.update(&mut cx, |this, cx| {
-                    this.reconnect(cx).ok();
-                })?;
+            match result {
+                Ok(Some(exit_code)) => {
+                    if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
+                        match error {
+                            ProxyLaunchError::ServerNotRunning => {
+                                log::error!("failed to reconnect because server is not running");
+                                this.update(&mut cx, |this, cx| {
+                                    this.set_state(State::ServerNotRunning, cx);
+                                })?;
+                            }
+                        }
+                    } else if exit_code > 0 {
+                        log::error!("proxy process terminated unexpectedly");
+                    }
+                }
+                Ok(None) => {}
+                Err(error) => {
+                    log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
+                    this.update(&mut cx, |this, cx| {
+                        this.reconnect(cx).ok();
+                    })?;
+                }
             }
-
             Ok(())
         })
     }
 
+    fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
+        self.state.lock().as_ref().map_or(false, check)
+    }
+
+    fn try_set_state(
+        &self,
+        cx: &mut ModelContext<Self>,
+        map: impl FnOnce(&State) -> Option<State>,
+    ) {
+        if let Some(new_state) = self.state.lock().as_ref().and_then(map) {
+            self.set_state(new_state, cx);
+        }
+    }
+
+    fn set_state(&self, state: State, cx: &mut ModelContext<Self>) {
+        log::info!("setting state to '{}'", &state);
+        self.state.lock().replace(state);
+        cx.notify();
+    }
+
     async fn establish_connection(
         unique_identifier: String,
+        reconnect: bool,
         connection_options: SshConnectionOptions,
         delegate: Arc<dyn SshClientDelegate>,
         cx: &mut AsyncAppContext,
@@ -851,14 +916,19 @@ impl SshRemoteClient {
 
         delegate.set_status(Some("Starting proxy"), cx);
 
+        let mut start_proxy_command = format!(
+            "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
+            std::env::var("RUST_LOG").unwrap_or_default(),
+            std::env::var("RUST_BACKTRACE").unwrap_or_default(),
+            remote_binary_path,
+            unique_identifier,
+        );
+        if reconnect {
+            start_proxy_command.push_str(" --reconnect");
+        }
+
         let ssh_proxy_process = socket
-            .ssh_command(format!(
-                "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
-                std::env::var("RUST_LOG").unwrap_or_default(),
-                std::env::var("RUST_BACKTRACE").unwrap_or_default(),
-                remote_binary_path,
-                unique_identifier,
-            ))
+            .ssh_command(start_proxy_command)
             // IMPORTANT: we kill this process when we drop the task that uses it.
             .kill_on_drop(true)
             .spawn()

crates/remote_server/src/main.rs 🔗

@@ -24,6 +24,8 @@ enum Commands {
         stdout_socket: PathBuf,
     },
     Proxy {
+        #[arg(long)]
+        reconnect: bool,
         #[arg(long)]
         identifier: String,
     },
@@ -37,6 +39,7 @@ fn main() {
 
 #[cfg(not(windows))]
 fn main() -> Result<()> {
+    use remote::proxy::ProxyLaunchError;
     use remote_server::unix::{execute_proxy, execute_run, init};
 
     let cli = Cli::parse();
@@ -51,9 +54,20 @@ fn main() -> Result<()> {
             init(Some(log_file))?;
             execute_run(pid_file, stdin_socket, stdout_socket)
         }
-        Some(Commands::Proxy { identifier }) => {
+        Some(Commands::Proxy {
+            identifier,
+            reconnect,
+        }) => {
             init(None)?;
-            execute_proxy(identifier)
+            match execute_proxy(identifier, reconnect) {
+                Ok(_) => Ok(()),
+                Err(err) => {
+                    if let Some(err) = err.downcast_ref::<ProxyLaunchError>() {
+                        std::process::exit(err.to_exit_code());
+                    }
+                    Err(err)
+                }
+            }
         }
         Some(Commands::Version) => {
             eprintln!("{}", env!("ZED_PKG_VERSION"));

crates/remote_server/src/unix.rs 🔗

@@ -4,6 +4,7 @@ use fs::RealFs;
 use futures::channel::mpsc;
 use futures::{select, select_biased, AsyncRead, AsyncWrite, FutureExt, SinkExt};
 use gpui::{AppContext, Context as _};
+use remote::proxy::ProxyLaunchError;
 use remote::ssh_session::ChannelClient;
 use remote::{
     json_log::LogRecord,
@@ -227,30 +228,62 @@ pub fn execute_run(pid_file: PathBuf, stdin_socket: PathBuf, stdout_socket: Path
     Ok(())
 }
 
-pub fn execute_proxy(identifier: String) -> Result<()> {
+#[derive(Clone)]
+struct ServerPaths {
+    log_file: PathBuf,
+    pid_file: PathBuf,
+    stdin_socket: PathBuf,
+    stdout_socket: PathBuf,
+}
+
+impl ServerPaths {
+    fn new(identifier: &str) -> Result<Self> {
+        let project_dir = create_state_directory(identifier)?;
+
+        let pid_file = project_dir.join("server.pid");
+        let stdin_socket = project_dir.join("stdin.sock");
+        let stdout_socket = project_dir.join("stdout.sock");
+        let log_file = project_dir.join("server.log");
+
+        Ok(Self {
+            pid_file,
+            stdin_socket,
+            stdout_socket,
+            log_file,
+        })
+    }
+}
+
+pub fn execute_proxy(identifier: String, is_reconnecting: bool) -> Result<()> {
     log::debug!("proxy: starting up. PID: {}", std::process::id());
 
-    let project_dir = ensure_project_dir(&identifier)?;
+    let server_paths = ServerPaths::new(&identifier)?;
 
-    let pid_file = project_dir.join("server.pid");
-    let stdin_socket = project_dir.join("stdin.sock");
-    let stdout_socket = project_dir.join("stdout.sock");
-    let log_file = project_dir.join("server.log");
+    let server_pid = check_pid_file(&server_paths.pid_file)?;
+    let server_running = server_pid.is_some();
+    if is_reconnecting {
+        if !server_running {
+            log::error!("proxy: attempted to reconnect, but no server running");
+            return Err(anyhow!(ProxyLaunchError::ServerNotRunning));
+        }
+    } else {
+        if let Some(pid) = server_pid {
+            log::debug!("proxy: found server already running with PID {}. Killing process and cleaning up files...", pid);
+            kill_running_server(pid, &server_paths)?;
+        }
 
-    let server_running = check_pid_file(&pid_file)?;
-    if !server_running {
-        spawn_server(&log_file, &pid_file, &stdin_socket, &stdout_socket)?;
-    };
+        spawn_server(&server_paths)?;
+    }
 
     let stdin_task = smol::spawn(async move {
         let stdin = Async::new(std::io::stdin())?;
-        let stream = smol::net::unix::UnixStream::connect(stdin_socket).await?;
+        let stream = smol::net::unix::UnixStream::connect(&server_paths.stdin_socket).await?;
         handle_io(stdin, stream, "stdin").await
     });
 
     let stdout_task: smol::Task<Result<()>> = smol::spawn(async move {
         let stdout = Async::new(std::io::stdout())?;
-        let stream = smol::net::unix::UnixStream::connect(stdout_socket).await?;
+        let stream = smol::net::unix::UnixStream::connect(&server_paths.stdout_socket).await?;
         handle_io(stream, stdout, "stdout").await
     });
 
@@ -267,50 +300,62 @@ pub fn execute_proxy(identifier: String) -> Result<()> {
     Ok(())
 }
 
-fn ensure_project_dir(identifier: &str) -> Result<PathBuf> {
-    let project_dir = env::var("HOME").unwrap_or_else(|_| ".".to_string());
-    let project_dir = PathBuf::from(project_dir)
+fn create_state_directory(identifier: &str) -> Result<PathBuf> {
+    let home_dir = env::var("HOME").unwrap_or_else(|_| ".".to_string());
+    let server_dir = PathBuf::from(home_dir)
         .join(".local")
         .join("state")
         .join("zed-remote-server")
         .join(identifier);
 
-    std::fs::create_dir_all(&project_dir)?;
+    std::fs::create_dir_all(&server_dir)?;
 
-    Ok(project_dir)
+    Ok(server_dir)
 }
 
-fn spawn_server(
-    log_file: &Path,
-    pid_file: &Path,
-    stdin_socket: &Path,
-    stdout_socket: &Path,
-) -> Result<()> {
-    if stdin_socket.exists() {
-        std::fs::remove_file(&stdin_socket)?;
+fn kill_running_server(pid: u32, paths: &ServerPaths) -> Result<()> {
+    log::info!("proxy: killing existing server with PID {}", pid);
+    std::process::Command::new("kill")
+        .arg(pid.to_string())
+        .output()
+        .context("proxy: failed to kill existing server")?;
+
+    for file in [&paths.pid_file, &paths.stdin_socket, &paths.stdout_socket] {
+        log::debug!(
+            "proxy: cleaning up file {:?} before starting new server",
+            file
+        );
+        std::fs::remove_file(file).ok();
+    }
+    Ok(())
+}
+
+fn spawn_server(paths: &ServerPaths) -> Result<()> {
+    if paths.stdin_socket.exists() {
+        std::fs::remove_file(&paths.stdin_socket)?;
     }
-    if stdout_socket.exists() {
-        std::fs::remove_file(&stdout_socket)?;
+    if paths.stdout_socket.exists() {
+        std::fs::remove_file(&paths.stdout_socket)?;
     }
 
     let binary_name = std::env::current_exe()?;
     let server_process = std::process::Command::new(binary_name)
         .arg("run")
         .arg("--log-file")
-        .arg(log_file)
+        .arg(&paths.log_file)
         .arg("--pid-file")
-        .arg(pid_file)
+        .arg(&paths.pid_file)
         .arg("--stdin-socket")
-        .arg(stdin_socket)
+        .arg(&paths.stdin_socket)
         .arg("--stdout-socket")
-        .arg(stdout_socket)
+        .arg(&paths.stdout_socket)
         .spawn()?;
 
     log::debug!("proxy: server started. PID: {:?}", server_process.id());
 
     let mut total_time_waited = std::time::Duration::from_secs(0);
     let wait_duration = std::time::Duration::from_millis(20);
-    while !stdout_socket.exists() || !stdin_socket.exists() {
+    while !paths.stdout_socket.exists() || !paths.stdin_socket.exists() {
         log::debug!("proxy: waiting for server to be ready to accept connections...");
         std::thread::sleep(wait_duration);
         total_time_waited += wait_duration;
@@ -323,12 +368,12 @@ fn spawn_server(
     Ok(())
 }
 
-fn check_pid_file(path: &Path) -> Result<bool> {
+fn check_pid_file(path: &Path) -> Result<Option<u32>> {
     let Some(pid) = std::fs::read_to_string(&path)
         .ok()
         .and_then(|contents| contents.parse::<u32>().ok())
     else {
-        return Ok(false);
+        return Ok(None);
     };
 
     log::debug!("proxy: Checking if process with PID {} exists...", pid);
@@ -339,12 +384,12 @@ fn check_pid_file(path: &Path) -> Result<bool> {
     {
         Ok(output) if output.status.success() => {
             log::debug!("proxy: Process with PID {} exists. NOT spawning new server, but attaching to existing one.", pid);
-            Ok(true)
+            Ok(Some(pid))
         }
         _ => {
             log::debug!("proxy: Found PID file, but process with that PID does not exist. Removing PID file.");
             std::fs::remove_file(&path).context("proxy: Failed to remove PID file")?;
-            Ok(false)
+            Ok(None)
         }
     }
 }