Cargo.lock 🔗
@@ -11885,6 +11885,7 @@ dependencies = [
"pretty_assertions",
"project",
"recent_projects",
+ "remote",
"rpc",
"serde",
"settings",
Bennet Bo Fenner and Thorsten created
Co-Authored-by: Thorsten <thorsten@zed.dev>
Release Notes:
- N/A
---------
Co-authored-by: Thorsten <thorsten@zed.dev>
Cargo.lock | 1
crates/project/src/project.rs | 6
crates/remote/src/remote.rs | 4
crates/remote/src/ssh_session.rs | 514 +++++++++++++++++++++++++-------
crates/title_bar/Cargo.toml | 1
crates/title_bar/src/title_bar.rs | 10
6 files changed, 412 insertions(+), 124 deletions(-)
@@ -11885,6 +11885,7 @@ dependencies = [
"pretty_assertions",
"project",
"recent_projects",
+ "remote",
"rpc",
"serde",
"settings",
@@ -1263,8 +1263,10 @@ impl Project {
.clone()
}
- pub fn ssh_is_connected(&self, cx: &AppContext) -> Option<bool> {
- Some(!self.ssh_client.as_ref()?.read(cx).is_reconnect_underway())
+ pub fn ssh_connection_state(&self, cx: &AppContext) -> Option<remote::ConnectionState> {
+ self.ssh_client
+ .as_ref()
+ .map(|ssh| ssh.read(cx).connection_state())
}
pub fn replica_id(&self) -> ReplicaId {
@@ -2,4 +2,6 @@ pub mod json_log;
pub mod protocol;
pub mod ssh_session;
-pub use ssh_session::{SshClientDelegate, SshConnectionOptions, SshPlatform, SshRemoteClient};
+pub use ssh_session::{
+ ConnectionState, SshClientDelegate, SshConnectionOptions, SshPlatform, SshRemoteClient,
+};
@@ -31,7 +31,8 @@ use smol::{
use std::{
any::TypeId,
ffi::OsStr,
- mem,
+ fmt,
+ ops::ControlFlow,
path::{Path, PathBuf},
sync::{
atomic::{AtomicU32, Ordering::SeqCst},
@@ -40,7 +41,7 @@ use std::{
time::{Duration, Instant},
};
use tempfile::TempDir;
-use util::maybe;
+use util::ResultExt;
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
@@ -234,19 +235,157 @@ impl ChannelForwarder {
}
}
-struct SshRemoteClientState {
- ssh_connection: SshRemoteConnection,
- delegate: Arc<dyn SshClientDelegate>,
- forwarder: ChannelForwarder,
- multiplex_task: Task<Result<()>>,
- heartbeat_task: Task<Result<()>>,
+const MAX_MISSED_HEARTBEATS: usize = 5;
+const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
+const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
+
+const MAX_RECONNECT_ATTEMPTS: usize = 3;
+
+enum State {
+ Connecting,
+ Connected {
+ ssh_connection: SshRemoteConnection,
+ delegate: Arc<dyn SshClientDelegate>,
+ forwarder: ChannelForwarder,
+
+ multiplex_task: Task<Result<()>>,
+ heartbeat_task: Task<Result<()>>,
+ },
+ HeartbeatMissed {
+ missed_heartbeats: usize,
+
+ ssh_connection: SshRemoteConnection,
+ delegate: Arc<dyn SshClientDelegate>,
+ forwarder: ChannelForwarder,
+
+ multiplex_task: Task<Result<()>>,
+ heartbeat_task: Task<Result<()>>,
+ },
+ Reconnecting,
+ ReconnectFailed {
+ ssh_connection: SshRemoteConnection,
+ delegate: Arc<dyn SshClientDelegate>,
+ forwarder: ChannelForwarder,
+
+ error: anyhow::Error,
+ attempts: usize,
+ },
+ ReconnectExhausted,
+}
+
+impl fmt::Display for State {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ Self::Connecting => write!(f, "connecting"),
+ Self::Connected { .. } => write!(f, "connected"),
+ Self::Reconnecting => write!(f, "reconnecting"),
+ Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
+ Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
+ Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
+ }
+ }
+}
+
+impl State {
+ fn ssh_connection(&self) -> Option<&SshRemoteConnection> {
+ match self {
+ Self::Connected { ssh_connection, .. } => Some(ssh_connection),
+ Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection),
+ Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection),
+ _ => None,
+ }
+ }
+
+ fn can_reconnect(&self) -> bool {
+ matches!(
+ self,
+ Self::Connected { .. } | Self::HeartbeatMissed { .. } | Self::ReconnectFailed { .. }
+ )
+ }
+
+ fn heartbeat_recovered(self) -> Self {
+ match self {
+ Self::HeartbeatMissed {
+ ssh_connection,
+ delegate,
+ forwarder,
+ multiplex_task,
+ heartbeat_task,
+ ..
+ } => Self::Connected {
+ ssh_connection,
+ delegate,
+ forwarder,
+ multiplex_task,
+ heartbeat_task,
+ },
+ _ => self,
+ }
+ }
+
+ fn heartbeat_missed(self) -> Self {
+ match self {
+ Self::Connected {
+ ssh_connection,
+ delegate,
+ forwarder,
+ multiplex_task,
+ heartbeat_task,
+ } => Self::HeartbeatMissed {
+ missed_heartbeats: 1,
+ ssh_connection,
+ delegate,
+ forwarder,
+ multiplex_task,
+ heartbeat_task,
+ },
+ Self::HeartbeatMissed {
+ missed_heartbeats,
+ ssh_connection,
+ delegate,
+ forwarder,
+ multiplex_task,
+ heartbeat_task,
+ } => Self::HeartbeatMissed {
+ missed_heartbeats: missed_heartbeats + 1,
+ ssh_connection,
+ delegate,
+ forwarder,
+ multiplex_task,
+ heartbeat_task,
+ },
+ _ => self,
+ }
+ }
+}
+
+/// The state of the ssh connection.
+#[derive(Clone, Copy, Debug)]
+pub enum ConnectionState {
+ Connecting,
+ Connected,
+ HeartbeatMissed,
+ Reconnecting,
+ Disconnected,
+}
+
+impl From<&State> for ConnectionState {
+ fn from(value: &State) -> Self {
+ match value {
+ State::Connecting => Self::Connecting,
+ State::Connected { .. } => Self::Connected,
+ State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
+ State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
+ State::ReconnectExhausted => Self::Disconnected,
+ }
+ }
}
pub struct SshRemoteClient {
client: Arc<ChannelClient>,
unique_identifier: String,
connection_options: SshConnectionOptions,
- inner_state: Arc<Mutex<Option<SshRemoteClientState>>>,
+ state: Arc<Mutex<Option<State>>>,
}
impl Drop for SshRemoteClient {
@@ -266,6 +405,7 @@ impl SshRemoteClient {
let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
+ let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?;
let this = cx.new_model(|cx| {
cx.on_app_quit(|this: &mut Self, _| {
this.shutdown_processes();
@@ -273,47 +413,49 @@ impl SshRemoteClient {
})
.detach();
- let client = ChannelClient::new(incoming_rx, outgoing_tx, cx);
Self {
- client,
+ client: client.clone(),
unique_identifier: unique_identifier.clone(),
- connection_options: SshConnectionOptions::default(),
- inner_state: Arc::new(Mutex::new(None)),
+ connection_options: connection_options.clone(),
+ state: Arc::new(Mutex::new(Some(State::Connecting))),
}
})?;
- let inner_state = {
- let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
- ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
-
- let (ssh_connection, ssh_proxy_process) = Self::establish_connection(
- unique_identifier,
- connection_options,
- delegate.clone(),
- &mut cx,
- )
- .await?;
-
- let multiplex_task = Self::multiplex(
- this.downgrade(),
- ssh_proxy_process,
- proxy_incoming_tx,
- proxy_outgoing_rx,
- &mut cx,
- );
-
- SshRemoteClientState {
+ let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
+ ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
+
+ let (ssh_connection, ssh_proxy_process) = Self::establish_connection(
+ unique_identifier,
+ connection_options,
+ delegate.clone(),
+ &mut cx,
+ )
+ .await?;
+
+ let multiplex_task = Self::multiplex(
+ this.downgrade(),
+ ssh_proxy_process,
+ proxy_incoming_tx,
+ proxy_outgoing_rx,
+ &mut cx,
+ );
+
+ if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
+ log::error!("failed to establish connection: {}", error);
+ delegate.set_error(error.to_string(), &mut cx);
+ return Err(error);
+ }
+
+ let heartbeat_task = Self::heartbeat(this.downgrade(), &mut cx);
+
+ this.update(&mut cx, |this, _| {
+ *this.state.lock() = Some(State::Connected {
ssh_connection,
delegate,
forwarder: proxy,
multiplex_task,
- heartbeat_task: Self::heartbeat(this.downgrade(), &mut cx),
- }
- };
-
- this.update(&mut cx, |this, cx| {
- this.inner_state.lock().replace(inner_state);
- cx.notify();
+ heartbeat_task,
+ });
})?;
Ok(this)
@@ -321,78 +463,192 @@ impl SshRemoteClient {
}
fn shutdown_processes(&self) {
- let Some(mut state) = self.inner_state.lock().take() else {
+ let Some(state) = self.state.lock().take() else {
return;
};
log::info!("shutting down ssh processes");
+
+ let State::Connected {
+ multiplex_task,
+ heartbeat_task,
+ ..
+ } = state
+ else {
+ return;
+ };
// Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
// child of master_process.
- let task = mem::replace(&mut state.multiplex_task, Task::ready(Ok(())));
- drop(task);
+ drop(multiplex_task);
// Now drop the rest of state, which kills master process.
- drop(state);
+ drop(heartbeat_task);
}
- fn reconnect(&self, cx: &ModelContext<Self>) -> Result<()> {
- log::info!("Trying to reconnect to ssh server...");
- let Some(state) = self.inner_state.lock().take() else {
- return Err(anyhow!("reconnect is already in progress"));
- };
+ fn reconnect(&mut self, cx: &mut ModelContext<Self>) -> Result<()> {
+ let mut lock = self.state.lock();
- let workspace_identifier = self.unique_identifier.clone();
+ let can_reconnect = lock
+ .as_ref()
+ .map(|state| state.can_reconnect())
+ .unwrap_or(false);
+ if !can_reconnect {
+ let error = if let Some(state) = lock.as_ref() {
+ format!("invalid state, cannot reconnect while in state {state}")
+ } else {
+ "no state set".to_string()
+ };
+ return Err(anyhow!(error));
+ }
- let SshRemoteClientState {
- mut ssh_connection,
- delegate,
- forwarder: proxy,
- multiplex_task,
- heartbeat_task,
- } = state;
- drop(multiplex_task);
- drop(heartbeat_task);
+ let state = lock.take().unwrap();
+ let (attempts, mut ssh_connection, delegate, forwarder) = match state {
+ State::Connected {
+ ssh_connection,
+ delegate,
+ forwarder,
+ multiplex_task,
+ heartbeat_task,
+ }
+ | State::HeartbeatMissed {
+ ssh_connection,
+ delegate,
+ forwarder,
+ multiplex_task,
+ heartbeat_task,
+ ..
+ } => {
+ drop(multiplex_task);
+ drop(heartbeat_task);
+ (0, ssh_connection, delegate, forwarder)
+ }
+ State::ReconnectFailed {
+ attempts,
+ ssh_connection,
+ delegate,
+ forwarder,
+ ..
+ } => (attempts, ssh_connection, delegate, forwarder),
+ State::Connecting | State::Reconnecting | State::ReconnectExhausted => unreachable!(),
+ };
- cx.spawn(|this, mut cx| async move {
- let (incoming_tx, outgoing_rx) = proxy.into_channels().await;
+ let attempts = attempts + 1;
+ if attempts > MAX_RECONNECT_ATTEMPTS {
+ log::error!(
+ "Failed to reconnect to after {} attempts, giving up",
+ MAX_RECONNECT_ATTEMPTS
+ );
+ *lock = Some(State::ReconnectExhausted);
+ return Ok(());
+ }
+ *lock = Some(State::Reconnecting);
+ drop(lock);
+
+ log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
+
+ let identifier = self.unique_identifier.clone();
+ let client = self.client.clone();
+ let reconnect_task = cx.spawn(|this, mut cx| async move {
+ macro_rules! failed {
+ ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr, $forwarder:expr) => {
+ return State::ReconnectFailed {
+ error: anyhow!($error),
+ attempts: $attempts,
+ ssh_connection: $ssh_connection,
+ delegate: $delegate,
+ forwarder: $forwarder,
+ };
+ };
+ }
- ssh_connection.master_process.kill()?;
- ssh_connection
+ if let Err(error) = ssh_connection.master_process.kill() {
+ failed!(error, attempts, ssh_connection, delegate, forwarder);
+ };
+
+ if let Err(error) = ssh_connection
.master_process
.status()
.await
- .context("Failed to kill ssh process")?;
+ .context("Failed to kill ssh process")
+ {
+ failed!(error, attempts, ssh_connection, delegate, forwarder);
+ }
let connection_options = ssh_connection.socket.connection_options.clone();
- let (ssh_connection, ssh_process) = Self::establish_connection(
- workspace_identifier,
+ let (incoming_tx, outgoing_rx) = forwarder.into_channels().await;
+ let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) =
+ ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
+
+ let (ssh_connection, ssh_process) = match Self::establish_connection(
+ identifier,
connection_options,
delegate.clone(),
&mut cx,
)
- .await?;
+ .await
+ {
+ Ok((ssh_connection, ssh_process)) => (ssh_connection, ssh_process),
+ Err(error) => {
+ failed!(error, attempts, ssh_connection, delegate, forwarder);
+ }
+ };
- let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
- ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
+ let multiplex_task = Self::multiplex(
+ this.clone(),
+ ssh_process,
+ proxy_incoming_tx,
+ proxy_outgoing_rx,
+ &mut cx,
+ );
- let inner_state = SshRemoteClientState {
+ if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
+ failed!(error, attempts, ssh_connection, delegate, forwarder);
+ };
+
+ State::Connected {
ssh_connection,
delegate,
- forwarder: proxy,
- multiplex_task: Self::multiplex(
- this.clone(),
- ssh_process,
- proxy_incoming_tx,
- proxy_outgoing_rx,
- &mut cx,
- ),
+ forwarder,
+ multiplex_task,
heartbeat_task: Self::heartbeat(this.clone(), &mut cx),
- };
+ }
+ });
- this.update(&mut cx, |this, _| {
- this.inner_state.lock().replace(inner_state);
+ cx.spawn(|this, mut cx| async move {
+ let new_state = reconnect_task.await;
+ this.update(&mut cx, |this, cx| {
+ match &new_state {
+ State::Connecting
+ | State::Reconnecting { .. }
+ | State::HeartbeatMissed { .. } => {}
+ State::Connected { .. } => {
+ log::info!("Successfully reconnected");
+ }
+ State::ReconnectFailed {
+ error, attempts, ..
+ } => {
+ log::error!(
+ "Reconnect attempt {} failed: {:?}. Starting new attempt...",
+ attempts,
+ error
+ );
+ }
+ State::ReconnectExhausted => {
+ log::error!("Reconnect attempt failed and all attempts exhausted");
+ }
+ }
+
+ let reconnect_failed = matches!(new_state, State::ReconnectFailed { .. });
+ *this.state.lock() = Some(new_state);
+ cx.notify();
+ if reconnect_failed {
+ this.reconnect(cx)
+ } else {
+ Ok(())
+ }
})
})
- .detach();
+ .detach_and_log_err(cx);
+
Ok(())
}
@@ -403,10 +659,6 @@ impl SshRemoteClient {
cx.spawn(|mut cx| {
let this = this.clone();
async move {
- const MAX_MISSED_HEARTBEATS: usize = 5;
- const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
- const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
-
let mut missed_heartbeats = 0;
let mut timer = Timer::interval(HEARTBEAT_INTERVAL);
@@ -415,19 +667,7 @@ impl SshRemoteClient {
log::info!("Sending heartbeat to server...");
- let result = smol::future::or(
- async {
- client.request(proto::Ping {}).await?;
- Ok(())
- },
- async {
- smol::Timer::after(HEARTBEAT_TIMEOUT).await;
-
- Err(anyhow!("Timeout detected"))
- },
- )
- .await;
-
+ let result = client.ping(HEARTBEAT_TIMEOUT).await;
if result.is_err() {
missed_heartbeats += 1;
log::warn!(
@@ -440,17 +680,10 @@ impl SshRemoteClient {
missed_heartbeats = 0;
}
- if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
- log::error!(
- "Missed last {} hearbeats. Reconnecting...",
- missed_heartbeats
- );
-
- this.update(&mut cx, |this, cx| {
- this.reconnect(cx)
- .context("failed to reconnect after missing heartbeats")
- })
- .context("failed to update weak reference, SshRemoteClient lost?")??;
+ let result = this.update(&mut cx, |this, mut cx| {
+ this.handle_heartbeat_result(missed_heartbeats, &mut cx)
+ })?;
+ if result.is_break() {
return Ok(());
}
}
@@ -458,6 +691,34 @@ impl SshRemoteClient {
})
}
+ fn handle_heartbeat_result(
+ &mut self,
+ missed_heartbeats: usize,
+ cx: &mut ModelContext<Self>,
+ ) -> ControlFlow<()> {
+ let state = self.state.lock().take().unwrap();
+ self.state.lock().replace(if missed_heartbeats > 0 {
+ state.heartbeat_missed()
+ } else {
+ state.heartbeat_recovered()
+ });
+ cx.notify();
+
+ if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
+ log::error!(
+ "Missed last {} heartbeats. Reconnecting...",
+ missed_heartbeats
+ );
+
+ self.reconnect(cx)
+ .context("failed to start reconnect process after missing heartbeats")
+ .log_err();
+ ControlFlow::Break(())
+ } else {
+ ControlFlow::Continue(())
+ }
+ }
+
fn multiplex(
this: WeakModel<Self>,
mut ssh_proxy_process: Child,
@@ -611,10 +872,11 @@ impl SshRemoteClient {
}
pub fn ssh_args(&self) -> Option<Vec<String>> {
- let state = self.inner_state.lock();
- state
+ self.state
+ .lock()
.as_ref()
- .map(|state| state.ssh_connection.socket.ssh_args())
+ .and_then(|state| state.ssh_connection())
+ .map(|ssh_connection| ssh_connection.socket.ssh_args())
}
pub fn to_proto_client(&self) -> AnyProtoClient {
@@ -625,8 +887,12 @@ impl SshRemoteClient {
self.connection_options.connection_string()
}
- pub fn is_reconnect_underway(&self) -> bool {
- maybe!({ Some(self.inner_state.try_lock()?.is_none()) }).unwrap_or_default()
+ pub fn connection_state(&self) -> ConnectionState {
+ self.state
+ .lock()
+ .as_ref()
+ .map(ConnectionState::from)
+ .unwrap_or(ConnectionState::Disconnected)
}
#[cfg(any(test, feature = "test-support"))]
@@ -646,7 +912,7 @@ impl SshRemoteClient {
client,
unique_identifier: "fake".to_string(),
connection_options: SshConnectionOptions::default(),
- inner_state: Arc::new(Mutex::new(None)),
+ state: Arc::new(Mutex::new(None)),
})
}),
server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)),
@@ -1046,6 +1312,20 @@ impl ChannelClient {
}
}
+ pub async fn ping(&self, timeout: Duration) -> Result<()> {
+ smol::future::or(
+ async {
+ self.request(proto::Ping {}).await?;
+ Ok(())
+ },
+ async {
+ smol::Timer::after(timeout).await;
+ Err(anyhow!("Timeout detected"))
+ },
+ )
+ .await
+ }
+
pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
log::debug!("ssh send name:{}", T::NAME);
self.send_dynamic(payload.into_envelope(0, None, None))
@@ -41,6 +41,7 @@ gpui.workspace = true
notifications.workspace = true
project.workspace = true
recent_projects.workspace = true
+remote.workspace = true
rpc.workspace = true
serde.workspace = true
smallvec.workspace = true
@@ -265,10 +265,12 @@ impl TitleBar {
fn render_ssh_project_host(&self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
let host = self.project.read(cx).ssh_connection_string(cx)?;
let meta = SharedString::from(format!("Connected to: {host}"));
- let indicator_color = if self.project.read(cx).ssh_is_connected(cx)? {
- Color::Success
- } else {
- Color::Warning
+ let indicator_color = match self.project.read(cx).ssh_connection_state(cx)? {
+ remote::ConnectionState::Connecting => Color::Info,
+ remote::ConnectionState::Connected => Color::Success,
+ remote::ConnectionState::HeartbeatMissed => Color::Warning,
+ remote::ConnectionState::Reconnecting => Color::Warning,
+ remote::ConnectionState::Disconnected => Color::Error,
};
let indicator = div()
.absolute()