SSH Remoting: Fix reconnects (#19485)

Conrad Irwin and Nathan created

Before this change messages could be lost on reconnect, now they will
not be.

Release Notes:

- SSH Remoting: make reconnects smoother

---------

Co-authored-by: Nathan <nathan@zed.dev>

Change summary

Cargo.lock                                                    |   1 
crates/collab/src/tests/remote_editing_collaboration_tests.rs |   3 
crates/project/src/project.rs                                 |   4 
crates/proto/proto/zed.proto                                  |   8 
crates/proto/src/macros.rs                                    |   1 
crates/proto/src/proto.rs                                     |   2 
crates/remote/Cargo.toml                                      |   1 
crates/remote/src/ssh_session.rs                              | 598 +++-
crates/remote_server/src/remote_editing_tests.rs              |  46 
crates/remote_server/src/unix.rs                              |   2 
10 files changed, 467 insertions(+), 199 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -9119,6 +9119,7 @@ name = "remote"
 version = "0.1.0"
 dependencies = [
  "anyhow",
+ "async-trait",
  "collections",
  "fs",
  "futures 0.3.30",

crates/collab/src/tests/remote_editing_collaboration_tests.rs 🔗

@@ -26,7 +26,7 @@ async fn test_sharing_an_ssh_remote_project(
         .await;
 
     // Set up project on remote FS
-    let (client_ssh, server_ssh) = SshRemoteClient::fake(cx_a, server_cx);
+    let (port, server_ssh) = SshRemoteClient::fake_server(cx_a, server_cx);
     let remote_fs = FakeFs::new(server_cx.executor());
     remote_fs
         .insert_tree(
@@ -67,6 +67,7 @@ async fn test_sharing_an_ssh_remote_project(
         )
     });
 
+    let client_ssh = SshRemoteClient::fake_client(port, cx_a).await;
     let (project_a, worktree_id) = client_a
         .build_ssh_project("/code/project1", client_ssh, cx_a)
         .await;

crates/project/src/project.rs 🔗

@@ -1243,6 +1243,10 @@ impl Project {
         self.client.clone()
     }
 
+    pub fn ssh_client(&self) -> Option<Model<SshRemoteClient>> {
+        self.ssh_client.clone()
+    }
+
     pub fn user_store(&self) -> Model<UserStore> {
         self.user_store.clone()
     }

crates/proto/proto/zed.proto 🔗

@@ -12,6 +12,7 @@ message Envelope {
     uint32 id = 1;
     optional uint32 responding_to = 2;
     optional PeerId original_sender_id = 3;
+    optional uint32 ack_id = 266;
 
     oneof payload {
         Hello hello = 4;
@@ -295,7 +296,9 @@ message Envelope {
         OpenServerSettings open_server_settings = 263;
 
         GetPermalinkToLine get_permalink_to_line = 264;
-        GetPermalinkToLineResponse get_permalink_to_line_response = 265;  // current max
+        GetPermalinkToLineResponse get_permalink_to_line_response = 265;
+
+        FlushBufferedMessages flush_buffered_messages = 267;
     }
 
     reserved 87 to 88;
@@ -2522,3 +2525,6 @@ message GetPermalinkToLine {
 message GetPermalinkToLineResponse {
     string permalink = 1;
 }
+
+message FlushBufferedMessages {}
+message FlushBufferedMessagesResponse {}

crates/proto/src/macros.rs 🔗

@@ -32,6 +32,7 @@ macro_rules! messages {
                         responding_to,
                         original_sender_id,
                         payload: Some(envelope::Payload::$name(self)),
+                        ack_id: None,
                     }
                 }
 

crates/proto/src/proto.rs 🔗

@@ -372,6 +372,7 @@ messages!(
     (OpenServerSettings, Foreground),
     (GetPermalinkToLine, Foreground),
     (GetPermalinkToLineResponse, Foreground),
+    (FlushBufferedMessages, Foreground),
 );
 
 request_messages!(
@@ -498,6 +499,7 @@ request_messages!(
     (RemoveWorktree, Ack),
     (OpenServerSettings, OpenBufferResponse),
     (GetPermalinkToLine, GetPermalinkToLineResponse),
+    (FlushBufferedMessages, Ack),
 );
 
 entity_messages!(

crates/remote/Cargo.toml 🔗

@@ -19,6 +19,7 @@ test-support = ["fs/test-support"]
 
 [dependencies]
 anyhow.workspace = true
+async-trait.workspace = true
 collections.workspace = true
 fs.workspace = true
 futures.workspace = true

crates/remote/src/ssh_session.rs 🔗

@@ -6,6 +6,7 @@ use crate::{
     proxy::ProxyLaunchError,
 };
 use anyhow::{anyhow, Context as _, Result};
+use async_trait::async_trait;
 use collections::HashMap;
 use futures::{
     channel::{
@@ -13,7 +14,7 @@ use futures::{
         oneshot,
     },
     future::BoxFuture,
-    select_biased, AsyncReadExt as _, Future, FutureExt as _, SinkExt, StreamExt as _,
+    select_biased, AsyncReadExt as _, Future, FutureExt as _, StreamExt as _,
 };
 use gpui::{
     AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, SemanticVersion, Task,
@@ -30,13 +31,14 @@ use smol::{
 };
 use std::{
     any::TypeId,
+    collections::VecDeque,
     ffi::OsStr,
     fmt,
     ops::ControlFlow,
     path::{Path, PathBuf},
     sync::{
         atomic::{AtomicU32, Ordering::SeqCst},
-        Arc,
+        Arc, Weak,
     },
     time::{Duration, Instant},
 };
@@ -275,68 +277,6 @@ async fn run_cmd(command: &mut process::Command) -> Result<String> {
     }
 }
 
-struct ChannelForwarder {
-    quit_tx: UnboundedSender<()>,
-    forwarding_task: Task<(UnboundedSender<Envelope>, UnboundedReceiver<Envelope>)>,
-}
-
-impl ChannelForwarder {
-    fn new(
-        mut incoming_tx: UnboundedSender<Envelope>,
-        mut outgoing_rx: UnboundedReceiver<Envelope>,
-        cx: &AsyncAppContext,
-    ) -> (Self, UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
-        let (quit_tx, mut quit_rx) = mpsc::unbounded::<()>();
-
-        let (proxy_incoming_tx, mut proxy_incoming_rx) = mpsc::unbounded::<Envelope>();
-        let (mut proxy_outgoing_tx, proxy_outgoing_rx) = mpsc::unbounded::<Envelope>();
-
-        let forwarding_task = cx.background_executor().spawn(async move {
-            loop {
-                select_biased! {
-                    _ = quit_rx.next().fuse() => {
-                        break;
-                    },
-                    incoming_envelope = proxy_incoming_rx.next().fuse() => {
-                        if let Some(envelope) = incoming_envelope {
-                            if incoming_tx.send(envelope).await.is_err() {
-                                break;
-                            }
-                        } else {
-                            break;
-                        }
-                    }
-                    outgoing_envelope = outgoing_rx.next().fuse() => {
-                        if let Some(envelope) = outgoing_envelope {
-                            if proxy_outgoing_tx.send(envelope).await.is_err() {
-                                break;
-                            }
-                        } else {
-                            break;
-                        }
-                    }
-                }
-            }
-
-            (incoming_tx, outgoing_rx)
-        });
-
-        (
-            Self {
-                forwarding_task,
-                quit_tx,
-            },
-            proxy_incoming_tx,
-            proxy_outgoing_rx,
-        )
-    }
-
-    async fn into_channels(mut self) -> (UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
-        let _ = self.quit_tx.send(()).await;
-        self.forwarding_task.await
-    }
-}
-
 const MAX_MISSED_HEARTBEATS: usize = 5;
 const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
 const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
@@ -346,9 +286,8 @@ const MAX_RECONNECT_ATTEMPTS: usize = 3;
 enum State {
     Connecting,
     Connected {
-        ssh_connection: SshRemoteConnection,
+        ssh_connection: Box<dyn SshRemoteProcess>,
         delegate: Arc<dyn SshClientDelegate>,
-        forwarder: ChannelForwarder,
 
         multiplex_task: Task<Result<()>>,
         heartbeat_task: Task<Result<()>>,
@@ -356,18 +295,16 @@ enum State {
     HeartbeatMissed {
         missed_heartbeats: usize,
 
-        ssh_connection: SshRemoteConnection,
+        ssh_connection: Box<dyn SshRemoteProcess>,
         delegate: Arc<dyn SshClientDelegate>,
-        forwarder: ChannelForwarder,
 
         multiplex_task: Task<Result<()>>,
         heartbeat_task: Task<Result<()>>,
     },
     Reconnecting,
     ReconnectFailed {
-        ssh_connection: SshRemoteConnection,
+        ssh_connection: Box<dyn SshRemoteProcess>,
         delegate: Arc<dyn SshClientDelegate>,
-        forwarder: ChannelForwarder,
 
         error: anyhow::Error,
         attempts: usize,
@@ -391,11 +328,11 @@ impl fmt::Display for State {
 }
 
 impl State {
-    fn ssh_connection(&self) -> Option<&SshRemoteConnection> {
+    fn ssh_connection(&self) -> Option<&dyn SshRemoteProcess> {
         match self {
-            Self::Connected { ssh_connection, .. } => Some(ssh_connection),
-            Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection),
-            Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection),
+            Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()),
+            Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
+            Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
             _ => None,
         }
     }
@@ -429,14 +366,12 @@ impl State {
             Self::HeartbeatMissed {
                 ssh_connection,
                 delegate,
-                forwarder,
                 multiplex_task,
                 heartbeat_task,
                 ..
             } => Self::Connected {
                 ssh_connection,
                 delegate,
-                forwarder,
                 multiplex_task,
                 heartbeat_task,
             },
@@ -449,14 +384,12 @@ impl State {
             Self::Connected {
                 ssh_connection,
                 delegate,
-                forwarder,
                 multiplex_task,
                 heartbeat_task,
             } => Self::HeartbeatMissed {
                 missed_heartbeats: 1,
                 ssh_connection,
                 delegate,
-                forwarder,
                 multiplex_task,
                 heartbeat_task,
             },
@@ -464,14 +397,12 @@ impl State {
                 missed_heartbeats,
                 ssh_connection,
                 delegate,
-                forwarder,
                 multiplex_task,
                 heartbeat_task,
             } => Self::HeartbeatMissed {
                 missed_heartbeats: missed_heartbeats + 1,
                 ssh_connection,
                 delegate,
-                forwarder,
                 multiplex_task,
                 heartbeat_task,
             },
@@ -529,7 +460,8 @@ impl SshRemoteClient {
             let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
             let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 
-            let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?;
+            let client =
+                cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "client"))?;
             let this = cx.new_model(|_| Self {
                 client: client.clone(),
                 unique_identifier: unique_identifier.clone(),
@@ -537,26 +469,19 @@ impl SshRemoteClient {
                 state: Arc::new(Mutex::new(Some(State::Connecting))),
             })?;
 
-            let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
-                ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
-
-            let (ssh_connection, ssh_proxy_process) = Self::establish_connection(
+            let (ssh_connection, io_task) = Self::establish_connection(
                 unique_identifier,
                 false,
                 connection_options,
+                incoming_tx,
+                outgoing_rx,
+                connection_activity_tx,
                 delegate.clone(),
                 &mut cx,
             )
             .await?;
 
-            let multiplex_task = Self::multiplex(
-                this.downgrade(),
-                ssh_proxy_process,
-                proxy_incoming_tx,
-                proxy_outgoing_rx,
-                connection_activity_tx,
-                &mut cx,
-            );
+            let multiplex_task = Self::monitor(this.downgrade(), io_task, &cx);
 
             if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
                 log::error!("failed to establish connection: {}", error);
@@ -570,7 +495,6 @@ impl SshRemoteClient {
                 *this.state.lock() = Some(State::Connected {
                     ssh_connection,
                     delegate,
-                    forwarder: proxy,
                     multiplex_task,
                     heartbeat_task,
                 });
@@ -592,7 +516,6 @@ impl SshRemoteClient {
             heartbeat_task,
             ssh_connection,
             delegate,
-            forwarder,
         } = state
         else {
             return None;
@@ -616,7 +539,6 @@ impl SshRemoteClient {
             drop(heartbeat_task);
             drop(ssh_connection);
             drop(delegate);
-            drop(forwarder);
         })
     }
 
@@ -638,33 +560,30 @@ impl SshRemoteClient {
         }
 
         let state = lock.take().unwrap();
-        let (attempts, mut ssh_connection, delegate, forwarder) = match state {
+        let (attempts, mut ssh_connection, delegate) = match state {
             State::Connected {
                 ssh_connection,
                 delegate,
-                forwarder,
                 multiplex_task,
                 heartbeat_task,
             }
             | State::HeartbeatMissed {
                 ssh_connection,
                 delegate,
-                forwarder,
                 multiplex_task,
                 heartbeat_task,
                 ..
             } => {
                 drop(multiplex_task);
                 drop(heartbeat_task);
-                (0, ssh_connection, delegate, forwarder)
+                (0, ssh_connection, delegate)
             }
             State::ReconnectFailed {
                 attempts,
                 ssh_connection,
                 delegate,
-                forwarder,
                 ..
-            } => (attempts, ssh_connection, delegate, forwarder),
+            } => (attempts, ssh_connection, delegate),
             State::Connecting
             | State::Reconnecting
             | State::ReconnectExhausted
@@ -691,41 +610,37 @@ impl SshRemoteClient {
         let client = self.client.clone();
         let reconnect_task = cx.spawn(|this, mut cx| async move {
             macro_rules! failed {
-                ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr, $forwarder:expr) => {
+                ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr) => {
                     return State::ReconnectFailed {
                         error: anyhow!($error),
                         attempts: $attempts,
                         ssh_connection: $ssh_connection,
                         delegate: $delegate,
-                        forwarder: $forwarder,
                     };
                 };
             }
 
-            if let Err(error) = ssh_connection.master_process.kill() {
-                failed!(error, attempts, ssh_connection, delegate, forwarder);
-            };
-
             if let Err(error) = ssh_connection
-                .master_process
-                .status()
+                .kill()
                 .await
                 .context("Failed to kill ssh process")
             {
-                failed!(error, attempts, ssh_connection, delegate, forwarder);
-            }
+                failed!(error, attempts, ssh_connection, delegate);
+            };
 
-            let connection_options = ssh_connection.socket.connection_options.clone();
+            let connection_options = ssh_connection.connection_options();
 
-            let (incoming_tx, outgoing_rx) = forwarder.into_channels().await;
-            let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) =
-                ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
+            let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
+            let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
             let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 
-            let (ssh_connection, ssh_process) = match Self::establish_connection(
+            let (ssh_connection, io_task) = match Self::establish_connection(
                 identifier,
                 true,
                 connection_options,
+                incoming_tx,
+                outgoing_rx,
+                connection_activity_tx,
                 delegate.clone(),
                 &mut cx,
             )
@@ -733,27 +648,20 @@ impl SshRemoteClient {
             {
                 Ok((ssh_connection, ssh_process)) => (ssh_connection, ssh_process),
                 Err(error) => {
-                    failed!(error, attempts, ssh_connection, delegate, forwarder);
+                    failed!(error, attempts, ssh_connection, delegate);
                 }
             };
 
-            let multiplex_task = Self::multiplex(
-                this.clone(),
-                ssh_process,
-                proxy_incoming_tx,
-                proxy_outgoing_rx,
-                connection_activity_tx,
-                &mut cx,
-            );
+            let multiplex_task = Self::monitor(this.clone(), io_task, &cx);
+            client.reconnect(incoming_rx, outgoing_tx, &cx);
 
-            if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
-                failed!(error, attempts, ssh_connection, delegate, forwarder);
+            if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
+                failed!(error, attempts, ssh_connection, delegate);
             };
 
             State::Connected {
                 ssh_connection,
                 delegate,
-                forwarder,
                 multiplex_task,
                 heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, &mut cx),
             }
@@ -797,7 +705,7 @@ impl SshRemoteClient {
                     cx.emit(SshRemoteEvent::Disconnected);
                     Ok(())
                 } else {
-                    log::debug!("State has transition from Reconnecting into new state while attempting reconnect. Ignoring new state.");
+                    log::debug!("State has transition from Reconnecting into new state while attempting reconnect.");
                     Ok(())
                 }
             })
@@ -910,13 +818,12 @@ impl SshRemoteClient {
     }
 
     fn multiplex(
-        this: WeakModel<Self>,
         mut ssh_proxy_process: Child,
         incoming_tx: UnboundedSender<Envelope>,
         mut outgoing_rx: UnboundedReceiver<Envelope>,
         mut connection_activity_tx: Sender<()>,
         cx: &AsyncAppContext,
-    ) -> Task<Result<()>> {
+    ) -> Task<Result<i32>> {
         let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
         let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
         let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
@@ -988,7 +895,7 @@ impl SshRemoteClient {
             }
         });
 
-        cx.spawn(|mut cx| async move {
+        cx.spawn(|_| async move {
             let result = futures::select! {
                 result = stdin_task.fuse() => {
                     result.context("stdin")
@@ -1002,9 +909,22 @@ impl SshRemoteClient {
             };
 
             match result {
-                Ok(_) => {
-                    let exit_code = ssh_proxy_process.status().await?.code().unwrap_or(1);
+                Ok(_) => Ok(ssh_proxy_process.status().await?.code().unwrap_or(1)),
+                Err(error) => Err(error),
+            }
+        })
+    }
 
+    fn monitor(
+        this: WeakModel<Self>,
+        io_task: Task<Result<i32>>,
+        cx: &AsyncAppContext,
+    ) -> Task<Result<()>> {
+        cx.spawn(|mut cx| async move {
+            let result = io_task.await;
+
+            match result {
+                Ok(exit_code) => {
                     if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
                         match error {
                             ProxyLaunchError::ServerNotRunning => {
@@ -1058,21 +978,40 @@ impl SshRemoteClient {
         cx.notify();
     }
 
+    #[allow(clippy::too_many_arguments)]
     async fn establish_connection(
         unique_identifier: String,
         reconnect: bool,
         connection_options: SshConnectionOptions,
+        incoming_tx: UnboundedSender<Envelope>,
+        outgoing_rx: UnboundedReceiver<Envelope>,
+        connection_activity_tx: Sender<()>,
         delegate: Arc<dyn SshClientDelegate>,
         cx: &mut AsyncAppContext,
-    ) -> Result<(SshRemoteConnection, Child)> {
+    ) -> Result<(Box<dyn SshRemoteProcess>, Task<Result<i32>>)> {
+        #[cfg(any(test, feature = "test-support"))]
+        if let Some(fake) = fake::SshRemoteConnection::new(&connection_options) {
+            let io_task = fake::SshRemoteConnection::multiplex(
+                fake.connection_options(),
+                incoming_tx,
+                outgoing_rx,
+                connection_activity_tx,
+                cx,
+            )
+            .await;
+            return Ok((fake, io_task));
+        }
+
         let ssh_connection =
             SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?;
 
         let platform = ssh_connection.query_platform().await?;
         let remote_binary_path = delegate.remote_server_binary_path(platform, cx)?;
-        ssh_connection
-            .ensure_server_binary(&delegate, &remote_binary_path, platform, cx)
-            .await?;
+        if !reconnect {
+            ssh_connection
+                .ensure_server_binary(&delegate, &remote_binary_path, platform, cx)
+                .await?;
+        }
 
         let socket = ssh_connection.socket.clone();
         run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
@@ -1097,7 +1036,15 @@ impl SshRemoteClient {
             .spawn()
             .context("failed to spawn remote server")?;
 
-        Ok((ssh_connection, ssh_proxy_process))
+        let io_task = Self::multiplex(
+            ssh_proxy_process,
+            incoming_tx,
+            outgoing_rx,
+            connection_activity_tx,
+            &cx,
+        );
+
+        Ok((Box::new(ssh_connection), io_task))
     }
 
     pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
@@ -1109,7 +1056,7 @@ impl SshRemoteClient {
             .lock()
             .as_ref()
             .and_then(|state| state.ssh_connection())
-            .map(|ssh_connection| ssh_connection.socket.ssh_args())
+            .map(|ssh_connection| ssh_connection.ssh_args())
     }
 
     pub fn proto_client(&self) -> AnyProtoClient {
@@ -1124,7 +1071,6 @@ impl SshRemoteClient {
         self.connection_options.clone()
     }
 
-    #[cfg(not(any(test, feature = "test-support")))]
     pub fn connection_state(&self) -> ConnectionState {
         self.state
             .lock()
@@ -1133,37 +1079,59 @@ impl SshRemoteClient {
             .unwrap_or(ConnectionState::Disconnected)
     }
 
-    #[cfg(any(test, feature = "test-support"))]
-    pub fn connection_state(&self) -> ConnectionState {
-        ConnectionState::Connected
-    }
-
     pub fn is_disconnected(&self) -> bool {
         self.connection_state() == ConnectionState::Disconnected
     }
 
     #[cfg(any(test, feature = "test-support"))]
-    pub fn fake(
+    pub fn simulate_disconnect(&self, client_cx: &mut AppContext) -> Task<()> {
+        let port = self.connection_options().port.unwrap();
+        client_cx.spawn(|cx| async move {
+            let (channel, server_cx) = cx
+                .update_global(|c: &mut fake::ServerConnections, _| c.get(port))
+                .unwrap();
+
+            let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
+            let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
+            channel.reconnect(incoming_rx, outgoing_tx, &server_cx);
+        })
+    }
+
+    #[cfg(any(test, feature = "test-support"))]
+    pub fn fake_server(
         client_cx: &mut gpui::TestAppContext,
         server_cx: &mut gpui::TestAppContext,
-    ) -> (Model<Self>, Arc<ChannelClient>) {
-        use gpui::Context;
-
-        let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
-        let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
-
-        (
-            client_cx.update(|cx| {
-                let client = ChannelClient::new(server_to_client_rx, client_to_server_tx, cx);
-                cx.new_model(|_| Self {
-                    client,
-                    unique_identifier: "fake".to_string(),
-                    connection_options: SshConnectionOptions::default(),
-                    state: Arc::new(Mutex::new(None)),
-                })
-            }),
-            server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)),
-        )
+    ) -> (u16, Arc<ChannelClient>) {
+        use gpui::BorrowAppContext;
+        let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
+        let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
+        let server_client =
+            server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
+        let port = client_cx.update(|cx| {
+            cx.update_default_global(|c: &mut fake::ServerConnections, _| {
+                c.push(server_client.clone(), server_cx.to_async())
+            })
+        });
+        (port, server_client)
+    }
+
+    #[cfg(any(test, feature = "test-support"))]
+    pub async fn fake_client(port: u16, client_cx: &mut gpui::TestAppContext) -> Model<Self> {
+        client_cx
+            .update(|cx| {
+                Self::new(
+                    "fake".to_string(),
+                    SshConnectionOptions {
+                        host: "<fake>".to_string(),
+                        port: Some(port),
+                        ..Default::default()
+                    },
+                    Arc::new(fake::Delegate),
+                    cx,
+                )
+            })
+            .await
+            .unwrap()
     }
 }
 
@@ -1173,6 +1141,13 @@ impl From<SshRemoteClient> for AnyProtoClient {
     }
 }
 
+#[async_trait]
+trait SshRemoteProcess: Send + Sync {
+    async fn kill(&mut self) -> Result<()>;
+    fn ssh_args(&self) -> Vec<String>;
+    fn connection_options(&self) -> SshConnectionOptions;
+}
+
 struct SshRemoteConnection {
     socket: SshSocket,
     master_process: process::Child,
@@ -1187,6 +1162,25 @@ impl Drop for SshRemoteConnection {
     }
 }
 
+#[async_trait]
+impl SshRemoteProcess for SshRemoteConnection {
+    async fn kill(&mut self) -> Result<()> {
+        self.master_process.kill()?;
+
+        self.master_process.status().await?;
+
+        Ok(())
+    }
+
+    fn ssh_args(&self) -> Vec<String> {
+        self.socket.ssh_args()
+    }
+
+    fn connection_options(&self) -> SshConnectionOptions {
+        self.socket.connection_options.clone()
+    }
+}
+
 impl SshRemoteConnection {
     #[cfg(not(unix))]
     async fn new(
@@ -1469,9 +1463,13 @@ type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, ones
 
 pub struct ChannelClient {
     next_message_id: AtomicU32,
-    outgoing_tx: mpsc::UnboundedSender<Envelope>,
-    response_channels: ResponseChannels,             // Lock
-    message_handlers: Mutex<ProtoMessageHandlerSet>, // Lock
+    outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
+    buffer: Mutex<VecDeque<Envelope>>,
+    response_channels: ResponseChannels,
+    message_handlers: Mutex<ProtoMessageHandlerSet>,
+    max_received: AtomicU32,
+    name: &'static str,
+    task: Mutex<Task<Result<()>>>,
 }
 
 impl ChannelClient {
@@ -1479,32 +1477,59 @@ impl ChannelClient {
         incoming_rx: mpsc::UnboundedReceiver<Envelope>,
         outgoing_tx: mpsc::UnboundedSender<Envelope>,
         cx: &AppContext,
+        name: &'static str,
     ) -> Arc<Self> {
-        let this = Arc::new(Self {
-            outgoing_tx,
+        Arc::new_cyclic(|this| Self {
+            outgoing_tx: Mutex::new(outgoing_tx),
             next_message_id: AtomicU32::new(0),
+            max_received: AtomicU32::new(0),
             response_channels: ResponseChannels::default(),
             message_handlers: Default::default(),
-        });
-
-        Self::start_handling_messages(this.clone(), incoming_rx, cx);
-
-        this
+            buffer: Mutex::new(VecDeque::new()),
+            name,
+            task: Mutex::new(Self::start_handling_messages(
+                this.clone(),
+                incoming_rx,
+                &cx.to_async(),
+            )),
+        })
     }
 
     fn start_handling_messages(
-        this: Arc<Self>,
+        this: Weak<Self>,
         mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
-        cx: &AppContext,
-    ) {
+        cx: &AsyncAppContext,
+    ) -> Task<Result<()>> {
         cx.spawn(|cx| {
-            let this = Arc::downgrade(&this);
             async move {
                 let peer_id = PeerId { owner_id: 0, id: 0 };
                 while let Some(incoming) = incoming_rx.next().await {
                     let Some(this) = this.upgrade() else {
                         return anyhow::Ok(());
                     };
+                    if let Some(ack_id) = incoming.ack_id {
+                        let mut buffer = this.buffer.lock();
+                        while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
+                            buffer.pop_front();
+                        }
+                    }
+                    if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) =
+                        &incoming.payload
+                    {
+                        log::debug!("{}:ssh message received. name:FlushBufferedMessages", this.name);
+                        {
+                            let buffer = this.buffer.lock();
+                            for envelope in buffer.iter() {
+                                this.outgoing_tx.lock().unbounded_send(envelope.clone()).ok();
+                            }
+                        }
+                        let mut envelope = proto::Ack{}.into_envelope(0, Some(incoming.id), None);
+                        envelope.id = this.next_message_id.fetch_add(1, SeqCst);
+                        this.outgoing_tx.lock().unbounded_send(envelope).ok();
+                        continue;
+                    }
+
+                    this.max_received.store(incoming.id, SeqCst);
 
                     if let Some(request_id) = incoming.responding_to {
                         let request_id = MessageId(request_id);
@@ -1526,26 +1551,37 @@ impl ChannelClient {
                             this.clone().into(),
                             cx.clone(),
                         ) {
-                            log::debug!("ssh message received. name:{type_name}");
-                            match future.await {
-                                Ok(_) => {
-                                    log::debug!("ssh message handled. name:{type_name}");
+                            log::debug!("{}:ssh message received. name:{type_name}", this.name);
+                            cx.foreground_executor().spawn(async move {
+                                match future.await {
+                                    Ok(_) => {
+                                        log::debug!("{}:ssh message handled. name:{type_name}", this.name);
+                                    }
+                                    Err(error) => {
+                                        log::error!(
+                                            "{}:error handling message. type:{type_name}, error:{error}", this.name,
+                                        );
+                                    }
                                 }
-                                Err(error) => {
-                                    log::error!(
-                                        "error handling message. type:{type_name}, error:{error}",
-                                    );
-                                }
-                            }
+                            }).detach()
                         } else {
-                            log::error!("unhandled ssh message name:{type_name}");
+                            log::error!("{}:unhandled ssh message name:{type_name}", this.name);
                         }
                     }
                 }
                 anyhow::Ok(())
             }
         })
-        .detach();
+    }
+
+    pub fn reconnect(
+        self: &Arc<Self>,
+        incoming_rx: UnboundedReceiver<Envelope>,
+        outgoing_tx: UnboundedSender<Envelope>,
+        cx: &AsyncAppContext,
+    ) {
+        *self.outgoing_tx.lock() = outgoing_tx;
+        *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx);
     }
 
     pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
@@ -1581,6 +1617,26 @@ impl ChannelClient {
         }
     }
 
+    pub async fn resync(&self, timeout: Duration) -> Result<()> {
+        smol::future::or(
+            async {
+                self.request(proto::FlushBufferedMessages {}).await?;
+                for envelope in self.buffer.lock().iter() {
+                    self.outgoing_tx
+                        .lock()
+                        .unbounded_send(envelope.clone())
+                        .ok();
+                }
+                Ok(())
+            },
+            async {
+                smol::Timer::after(timeout).await;
+                Err(anyhow!("Timeout detected"))
+            },
+        )
+        .await
+    }
+
     pub async fn ping(&self, timeout: Duration) -> Result<()> {
         smol::future::or(
             async {
@@ -1610,7 +1666,8 @@ impl ChannelClient {
         let mut response_channels_lock = self.response_channels.lock();
         response_channels_lock.insert(MessageId(envelope.id), tx);
         drop(response_channels_lock);
-        let result = self.outgoing_tx.unbounded_send(envelope);
+
+        let result = self.send_buffered(envelope);
         async move {
             if let Err(error) = &result {
                 log::error!("failed to send message: {}", error);
@@ -1627,7 +1684,15 @@ impl ChannelClient {
 
     pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
         envelope.id = self.next_message_id.fetch_add(1, SeqCst);
-        self.outgoing_tx.unbounded_send(envelope)?;
+        self.send_buffered(envelope)
+    }
+
+    pub fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
+        envelope.ack_id = Some(self.max_received.load(SeqCst));
+        self.buffer.lock().push_back(envelope.clone());
+        // ignore errors on send (happen while we're reconnecting)
+        // assume that the global "disconnected" overlay is sufficient.
+        self.outgoing_tx.lock().unbounded_send(envelope).ok();
         Ok(())
     }
 }
@@ -1657,3 +1722,148 @@ impl ProtoClient for ChannelClient {
         false
     }
 }
+
+#[cfg(any(test, feature = "test-support"))]
+mod fake {
+    use std::{path::PathBuf, sync::Arc};
+
+    use anyhow::Result;
+    use async_trait::async_trait;
+    use futures::{
+        channel::{
+            mpsc::{self, Sender},
+            oneshot,
+        },
+        select_biased, FutureExt, SinkExt, StreamExt,
+    };
+    use gpui::{AsyncAppContext, BorrowAppContext, Global, SemanticVersion, Task};
+    use rpc::proto::Envelope;
+
+    use super::{
+        ChannelClient, SshClientDelegate, SshConnectionOptions, SshPlatform, SshRemoteProcess,
+    };
+
+    pub(super) struct SshRemoteConnection {
+        connection_options: SshConnectionOptions,
+    }
+
+    impl SshRemoteConnection {
+        pub(super) fn new(
+            connection_options: &SshConnectionOptions,
+        ) -> Option<Box<dyn SshRemoteProcess>> {
+            if connection_options.host == "<fake>" {
+                return Some(Box::new(Self {
+                    connection_options: connection_options.clone(),
+                }));
+            }
+            return None;
+        }
+        pub(super) async fn multiplex(
+            connection_options: SshConnectionOptions,
+            mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
+            mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
+            mut connection_activity_tx: Sender<()>,
+            cx: &mut AsyncAppContext,
+        ) -> Task<Result<i32>> {
+            let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
+            let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
+
+            let (channel, server_cx) = cx
+                .update(|cx| {
+                    cx.update_global(|conns: &mut ServerConnections, _| {
+                        conns.get(connection_options.port.unwrap())
+                    })
+                })
+                .unwrap();
+            channel.reconnect(server_incoming_rx, server_outgoing_tx, &server_cx);
+
+            // send to proxy_tx to get to the server.
+            // receive from
+
+            cx.background_executor().spawn(async move {
+                loop {
+                    select_biased! {
+                        server_to_client = server_outgoing_rx.next().fuse() => {
+                            let Some(server_to_client) = server_to_client else {
+                                return Ok(1)
+                            };
+                            connection_activity_tx.try_send(()).ok();
+                            client_incoming_tx.send(server_to_client).await.ok();
+                        }
+                        client_to_server = client_outgoing_rx.next().fuse() => {
+                            let Some(client_to_server) = client_to_server else {
+                                return Ok(1)
+                            };
+                            server_incoming_tx.send(client_to_server).await.ok();
+                        }
+                    }
+                }
+            })
+        }
+    }
+
+    #[async_trait]
+    impl SshRemoteProcess for SshRemoteConnection {
+        async fn kill(&mut self) -> Result<()> {
+            Ok(())
+        }
+
+        fn ssh_args(&self) -> Vec<String> {
+            Vec::new()
+        }
+
+        fn connection_options(&self) -> SshConnectionOptions {
+            self.connection_options.clone()
+        }
+    }
+
+    #[derive(Default)]
+    pub(super) struct ServerConnections(Vec<(Arc<ChannelClient>, AsyncAppContext)>);
+    impl Global for ServerConnections {}
+
+    impl ServerConnections {
+        pub(super) fn push(&mut self, server: Arc<ChannelClient>, cx: AsyncAppContext) -> u16 {
+            self.0.push((server.clone(), cx));
+            self.0.len() as u16 - 1
+        }
+
+        pub(super) fn get(&mut self, port: u16) -> (Arc<ChannelClient>, AsyncAppContext) {
+            self.0
+                .get(port as usize)
+                .expect("no fake server for port")
+                .clone()
+        }
+    }
+
+    pub(super) struct Delegate;
+
+    impl SshClientDelegate for Delegate {
+        fn ask_password(
+            &self,
+            _: String,
+            _: &mut AsyncAppContext,
+        ) -> oneshot::Receiver<Result<String>> {
+            unreachable!()
+        }
+        fn remote_server_binary_path(
+            &self,
+            _: SshPlatform,
+            _: &mut AsyncAppContext,
+        ) -> Result<PathBuf> {
+            unreachable!()
+        }
+        fn get_server_binary(
+            &self,
+            _: SshPlatform,
+            _: &mut AsyncAppContext,
+        ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>> {
+            unreachable!()
+        }
+        fn set_status(&self, _: Option<&str>, _: &mut AsyncAppContext) {
+            unreachable!()
+        }
+        fn set_error(&self, _: String, _: &mut AsyncAppContext) {
+            unreachable!()
+        }
+    }
+}

crates/remote_server/src/remote_editing_tests.rs 🔗

@@ -641,6 +641,47 @@ async fn test_open_server_settings(cx: &mut TestAppContext, server_cx: &mut Test
     })
 }
 
+#[gpui::test(iterations = 20)]
+async fn test_reconnect(cx: &mut TestAppContext, server_cx: &mut TestAppContext) {
+    let (project, _headless, fs) = init_test(cx, server_cx).await;
+
+    let (worktree, _) = project
+        .update(cx, |project, cx| {
+            project.find_or_create_worktree("/code/project1", true, cx)
+        })
+        .await
+        .unwrap();
+
+    let worktree_id = worktree.read_with(cx, |worktree, _| worktree.id());
+    let buffer = project
+        .update(cx, |project, cx| {
+            project.open_buffer((worktree_id, Path::new("src/lib.rs")), cx)
+        })
+        .await
+        .unwrap();
+
+    buffer.update(cx, |buffer, cx| {
+        assert_eq!(buffer.text(), "fn one() -> usize { 1 }");
+        let ix = buffer.text().find('1').unwrap();
+        buffer.edit([(ix..ix + 1, "100")], None, cx);
+    });
+
+    let client = cx.read(|cx| project.read(cx).ssh_client().unwrap());
+    client
+        .update(cx, |client, cx| client.simulate_disconnect(cx))
+        .detach();
+
+    project
+        .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
+        .await
+        .unwrap();
+
+    assert_eq!(
+        fs.load("/code/project1/src/lib.rs".as_ref()).await.unwrap(),
+        "fn one() -> usize { 100 }"
+    );
+}
+
 fn init_logger() {
     if std::env::var("RUST_LOG").is_ok() {
         env_logger::try_init().ok();
@@ -651,9 +692,9 @@ async fn init_test(
     cx: &mut TestAppContext,
     server_cx: &mut TestAppContext,
 ) -> (Model<Project>, Model<HeadlessProject>, Arc<FakeFs>) {
-    let (ssh_remote_client, ssh_server_client) = SshRemoteClient::fake(cx, server_cx);
     init_logger();
 
+    let (forwarder, ssh_server_client) = SshRemoteClient::fake_server(cx, server_cx);
     let fs = FakeFs::new(server_cx.executor());
     fs.insert_tree(
         "/code",
@@ -694,8 +735,9 @@ async fn init_test(
             cx,
         )
     });
-    let project = build_project(ssh_remote_client, cx);
 
+    let ssh = SshRemoteClient::fake_client(forwarder, cx).await;
+    let project = build_project(ssh, cx);
     project
         .update(cx, {
             let headless = headless.clone();

crates/remote_server/src/unix.rs 🔗

@@ -279,7 +279,7 @@ fn start_server(
     })
     .detach();
 
-    ChannelClient::new(incoming_rx, outgoing_tx, cx)
+    ChannelClient::new(incoming_rx, outgoing_tx, cx, "server")
 }
 
 fn init_paths() -> anyhow::Result<()> {