Cargo.lock 🔗
@@ -9089,6 +9089,7 @@ dependencies = [
"serde_json",
"smol",
"tempfile",
+ "thiserror",
"util",
]
Thorsten Ball and Bennet Bo Fenner created
This ensures that we only ever reconnect to a running server and not
spawn a new server with no state.
This avoids the problem of the server process crashing, `proxy`
reconnecting, starting a new server, and the user getting errors like
"unknown buffer id: ...".
Release Notes:
- N/A
---------
Co-authored-by: Bennet Bo Fenner <bennet@zed.dev>
Cargo.lock | 1
crates/remote/Cargo.toml | 1
crates/remote/src/proxy.rs | 25 ++++
crates/remote/src/remote.rs | 1
crates/remote/src/ssh_session.rs | 176 +++++++++++++++++++++++----------
crates/remote_server/src/main.rs | 18 +++
crates/remote_server/src/unix.rs | 117 +++++++++++++++------
7 files changed, 248 insertions(+), 91 deletions(-)
@@ -9089,6 +9089,7 @@ dependencies = [
"serde_json",
"smol",
"tempfile",
+ "thiserror",
"util",
]
@@ -31,6 +31,7 @@ serde.workspace = true
serde_json.workspace = true
smol.workspace = true
tempfile.workspace = true
+thiserror.workspace = true
util.workspace = true
[dev-dependencies]
@@ -0,0 +1,25 @@
+use thiserror::Error;
+
+#[derive(Error, Debug)]
+pub enum ProxyLaunchError {
+ #[error("Attempted reconnect, but server not running.")]
+ ServerNotRunning,
+}
+
+impl ProxyLaunchError {
+ pub fn to_exit_code(&self) -> i32 {
+ match self {
+ // We're using 90 as the exit code, because 0-78 are often taken
+ // by shells and other conventions and >128 also has certain meanings
+ // in certain contexts.
+ Self::ServerNotRunning => 90,
+ }
+ }
+
+ pub fn from_exit_code(exit_code: i32) -> Option<Self> {
+ match exit_code {
+ 90 => Some(Self::ServerNotRunning),
+ _ => None,
+ }
+ }
+}
@@ -1,5 +1,6 @@
pub mod json_log;
pub mod protocol;
+pub mod proxy;
pub mod ssh_session;
pub use ssh_session::{
@@ -3,6 +3,7 @@ use crate::{
protocol::{
message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
},
+ proxy::ProxyLaunchError,
};
use anyhow::{anyhow, Context as _, Result};
use collections::HashMap;
@@ -271,6 +272,7 @@ enum State {
attempts: usize,
},
ReconnectExhausted,
+ ServerNotRunning,
}
impl fmt::Display for State {
@@ -282,6 +284,7 @@ impl fmt::Display for State {
Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
+ Self::ServerNotRunning { .. } => write!(f, "server not running"),
}
}
}
@@ -297,10 +300,23 @@ impl State {
}
fn can_reconnect(&self) -> bool {
- matches!(
- self,
- Self::Connected { .. } | Self::HeartbeatMissed { .. } | Self::ReconnectFailed { .. }
- )
+ match self {
+ Self::Connected { .. }
+ | Self::HeartbeatMissed { .. }
+ | Self::ReconnectFailed { .. } => true,
+ State::Connecting
+ | State::Reconnecting
+ | State::ReconnectExhausted
+ | State::ServerNotRunning => false,
+ }
+ }
+
+ fn is_reconnect_failed(&self) -> bool {
+ matches!(self, Self::ReconnectFailed { .. })
+ }
+
+ fn is_reconnecting(&self) -> bool {
+ matches!(self, Self::Reconnecting { .. })
}
fn heartbeat_recovered(self) -> Self {
@@ -377,6 +393,7 @@ impl From<&State> for ConnectionState {
State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
State::ReconnectExhausted => Self::Disconnected,
+ State::ServerNotRunning => Self::Disconnected,
}
}
}
@@ -426,6 +443,7 @@ impl SshRemoteClient {
let (ssh_connection, ssh_proxy_process) = Self::establish_connection(
unique_identifier,
+ false,
connection_options,
delegate.clone(),
&mut cx,
@@ -496,6 +514,7 @@ impl SshRemoteClient {
} else {
"no state set".to_string()
};
+ log::info!("aborting reconnect, because not in state that allows reconnecting");
return Err(anyhow!(error));
}
@@ -527,7 +546,10 @@ impl SshRemoteClient {
forwarder,
..
} => (attempts, ssh_connection, delegate, forwarder),
- State::Connecting | State::Reconnecting | State::ReconnectExhausted => unreachable!(),
+ State::Connecting
+ | State::Reconnecting
+ | State::ReconnectExhausted
+ | State::ServerNotRunning => unreachable!(),
};
let attempts = attempts + 1;
@@ -536,11 +558,12 @@ impl SshRemoteClient {
"Failed to reconnect to after {} attempts, giving up",
MAX_RECONNECT_ATTEMPTS
);
- *lock = Some(State::ReconnectExhausted);
+ drop(lock);
+ self.set_state(State::ReconnectExhausted, cx);
return Ok(());
}
- *lock = Some(State::Reconnecting);
drop(lock);
+ self.set_state(State::Reconnecting, cx);
log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
@@ -580,6 +603,7 @@ impl SshRemoteClient {
let (ssh_connection, ssh_process) = match Self::establish_connection(
identifier,
+ true,
connection_options,
delegate.clone(),
&mut cx,
@@ -616,33 +640,39 @@ impl SshRemoteClient {
cx.spawn(|this, mut cx| async move {
let new_state = reconnect_task.await;
this.update(&mut cx, |this, cx| {
- match &new_state {
- State::Connecting
- | State::Reconnecting { .. }
- | State::HeartbeatMissed { .. } => {}
- State::Connected { .. } => {
- log::info!("Successfully reconnected");
- }
- State::ReconnectFailed {
- error, attempts, ..
- } => {
- log::error!(
- "Reconnect attempt {} failed: {:?}. Starting new attempt...",
- attempts,
- error
- );
- }
- State::ReconnectExhausted => {
- log::error!("Reconnect attempt failed and all attempts exhausted");
+ this.try_set_state(cx, |old_state| {
+ if old_state.is_reconnecting() {
+ match &new_state {
+ State::Connecting
+ | State::Reconnecting { .. }
+ | State::HeartbeatMissed { .. }
+ | State::ServerNotRunning => {}
+ State::Connected { .. } => {
+ log::info!("Successfully reconnected");
+ }
+ State::ReconnectFailed {
+ error, attempts, ..
+ } => {
+ log::error!(
+ "Reconnect attempt {} failed: {:?}. Starting new attempt...",
+ attempts,
+ error
+ );
+ }
+ State::ReconnectExhausted => {
+ log::error!("Reconnect attempt failed and all attempts exhausted");
+ }
+ }
+ Some(new_state)
+ } else {
+ None
}
- }
+ });
- let reconnect_failed = matches!(new_state, State::ReconnectFailed { .. });
- *this.state.lock() = Some(new_state);
- cx.notify();
- if reconnect_failed {
+ if this.state_is(State::is_reconnect_failed) {
this.reconnect(cx)
} else {
+ log::debug!("State has transition from Reconnecting into new state while attempting reconnect. Ignoring new state.");
Ok(())
}
})
@@ -676,8 +706,10 @@ impl SshRemoteClient {
missed_heartbeats,
MAX_MISSED_HEARTBEATS
);
- } else {
+ } else if missed_heartbeats != 0 {
missed_heartbeats = 0;
+ } else {
+ continue;
}
let result = this.update(&mut cx, |this, mut cx| {
@@ -697,12 +729,12 @@ impl SshRemoteClient {
cx: &mut ModelContext<Self>,
) -> ControlFlow<()> {
let state = self.state.lock().take().unwrap();
- self.state.lock().replace(if missed_heartbeats > 0 {
+ let next_state = if missed_heartbeats > 0 {
state.heartbeat_missed()
} else {
state.heartbeat_recovered()
- });
- cx.notify();
+ };
+ self.set_state(next_state, cx);
if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
log::error!(
@@ -743,7 +775,7 @@ impl SshRemoteClient {
select_biased! {
outgoing = outgoing_rx.next().fuse() => {
let Some(outgoing) = outgoing else {
- return anyhow::Ok(());
+ return anyhow::Ok(None);
};
write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
@@ -755,11 +787,7 @@ impl SshRemoteClient {
child_stdin.close().await?;
outgoing_rx.close();
let status = ssh_proxy_process.status().await?;
- if !status.success() {
- log::error!("ssh process exited with status: {status:?}");
- return Err(anyhow!("ssh process exited with non-zero status code: {:?}", status.code()));
- }
- return Ok(());
+ return Ok(status.code());
}
Ok(len) => {
if len < stdout_buffer.len() {
@@ -813,19 +841,56 @@ impl SshRemoteClient {
cx.spawn(|mut cx| async move {
let result = io_task.await;
- if let Err(error) = result {
- log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
- this.update(&mut cx, |this, cx| {
- this.reconnect(cx).ok();
- })?;
+ match result {
+ Ok(Some(exit_code)) => {
+ if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
+ match error {
+ ProxyLaunchError::ServerNotRunning => {
+ log::error!("failed to reconnect because server is not running");
+ this.update(&mut cx, |this, cx| {
+ this.set_state(State::ServerNotRunning, cx);
+ })?;
+ }
+ }
+ } else if exit_code > 0 {
+ log::error!("proxy process terminated unexpectedly");
+ }
+ }
+ Ok(None) => {}
+ Err(error) => {
+ log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
+ this.update(&mut cx, |this, cx| {
+ this.reconnect(cx).ok();
+ })?;
+ }
}
-
Ok(())
})
}
+ fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
+ self.state.lock().as_ref().map_or(false, check)
+ }
+
+ fn try_set_state(
+ &self,
+ cx: &mut ModelContext<Self>,
+ map: impl FnOnce(&State) -> Option<State>,
+ ) {
+ if let Some(new_state) = self.state.lock().as_ref().and_then(map) {
+ self.set_state(new_state, cx);
+ }
+ }
+
+ fn set_state(&self, state: State, cx: &mut ModelContext<Self>) {
+ log::info!("setting state to '{}'", &state);
+ self.state.lock().replace(state);
+ cx.notify();
+ }
+
async fn establish_connection(
unique_identifier: String,
+ reconnect: bool,
connection_options: SshConnectionOptions,
delegate: Arc<dyn SshClientDelegate>,
cx: &mut AsyncAppContext,
@@ -851,14 +916,19 @@ impl SshRemoteClient {
delegate.set_status(Some("Starting proxy"), cx);
+ let mut start_proxy_command = format!(
+ "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
+ std::env::var("RUST_LOG").unwrap_or_default(),
+ std::env::var("RUST_BACKTRACE").unwrap_or_default(),
+ remote_binary_path,
+ unique_identifier,
+ );
+ if reconnect {
+ start_proxy_command.push_str(" --reconnect");
+ }
+
let ssh_proxy_process = socket
- .ssh_command(format!(
- "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
- std::env::var("RUST_LOG").unwrap_or_default(),
- std::env::var("RUST_BACKTRACE").unwrap_or_default(),
- remote_binary_path,
- unique_identifier,
- ))
+ .ssh_command(start_proxy_command)
// IMPORTANT: we kill this process when we drop the task that uses it.
.kill_on_drop(true)
.spawn()
@@ -24,6 +24,8 @@ enum Commands {
stdout_socket: PathBuf,
},
Proxy {
+ #[arg(long)]
+ reconnect: bool,
#[arg(long)]
identifier: String,
},
@@ -37,6 +39,7 @@ fn main() {
#[cfg(not(windows))]
fn main() -> Result<()> {
+ use remote::proxy::ProxyLaunchError;
use remote_server::unix::{execute_proxy, execute_run, init};
let cli = Cli::parse();
@@ -51,9 +54,20 @@ fn main() -> Result<()> {
init(Some(log_file))?;
execute_run(pid_file, stdin_socket, stdout_socket)
}
- Some(Commands::Proxy { identifier }) => {
+ Some(Commands::Proxy {
+ identifier,
+ reconnect,
+ }) => {
init(None)?;
- execute_proxy(identifier)
+ match execute_proxy(identifier, reconnect) {
+ Ok(_) => Ok(()),
+ Err(err) => {
+ if let Some(err) = err.downcast_ref::<ProxyLaunchError>() {
+ std::process::exit(err.to_exit_code());
+ }
+ Err(err)
+ }
+ }
}
Some(Commands::Version) => {
eprintln!("{}", env!("ZED_PKG_VERSION"));
@@ -4,6 +4,7 @@ use fs::RealFs;
use futures::channel::mpsc;
use futures::{select, select_biased, AsyncRead, AsyncWrite, FutureExt, SinkExt};
use gpui::{AppContext, Context as _};
+use remote::proxy::ProxyLaunchError;
use remote::ssh_session::ChannelClient;
use remote::{
json_log::LogRecord,
@@ -227,30 +228,62 @@ pub fn execute_run(pid_file: PathBuf, stdin_socket: PathBuf, stdout_socket: Path
Ok(())
}
-pub fn execute_proxy(identifier: String) -> Result<()> {
+#[derive(Clone)]
+struct ServerPaths {
+ log_file: PathBuf,
+ pid_file: PathBuf,
+ stdin_socket: PathBuf,
+ stdout_socket: PathBuf,
+}
+
+impl ServerPaths {
+ fn new(identifier: &str) -> Result<Self> {
+ let project_dir = create_state_directory(identifier)?;
+
+ let pid_file = project_dir.join("server.pid");
+ let stdin_socket = project_dir.join("stdin.sock");
+ let stdout_socket = project_dir.join("stdout.sock");
+ let log_file = project_dir.join("server.log");
+
+ Ok(Self {
+ pid_file,
+ stdin_socket,
+ stdout_socket,
+ log_file,
+ })
+ }
+}
+
+pub fn execute_proxy(identifier: String, is_reconnecting: bool) -> Result<()> {
log::debug!("proxy: starting up. PID: {}", std::process::id());
- let project_dir = ensure_project_dir(&identifier)?;
+ let server_paths = ServerPaths::new(&identifier)?;
- let pid_file = project_dir.join("server.pid");
- let stdin_socket = project_dir.join("stdin.sock");
- let stdout_socket = project_dir.join("stdout.sock");
- let log_file = project_dir.join("server.log");
+ let server_pid = check_pid_file(&server_paths.pid_file)?;
+ let server_running = server_pid.is_some();
+ if is_reconnecting {
+ if !server_running {
+ log::error!("proxy: attempted to reconnect, but no server running");
+ return Err(anyhow!(ProxyLaunchError::ServerNotRunning));
+ }
+ } else {
+ if let Some(pid) = server_pid {
+ log::debug!("proxy: found server already running with PID {}. Killing process and cleaning up files...", pid);
+ kill_running_server(pid, &server_paths)?;
+ }
- let server_running = check_pid_file(&pid_file)?;
- if !server_running {
- spawn_server(&log_file, &pid_file, &stdin_socket, &stdout_socket)?;
- };
+ spawn_server(&server_paths)?;
+ }
let stdin_task = smol::spawn(async move {
let stdin = Async::new(std::io::stdin())?;
- let stream = smol::net::unix::UnixStream::connect(stdin_socket).await?;
+ let stream = smol::net::unix::UnixStream::connect(&server_paths.stdin_socket).await?;
handle_io(stdin, stream, "stdin").await
});
let stdout_task: smol::Task<Result<()>> = smol::spawn(async move {
let stdout = Async::new(std::io::stdout())?;
- let stream = smol::net::unix::UnixStream::connect(stdout_socket).await?;
+ let stream = smol::net::unix::UnixStream::connect(&server_paths.stdout_socket).await?;
handle_io(stream, stdout, "stdout").await
});
@@ -267,50 +300,62 @@ pub fn execute_proxy(identifier: String) -> Result<()> {
Ok(())
}
-fn ensure_project_dir(identifier: &str) -> Result<PathBuf> {
- let project_dir = env::var("HOME").unwrap_or_else(|_| ".".to_string());
- let project_dir = PathBuf::from(project_dir)
+fn create_state_directory(identifier: &str) -> Result<PathBuf> {
+ let home_dir = env::var("HOME").unwrap_or_else(|_| ".".to_string());
+ let server_dir = PathBuf::from(home_dir)
.join(".local")
.join("state")
.join("zed-remote-server")
.join(identifier);
- std::fs::create_dir_all(&project_dir)?;
+ std::fs::create_dir_all(&server_dir)?;
- Ok(project_dir)
+ Ok(server_dir)
}
-fn spawn_server(
- log_file: &Path,
- pid_file: &Path,
- stdin_socket: &Path,
- stdout_socket: &Path,
-) -> Result<()> {
- if stdin_socket.exists() {
- std::fs::remove_file(&stdin_socket)?;
+fn kill_running_server(pid: u32, paths: &ServerPaths) -> Result<()> {
+ log::info!("proxy: killing existing server with PID {}", pid);
+ std::process::Command::new("kill")
+ .arg(pid.to_string())
+ .output()
+ .context("proxy: failed to kill existing server")?;
+
+ for file in [&paths.pid_file, &paths.stdin_socket, &paths.stdout_socket] {
+ log::debug!(
+ "proxy: cleaning up file {:?} before starting new server",
+ file
+ );
+ std::fs::remove_file(file).ok();
+ }
+ Ok(())
+}
+
+fn spawn_server(paths: &ServerPaths) -> Result<()> {
+ if paths.stdin_socket.exists() {
+ std::fs::remove_file(&paths.stdin_socket)?;
}
- if stdout_socket.exists() {
- std::fs::remove_file(&stdout_socket)?;
+ if paths.stdout_socket.exists() {
+ std::fs::remove_file(&paths.stdout_socket)?;
}
let binary_name = std::env::current_exe()?;
let server_process = std::process::Command::new(binary_name)
.arg("run")
.arg("--log-file")
- .arg(log_file)
+ .arg(&paths.log_file)
.arg("--pid-file")
- .arg(pid_file)
+ .arg(&paths.pid_file)
.arg("--stdin-socket")
- .arg(stdin_socket)
+ .arg(&paths.stdin_socket)
.arg("--stdout-socket")
- .arg(stdout_socket)
+ .arg(&paths.stdout_socket)
.spawn()?;
log::debug!("proxy: server started. PID: {:?}", server_process.id());
let mut total_time_waited = std::time::Duration::from_secs(0);
let wait_duration = std::time::Duration::from_millis(20);
- while !stdout_socket.exists() || !stdin_socket.exists() {
+ while !paths.stdout_socket.exists() || !paths.stdin_socket.exists() {
log::debug!("proxy: waiting for server to be ready to accept connections...");
std::thread::sleep(wait_duration);
total_time_waited += wait_duration;
@@ -323,12 +368,12 @@ fn spawn_server(
Ok(())
}
-fn check_pid_file(path: &Path) -> Result<bool> {
+fn check_pid_file(path: &Path) -> Result<Option<u32>> {
let Some(pid) = std::fs::read_to_string(&path)
.ok()
.and_then(|contents| contents.parse::<u32>().ok())
else {
- return Ok(false);
+ return Ok(None);
};
log::debug!("proxy: Checking if process with PID {} exists...", pid);
@@ -339,12 +384,12 @@ fn check_pid_file(path: &Path) -> Result<bool> {
{
Ok(output) if output.status.success() => {
log::debug!("proxy: Process with PID {} exists. NOT spawning new server, but attaching to existing one.", pid);
- Ok(true)
+ Ok(Some(pid))
}
_ => {
log::debug!("proxy: Found PID file, but process with that PID does not exist. Removing PID file.");
std::fs::remove_file(&path).context("proxy: Failed to remove PID file")?;
- Ok(false)
+ Ok(None)
}
}
}