SSH connection pooling (#19692)

Conrad Irwin and Max created

Co-Authored-By: Max <max@zed.dev>

Closes #ISSUE

Release Notes:

- SSH Remoting: Reuse connections across hosts

---------

Co-authored-by: Max <max@zed.dev>

Change summary

crates/collab/src/tests/remote_editing_collaboration_tests.rs |   4 
crates/recent_projects/src/remote_servers.rs                  |  11 
crates/remote/src/ssh_session.rs                              | 800 ++--
crates/remote_server/src/remote_editing_tests.rs              |   4 
4 files changed, 483 insertions(+), 336 deletions(-)

Detailed changes

crates/collab/src/tests/remote_editing_collaboration_tests.rs 🔗

@@ -26,7 +26,7 @@ async fn test_sharing_an_ssh_remote_project(
         .await;
 
     // Set up project on remote FS
-    let (port, server_ssh) = SshRemoteClient::fake_server(cx_a, server_cx);
+    let (opts, server_ssh) = SshRemoteClient::fake_server(cx_a, server_cx);
     let remote_fs = FakeFs::new(server_cx.executor());
     remote_fs
         .insert_tree(
@@ -67,7 +67,7 @@ async fn test_sharing_an_ssh_remote_project(
         )
     });
 
-    let client_ssh = SshRemoteClient::fake_client(port, cx_a).await;
+    let client_ssh = SshRemoteClient::fake_client(opts, cx_a).await;
     let (project_a, worktree_id) = client_a
         .build_ssh_project("/code/project1", client_ssh, cx_a)
         .await;

crates/recent_projects/src/remote_servers.rs 🔗

@@ -17,6 +17,7 @@ use gpui::{
 use picker::Picker;
 use project::Project;
 use remote::SshConnectionOptions;
+use remote::SshRemoteClient;
 use settings::update_settings_file;
 use settings::Settings;
 use ui::{
@@ -46,6 +47,7 @@ pub struct RemoteServerProjects {
     scroll_handle: ScrollHandle,
     workspace: WeakView<Workspace>,
     selectable_items: SelectableItemList,
+    retained_connections: Vec<Model<SshRemoteClient>>,
 }
 
 struct CreateRemoteServer {
@@ -355,6 +357,7 @@ impl RemoteServerProjects {
             scroll_handle: ScrollHandle::new(),
             workspace,
             selectable_items: Default::default(),
+            retained_connections: Vec::new(),
         }
     }
 
@@ -424,7 +427,7 @@ impl RemoteServerProjects {
         let address_editor = editor.clone();
         let creating = cx.spawn(move |this, mut cx| async move {
             match connection.await {
-                Some(_) => this
+                Some(Some(client)) => this
                     .update(&mut cx, |this, cx| {
                         let _ = this.workspace.update(cx, |workspace, _| {
                             workspace
@@ -432,14 +435,14 @@ impl RemoteServerProjects {
                                 .telemetry()
                                 .report_app_event("create ssh server".to_string())
                         });
-
+                        this.retained_connections.push(client);
                         this.add_ssh_server(connection_options, cx);
                         this.mode = Mode::default_mode();
                         this.selectable_items.reset_selection();
                         cx.notify()
                     })
                     .log_err(),
-                None => this
+                _ => this
                     .update(&mut cx, |this, cx| {
                         address_editor.update(cx, |this, _| {
                             this.set_read_only(false);
@@ -1056,7 +1059,7 @@ impl RemoteServerProjects {
                             );
 
                             cx.spawn(|mut cx| async move {
-                                if confirmation.await.ok() == Some(1) {
+                                if confirmation.await.ok() == Some(0) {
                                     remote_servers
                                         .update(&mut cx, |this, cx| {
                                             this.delete_ssh_server(index, cx);

crates/remote/src/ssh_session.rs 🔗

@@ -13,17 +13,18 @@ use futures::{
         mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
         oneshot,
     },
-    future::BoxFuture,
+    future::{BoxFuture, Shared},
     select, select_biased, AsyncReadExt as _, Future, FutureExt as _, StreamExt as _,
 };
 use gpui::{
-    AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, SemanticVersion, Task,
-    WeakModel,
+    AppContext, AsyncAppContext, BorrowAppContext, Context, EventEmitter, Global, Model,
+    ModelContext, SemanticVersion, Task, WeakModel,
 };
 use parking_lot::Mutex;
 use rpc::{
     proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
-    AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError,
+    AnyProtoClient, EntityMessageSubscriber, ErrorExt, ProtoClient, ProtoMessageHandlerSet,
+    RpcError,
 };
 use smol::{
     fs,
@@ -56,7 +57,7 @@ pub struct SshSocket {
     socket_path: PathBuf,
 }
 
-#[derive(Debug, Default, Clone, PartialEq, Eq)]
+#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
 pub struct SshConnectionOptions {
     pub host: String,
     pub username: Option<String>,
@@ -290,7 +291,7 @@ const MAX_RECONNECT_ATTEMPTS: usize = 3;
 enum State {
     Connecting,
     Connected {
-        ssh_connection: Box<dyn SshRemoteProcess>,
+        ssh_connection: Arc<dyn RemoteConnection>,
         delegate: Arc<dyn SshClientDelegate>,
 
         multiplex_task: Task<Result<()>>,
@@ -299,7 +300,7 @@ enum State {
     HeartbeatMissed {
         missed_heartbeats: usize,
 
-        ssh_connection: Box<dyn SshRemoteProcess>,
+        ssh_connection: Arc<dyn RemoteConnection>,
         delegate: Arc<dyn SshClientDelegate>,
 
         multiplex_task: Task<Result<()>>,
@@ -307,7 +308,7 @@ enum State {
     },
     Reconnecting,
     ReconnectFailed {
-        ssh_connection: Box<dyn SshRemoteProcess>,
+        ssh_connection: Arc<dyn RemoteConnection>,
         delegate: Arc<dyn SshClientDelegate>,
 
         error: anyhow::Error,
@@ -332,7 +333,7 @@ impl fmt::Display for State {
 }
 
 impl State {
-    fn ssh_connection(&self) -> Option<&dyn SshRemoteProcess> {
+    fn ssh_connection(&self) -> Option<&dyn RemoteConnection> {
         match self {
             Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()),
             Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
@@ -462,7 +463,7 @@ impl SshRemoteClient {
         connection_options: SshConnectionOptions,
         cancellation: oneshot::Receiver<()>,
         delegate: Arc<dyn SshClientDelegate>,
-        cx: &AppContext,
+        cx: &mut AppContext,
     ) -> Task<Result<Option<Model<Self>>>> {
         cx.spawn(|mut cx| async move {
             let success = Box::pin(async move {
@@ -479,17 +480,28 @@ impl SshRemoteClient {
                     state: Arc::new(Mutex::new(Some(State::Connecting))),
                 })?;
 
-                let (ssh_connection, io_task) = Self::establish_connection(
+                let ssh_connection = cx
+                    .update(|cx| {
+                        cx.update_default_global(|pool: &mut ConnectionPool, cx| {
+                            pool.connect(connection_options, &delegate, cx)
+                        })
+                    })?
+                    .await
+                    .map_err(|e| e.cloned())?;
+                let remote_binary_path = ssh_connection
+                    .get_remote_binary_path(&delegate, false, &mut cx)
+                    .await?;
+
+                let io_task = ssh_connection.start_proxy(
+                    remote_binary_path,
                     unique_identifier,
                     false,
-                    connection_options,
                     incoming_tx,
                     outgoing_rx,
                     connection_activity_tx,
                     delegate.clone(),
                     &mut cx,
-                )
-                .await?;
+                );
 
                 let multiplex_task = Self::monitor(this.downgrade(), io_task, &cx);
 
@@ -578,7 +590,7 @@ impl SshRemoteClient {
         }
 
         let state = lock.take().unwrap();
-        let (attempts, mut ssh_connection, delegate) = match state {
+        let (attempts, ssh_connection, delegate) = match state {
             State::Connected {
                 ssh_connection,
                 delegate,
@@ -624,7 +636,7 @@ impl SshRemoteClient {
 
         log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
 
-        let identifier = self.unique_identifier.clone();
+        let unique_identifier = self.unique_identifier.clone();
         let client = self.client.clone();
         let reconnect_task = cx.spawn(|this, mut cx| async move {
             macro_rules! failed {
@@ -652,19 +664,33 @@ impl SshRemoteClient {
             let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
             let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
 
-            let (ssh_connection, io_task) = match Self::establish_connection(
-                identifier,
-                true,
-                connection_options,
-                incoming_tx,
-                outgoing_rx,
-                connection_activity_tx,
-                delegate.clone(),
-                &mut cx,
-            )
+            let (ssh_connection, io_task) = match async {
+                let ssh_connection = cx
+                    .update_global(|pool: &mut ConnectionPool, cx| {
+                        pool.connect(connection_options, &delegate, cx)
+                    })?
+                    .await
+                    .map_err(|error| error.cloned())?;
+
+                let remote_binary_path = ssh_connection
+                    .get_remote_binary_path(&delegate, true, &mut cx)
+                    .await?;
+
+                let io_task = ssh_connection.start_proxy(
+                    remote_binary_path,
+                    unique_identifier,
+                    true,
+                    incoming_tx,
+                    outgoing_rx,
+                    connection_activity_tx,
+                    delegate.clone(),
+                    &mut cx,
+                );
+                anyhow::Ok((ssh_connection, io_task))
+            }
             .await
             {
-                Ok((ssh_connection, ssh_process)) => (ssh_connection, ssh_process),
+                Ok((ssh_connection, io_task)) => (ssh_connection, io_task),
                 Err(error) => {
                     failed!(error, attempts, ssh_connection, delegate);
                 }
@@ -834,108 +860,6 @@ impl SshRemoteClient {
         }
     }
 
-    fn multiplex(
-        mut ssh_proxy_process: Child,
-        incoming_tx: UnboundedSender<Envelope>,
-        mut outgoing_rx: UnboundedReceiver<Envelope>,
-        mut connection_activity_tx: Sender<()>,
-        cx: &AsyncAppContext,
-    ) -> Task<Result<i32>> {
-        let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
-        let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
-        let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
-
-        let mut stdin_buffer = Vec::new();
-        let mut stdout_buffer = Vec::new();
-        let mut stderr_buffer = Vec::new();
-        let mut stderr_offset = 0;
-
-        let stdin_task = cx.background_executor().spawn(async move {
-            while let Some(outgoing) = outgoing_rx.next().await {
-                write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
-            }
-            anyhow::Ok(())
-        });
-
-        let stdout_task = cx.background_executor().spawn({
-            let mut connection_activity_tx = connection_activity_tx.clone();
-            async move {
-                loop {
-                    stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
-                    let len = child_stdout.read(&mut stdout_buffer).await?;
-
-                    if len == 0 {
-                        return anyhow::Ok(());
-                    }
-
-                    if len < MESSAGE_LEN_SIZE {
-                        child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
-                    }
-
-                    let message_len = message_len_from_buffer(&stdout_buffer);
-                    let envelope =
-                        read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
-                            .await?;
-                    connection_activity_tx.try_send(()).ok();
-                    incoming_tx.unbounded_send(envelope).ok();
-                }
-            }
-        });
-
-        let stderr_task: Task<anyhow::Result<()>> = cx.background_executor().spawn(async move {
-            loop {
-                stderr_buffer.resize(stderr_offset + 1024, 0);
-
-                let len = child_stderr
-                    .read(&mut stderr_buffer[stderr_offset..])
-                    .await?;
-                if len == 0 {
-                    return anyhow::Ok(());
-                }
-
-                stderr_offset += len;
-                let mut start_ix = 0;
-                while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
-                    .iter()
-                    .position(|b| b == &b'\n')
-                {
-                    let line_ix = start_ix + ix;
-                    let content = &stderr_buffer[start_ix..line_ix];
-                    start_ix = line_ix + 1;
-                    if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
-                        record.log(log::logger())
-                    } else {
-                        eprintln!("(remote) {}", String::from_utf8_lossy(content));
-                    }
-                }
-                stderr_buffer.drain(0..start_ix);
-                stderr_offset -= start_ix;
-
-                connection_activity_tx.try_send(()).ok();
-            }
-        });
-
-        cx.spawn(|_| async move {
-            let result = futures::select! {
-                result = stdin_task.fuse() => {
-                    result.context("stdin")
-                }
-                result = stdout_task.fuse() => {
-                    result.context("stdout")
-                }
-                result = stderr_task.fuse() => {
-                    result.context("stderr")
-                }
-            };
-
-            let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
-            match result {
-                Ok(_) => Ok(status),
-                Err(error) => Err(error),
-            }
-        })
-    }
-
     fn monitor(
         this: WeakModel<Self>,
         io_task: Task<Result<i32>>,
@@ -1005,75 +929,6 @@ impl SshRemoteClient {
         cx.notify();
     }
 
-    #[allow(clippy::too_many_arguments)]
-    async fn establish_connection(
-        unique_identifier: String,
-        reconnect: bool,
-        connection_options: SshConnectionOptions,
-        incoming_tx: UnboundedSender<Envelope>,
-        outgoing_rx: UnboundedReceiver<Envelope>,
-        connection_activity_tx: Sender<()>,
-        delegate: Arc<dyn SshClientDelegate>,
-        cx: &mut AsyncAppContext,
-    ) -> Result<(Box<dyn SshRemoteProcess>, Task<Result<i32>>)> {
-        #[cfg(any(test, feature = "test-support"))]
-        if let Some(fake) = fake::SshRemoteConnection::new(&connection_options) {
-            let io_task = fake::SshRemoteConnection::multiplex(
-                fake.connection_options(),
-                incoming_tx,
-                outgoing_rx,
-                connection_activity_tx,
-                cx,
-            )
-            .await;
-            return Ok((fake, io_task));
-        }
-
-        let ssh_connection =
-            SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?;
-
-        let platform = ssh_connection.query_platform().await?;
-        let remote_binary_path = delegate.remote_server_binary_path(platform, cx)?;
-        if !reconnect {
-            ssh_connection
-                .ensure_server_binary(&delegate, &remote_binary_path, platform, cx)
-                .await?;
-        }
-
-        let socket = ssh_connection.socket.clone();
-        run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
-
-        delegate.set_status(Some("Starting proxy"), cx);
-
-        let mut start_proxy_command = format!(
-            "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
-            std::env::var("RUST_LOG").unwrap_or_default(),
-            std::env::var("RUST_BACKTRACE").unwrap_or_default(),
-            remote_binary_path,
-            unique_identifier,
-        );
-        if reconnect {
-            start_proxy_command.push_str(" --reconnect");
-        }
-
-        let ssh_proxy_process = socket
-            .ssh_command(start_proxy_command)
-            // IMPORTANT: we kill this process when we drop the task that uses it.
-            .kill_on_drop(true)
-            .spawn()
-            .context("failed to spawn remote server")?;
-
-        let io_task = Self::multiplex(
-            ssh_proxy_process,
-            incoming_tx,
-            outgoing_rx,
-            connection_activity_tx,
-            &cx,
-        );
-
-        Ok((Box::new(ssh_connection), io_task))
-    }
-
     pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
         self.client.subscribe_to_entity(remote_id, entity);
     }
@@ -1112,15 +967,21 @@ impl SshRemoteClient {
 
     #[cfg(any(test, feature = "test-support"))]
     pub fn simulate_disconnect(&self, client_cx: &mut AppContext) -> Task<()> {
-        let port = self.connection_options().port.unwrap();
+        let opts = self.connection_options();
         client_cx.spawn(|cx| async move {
-            let (channel, server_cx) = cx
-                .update_global(|c: &mut fake::ServerConnections, _| c.get(port))
+            let connection = cx
+                .update_global(|c: &mut ConnectionPool, _| {
+                    if let Some(ConnectionPoolEntry::Connecting(c)) = c.connections.get(&opts) {
+                        c.clone()
+                    } else {
+                        panic!("missing test connection")
+                    }
+                })
+                .unwrap()
+                .await
                 .unwrap();
 
-            let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
-            let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
-            channel.reconnect(incoming_rx, outgoing_tx, &server_cx);
+            connection.simulate_disconnect(&cx);
         })
     }
 
@@ -1128,78 +989,190 @@ impl SshRemoteClient {
     pub fn fake_server(
         client_cx: &mut gpui::TestAppContext,
         server_cx: &mut gpui::TestAppContext,
-    ) -> (u16, Arc<ChannelClient>) {
-        use gpui::BorrowAppContext;
+    ) -> (SshConnectionOptions, Arc<ChannelClient>) {
+        let port = client_cx
+            .update(|cx| cx.default_global::<ConnectionPool>().connections.len() as u16 + 1);
+        let opts = SshConnectionOptions {
+            host: "<fake>".to_string(),
+            port: Some(port),
+            ..Default::default()
+        };
         let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
         let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
         let server_client =
             server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
-        let port = client_cx.update(|cx| {
-            cx.update_default_global(|c: &mut fake::ServerConnections, _| {
-                c.push(server_client.clone(), server_cx.to_async())
+        let connection: Arc<dyn RemoteConnection> = Arc::new(fake::FakeRemoteConnection {
+            connection_options: opts.clone(),
+            server_cx: fake::SendableCx::new(server_cx.to_async()),
+            server_channel: server_client.clone(),
+        });
+
+        client_cx.update(|cx| {
+            cx.update_default_global(|c: &mut ConnectionPool, cx| {
+                c.connections.insert(
+                    opts.clone(),
+                    ConnectionPoolEntry::Connecting(
+                        cx.foreground_executor()
+                            .spawn({
+                                let connection = connection.clone();
+                                async move { Ok(connection.clone()) }
+                            })
+                            .shared(),
+                    ),
+                );
             })
         });
-        (port, server_client)
+
+        (opts, server_client)
     }
 
     #[cfg(any(test, feature = "test-support"))]
-    pub async fn fake_client(port: u16, client_cx: &mut gpui::TestAppContext) -> Model<Self> {
+    pub async fn fake_client(
+        opts: SshConnectionOptions,
+        client_cx: &mut gpui::TestAppContext,
+    ) -> Model<Self> {
         let (_tx, rx) = oneshot::channel();
         client_cx
-            .update(|cx| {
-                Self::new(
-                    "fake".to_string(),
-                    SshConnectionOptions {
-                        host: "<fake>".to_string(),
-                        port: Some(port),
-                        ..Default::default()
-                    },
-                    rx,
-                    Arc::new(fake::Delegate),
-                    cx,
-                )
-            })
+            .update(|cx| Self::new("fake".to_string(), opts, rx, Arc::new(fake::Delegate), cx))
             .await
             .unwrap()
             .unwrap()
     }
 }
 
+enum ConnectionPoolEntry {
+    Connecting(Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>>),
+    Connected(Weak<dyn RemoteConnection>),
+}
+
+#[derive(Default)]
+struct ConnectionPool {
+    connections: HashMap<SshConnectionOptions, ConnectionPoolEntry>,
+}
+
+impl Global for ConnectionPool {}
+
+impl ConnectionPool {
+    pub fn connect(
+        &mut self,
+        opts: SshConnectionOptions,
+        delegate: &Arc<dyn SshClientDelegate>,
+        cx: &mut AppContext,
+    ) -> Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>> {
+        let connection = self.connections.get(&opts);
+        match connection {
+            Some(ConnectionPoolEntry::Connecting(task)) => {
+                let delegate = delegate.clone();
+                cx.spawn(|mut cx| async move {
+                    delegate.set_status(Some("Waiting for existing connection attempt"), &mut cx);
+                })
+                .detach();
+                return task.clone();
+            }
+            Some(ConnectionPoolEntry::Connected(ssh)) => {
+                if let Some(ssh) = ssh.upgrade() {
+                    if !ssh.has_been_killed() {
+                        return Task::ready(Ok(ssh)).shared();
+                    }
+                }
+                self.connections.remove(&opts);
+            }
+            None => {}
+        }
+
+        let task = cx
+            .spawn({
+                let opts = opts.clone();
+                let delegate = delegate.clone();
+                |mut cx| async move {
+                    let connection = SshRemoteConnection::new(opts.clone(), delegate, &mut cx)
+                        .await
+                        .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>);
+
+                    cx.update_global(|pool: &mut Self, _| {
+                        debug_assert!(matches!(
+                            pool.connections.get(&opts),
+                            Some(ConnectionPoolEntry::Connecting(_))
+                        ));
+                        match connection {
+                            Ok(connection) => {
+                                pool.connections.insert(
+                                    opts.clone(),
+                                    ConnectionPoolEntry::Connected(Arc::downgrade(&connection)),
+                                );
+                                Ok(connection)
+                            }
+                            Err(error) => {
+                                pool.connections.remove(&opts);
+                                Err(Arc::new(error))
+                            }
+                        }
+                    })?
+                }
+            })
+            .shared();
+
+        self.connections
+            .insert(opts.clone(), ConnectionPoolEntry::Connecting(task.clone()));
+        task
+    }
+}
+
 impl From<SshRemoteClient> for AnyProtoClient {
     fn from(client: SshRemoteClient) -> Self {
         AnyProtoClient::new(client.client.clone())
     }
 }
 
-#[async_trait]
-trait SshRemoteProcess: Send + Sync {
-    async fn kill(&mut self) -> Result<()>;
+#[async_trait(?Send)]
+trait RemoteConnection: Send + Sync {
+    #[allow(clippy::too_many_arguments)]
+    fn start_proxy(
+        &self,
+        remote_binary_path: PathBuf,
+        unique_identifier: String,
+        reconnect: bool,
+        incoming_tx: UnboundedSender<Envelope>,
+        outgoing_rx: UnboundedReceiver<Envelope>,
+        connection_activity_tx: Sender<()>,
+        delegate: Arc<dyn SshClientDelegate>,
+        cx: &mut AsyncAppContext,
+    ) -> Task<Result<i32>>;
+    async fn get_remote_binary_path(
+        &self,
+        delegate: &Arc<dyn SshClientDelegate>,
+        reconnect: bool,
+        cx: &mut AsyncAppContext,
+    ) -> Result<PathBuf>;
+    async fn kill(&self) -> Result<()>;
+    fn has_been_killed(&self) -> bool;
     fn ssh_args(&self) -> Vec<String>;
     fn connection_options(&self) -> SshConnectionOptions;
+
+    #[cfg(any(test, feature = "test-support"))]
+    fn simulate_disconnect(&self, _: &AsyncAppContext) {}
 }
 
 struct SshRemoteConnection {
     socket: SshSocket,
-    master_process: process::Child,
+    master_process: Mutex<Option<process::Child>>,
+    platform: SshPlatform,
     _temp_dir: TempDir,
 }
 
-impl Drop for SshRemoteConnection {
-    fn drop(&mut self) {
-        if let Err(error) = self.master_process.kill() {
-            log::error!("failed to kill SSH master process: {}", error);
-        }
+#[async_trait(?Send)]
+impl RemoteConnection for SshRemoteConnection {
+    async fn kill(&self) -> Result<()> {
+        let Some(mut process) = self.master_process.lock().take() else {
+            return Ok(());
+        };
+        process.kill().ok();
+        process.status().await?;
+        Ok(())
     }
-}
 
-#[async_trait]
-impl SshRemoteProcess for SshRemoteConnection {
-    async fn kill(&mut self) -> Result<()> {
-        self.master_process.kill()?;
-
-        self.master_process.status().await?;
-
-        Ok(())
+    fn has_been_killed(&self) -> bool {
+        self.master_process.lock().is_none()
     }
 
     fn ssh_args(&self) -> Vec<String> {
@@ -1209,6 +1182,70 @@ impl SshRemoteProcess for SshRemoteConnection {
     fn connection_options(&self) -> SshConnectionOptions {
         self.socket.connection_options.clone()
     }
+
+    async fn get_remote_binary_path(
+        &self,
+        delegate: &Arc<dyn SshClientDelegate>,
+        reconnect: bool,
+        cx: &mut AsyncAppContext,
+    ) -> Result<PathBuf> {
+        let platform = self.platform;
+        let remote_binary_path = delegate.remote_server_binary_path(platform, cx)?;
+        if !reconnect {
+            self.ensure_server_binary(&delegate, &remote_binary_path, platform, cx)
+                .await?;
+        }
+
+        let socket = self.socket.clone();
+        run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
+        Ok(remote_binary_path)
+    }
+
+    fn start_proxy(
+        &self,
+        remote_binary_path: PathBuf,
+        unique_identifier: String,
+        reconnect: bool,
+        incoming_tx: UnboundedSender<Envelope>,
+        outgoing_rx: UnboundedReceiver<Envelope>,
+        connection_activity_tx: Sender<()>,
+        delegate: Arc<dyn SshClientDelegate>,
+        cx: &mut AsyncAppContext,
+    ) -> Task<Result<i32>> {
+        delegate.set_status(Some("Starting proxy"), cx);
+
+        let mut start_proxy_command = format!(
+            "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
+            std::env::var("RUST_LOG").unwrap_or_default(),
+            std::env::var("RUST_BACKTRACE").unwrap_or_default(),
+            remote_binary_path,
+            unique_identifier,
+        );
+        if reconnect {
+            start_proxy_command.push_str(" --reconnect");
+        }
+
+        let ssh_proxy_process = match self
+            .socket
+            .ssh_command(start_proxy_command)
+            // IMPORTANT: we kill this process when we drop the task that uses it.
+            .kill_on_drop(true)
+            .spawn()
+        {
+            Ok(process) => process,
+            Err(error) => {
+                return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error)))
+            }
+        };
+
+        Self::multiplex(
+            ssh_proxy_process,
+            incoming_tx,
+            outgoing_rx,
+            connection_activity_tx,
+            &cx,
+        )
+    }
 }
 
 impl SshRemoteConnection {
@@ -1305,6 +1342,7 @@ impl SshRemoteConnection {
             ])
             .arg(format!("ControlPath={}", socket_path.display()))
             .arg(&url)
+            .kill_on_drop(true)
             .spawn()?;
 
         // Wait for this ssh process to close its stdout, indicating that authentication
@@ -1348,16 +1386,139 @@ impl SshRemoteConnection {
             Err(anyhow!(error_message))?;
         }
 
+        let socket = SshSocket {
+            connection_options,
+            socket_path,
+        };
+
+        let os = run_cmd(socket.ssh_command("uname").arg("-s")).await?;
+        let arch = run_cmd(socket.ssh_command("uname").arg("-m")).await?;
+
+        let os = match os.trim() {
+            "Darwin" => "macos",
+            "Linux" => "linux",
+            _ => Err(anyhow!("unknown uname os {os:?}"))?,
+        };
+        let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
+            "aarch64"
+        } else if arch.starts_with("x86") || arch.starts_with("i686") {
+            "x86_64"
+        } else {
+            Err(anyhow!("unknown uname architecture {arch:?}"))?
+        };
+
+        let platform = SshPlatform { os, arch };
+
         Ok(Self {
-            socket: SshSocket {
-                connection_options,
-                socket_path,
-            },
-            master_process,
+            socket,
+            master_process: Mutex::new(Some(master_process)),
+            platform,
             _temp_dir: temp_dir,
         })
     }
 
+    fn multiplex(
+        mut ssh_proxy_process: Child,
+        incoming_tx: UnboundedSender<Envelope>,
+        mut outgoing_rx: UnboundedReceiver<Envelope>,
+        mut connection_activity_tx: Sender<()>,
+        cx: &AsyncAppContext,
+    ) -> Task<Result<i32>> {
+        let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
+        let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
+        let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
+
+        let mut stdin_buffer = Vec::new();
+        let mut stdout_buffer = Vec::new();
+        let mut stderr_buffer = Vec::new();
+        let mut stderr_offset = 0;
+
+        let stdin_task = cx.background_executor().spawn(async move {
+            while let Some(outgoing) = outgoing_rx.next().await {
+                write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
+            }
+            anyhow::Ok(())
+        });
+
+        let stdout_task = cx.background_executor().spawn({
+            let mut connection_activity_tx = connection_activity_tx.clone();
+            async move {
+                loop {
+                    stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
+                    let len = child_stdout.read(&mut stdout_buffer).await?;
+
+                    if len == 0 {
+                        return anyhow::Ok(());
+                    }
+
+                    if len < MESSAGE_LEN_SIZE {
+                        child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
+                    }
+
+                    let message_len = message_len_from_buffer(&stdout_buffer);
+                    let envelope =
+                        read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
+                            .await?;
+                    connection_activity_tx.try_send(()).ok();
+                    incoming_tx.unbounded_send(envelope).ok();
+                }
+            }
+        });
+
+        let stderr_task: Task<anyhow::Result<()>> = cx.background_executor().spawn(async move {
+            loop {
+                stderr_buffer.resize(stderr_offset + 1024, 0);
+
+                let len = child_stderr
+                    .read(&mut stderr_buffer[stderr_offset..])
+                    .await?;
+                if len == 0 {
+                    return anyhow::Ok(());
+                }
+
+                stderr_offset += len;
+                let mut start_ix = 0;
+                while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
+                    .iter()
+                    .position(|b| b == &b'\n')
+                {
+                    let line_ix = start_ix + ix;
+                    let content = &stderr_buffer[start_ix..line_ix];
+                    start_ix = line_ix + 1;
+                    if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
+                        record.log(log::logger())
+                    } else {
+                        eprintln!("(remote) {}", String::from_utf8_lossy(content));
+                    }
+                }
+                stderr_buffer.drain(0..start_ix);
+                stderr_offset -= start_ix;
+
+                connection_activity_tx.try_send(()).ok();
+            }
+        });
+
+        cx.spawn(|_| async move {
+            let result = futures::select! {
+                result = stdin_task.fuse() => {
+                    result.context("stdin")
+                }
+                result = stdout_task.fuse() => {
+                    result.context("stdout")
+                }
+                result = stderr_task.fuse() => {
+                    result.context("stderr")
+                }
+            };
+
+            let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
+            match result {
+                Ok(_) => Ok(status),
+                Err(error) => Err(error),
+            }
+        })
+    }
+
     async fn ensure_server_binary(
         &self,
         delegate: &Arc<dyn SshClientDelegate>,
@@ -1621,26 +1782,6 @@ impl SshRemoteConnection {
         Ok(())
     }
 
-    async fn query_platform(&self) -> Result<SshPlatform> {
-        let os = run_cmd(self.socket.ssh_command("uname").arg("-s")).await?;
-        let arch = run_cmd(self.socket.ssh_command("uname").arg("-m")).await?;
-
-        let os = match os.trim() {
-            "Darwin" => "macos",
-            "Linux" => "linux",
-            _ => Err(anyhow!("unknown uname os {os:?}"))?,
-        };
-        let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
-            "aarch64"
-        } else if arch.starts_with("x86") || arch.starts_with("i686") {
-            "x86_64"
-        } else {
-            Err(anyhow!("unknown uname architecture {arch:?}"))?
-        };
-
-        Ok(SshPlatform { os, arch })
-    }
-
     async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
         let mut command = process::Command::new("scp");
         let output = self
@@ -1974,50 +2115,86 @@ mod fake {
         },
         select_biased, FutureExt, SinkExt, StreamExt,
     };
-    use gpui::{AsyncAppContext, BorrowAppContext, Global, SemanticVersion, Task};
+    use gpui::{AsyncAppContext, SemanticVersion, Task};
     use rpc::proto::Envelope;
 
     use super::{
-        ChannelClient, ServerBinary, SshClientDelegate, SshConnectionOptions, SshPlatform,
-        SshRemoteProcess,
+        ChannelClient, RemoteConnection, ServerBinary, SshClientDelegate, SshConnectionOptions,
+        SshPlatform,
     };
 
-    pub(super) struct SshRemoteConnection {
-        connection_options: SshConnectionOptions,
+    pub(super) struct FakeRemoteConnection {
+        pub(super) connection_options: SshConnectionOptions,
+        pub(super) server_channel: Arc<ChannelClient>,
+        pub(super) server_cx: SendableCx,
     }
 
-    impl SshRemoteConnection {
-        pub(super) fn new(
-            connection_options: &SshConnectionOptions,
-        ) -> Option<Box<dyn SshRemoteProcess>> {
-            if connection_options.host == "<fake>" {
-                return Some(Box::new(Self {
-                    connection_options: connection_options.clone(),
-                }));
-            }
-            return None;
+    pub(super) struct SendableCx(AsyncAppContext);
+    // safety: you can only get the other cx on the main thread.
+    impl SendableCx {
+        pub(super) fn new(cx: AsyncAppContext) -> Self {
+            Self(cx)
+        }
+        fn get(&self, _: &AsyncAppContext) -> AsyncAppContext {
+            self.0.clone()
+        }
+    }
+    unsafe impl Send for SendableCx {}
+    unsafe impl Sync for SendableCx {}
+
+    #[async_trait(?Send)]
+    impl RemoteConnection for FakeRemoteConnection {
+        async fn kill(&self) -> Result<()> {
+            Ok(())
+        }
+
+        fn has_been_killed(&self) -> bool {
+            false
+        }
+
+        fn ssh_args(&self) -> Vec<String> {
+            Vec::new()
+        }
+
+        fn connection_options(&self) -> SshConnectionOptions {
+            self.connection_options.clone()
+        }
+
+        fn simulate_disconnect(&self, cx: &AsyncAppContext) {
+            let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
+            let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
+            self.server_channel
+                .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(&cx));
+        }
+
+        async fn get_remote_binary_path(
+            &self,
+            _delegate: &Arc<dyn SshClientDelegate>,
+            _reconnect: bool,
+            _cx: &mut AsyncAppContext,
+        ) -> Result<PathBuf> {
+            Ok(PathBuf::new())
         }
-        pub(super) async fn multiplex(
-            connection_options: SshConnectionOptions,
+
+        fn start_proxy(
+            &self,
+            _remote_binary_path: PathBuf,
+            _unique_identifier: String,
+            _reconnect: bool,
             mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
             mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
             mut connection_activity_tx: Sender<()>,
+            _delegate: Arc<dyn SshClientDelegate>,
             cx: &mut AsyncAppContext,
         ) -> Task<Result<i32>> {
             let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
             let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
 
-            let (channel, server_cx) = cx
-                .update(|cx| {
-                    cx.update_global(|conns: &mut ServerConnections, _| {
-                        conns.get(connection_options.port.unwrap())
-                    })
-                })
-                .unwrap();
-            channel.reconnect(server_incoming_rx, server_outgoing_tx, &server_cx);
-
-            // send to proxy_tx to get to the server.
-            // receive from
+            self.server_channel.reconnect(
+                server_incoming_rx,
+                server_outgoing_tx,
+                &self.server_cx.get(cx),
+            );
 
             cx.background_executor().spawn(async move {
                 loop {
@@ -2041,39 +2218,6 @@ mod fake {
         }
     }
 
-    #[async_trait]
-    impl SshRemoteProcess for SshRemoteConnection {
-        async fn kill(&mut self) -> Result<()> {
-            Ok(())
-        }
-
-        fn ssh_args(&self) -> Vec<String> {
-            Vec::new()
-        }
-
-        fn connection_options(&self) -> SshConnectionOptions {
-            self.connection_options.clone()
-        }
-    }
-
-    #[derive(Default)]
-    pub(super) struct ServerConnections(Vec<(Arc<ChannelClient>, AsyncAppContext)>);
-    impl Global for ServerConnections {}
-
-    impl ServerConnections {
-        pub(super) fn push(&mut self, server: Arc<ChannelClient>, cx: AsyncAppContext) -> u16 {
-            self.0.push((server.clone(), cx));
-            self.0.len() as u16 - 1
-        }
-
-        pub(super) fn get(&mut self, port: u16) -> (Arc<ChannelClient>, AsyncAppContext) {
-            self.0
-                .get(port as usize)
-                .expect("no fake server for port")
-                .clone()
-        }
-    }
-
     pub(super) struct Delegate;
 
     impl SshClientDelegate for Delegate {

crates/remote_server/src/remote_editing_tests.rs 🔗

@@ -702,7 +702,7 @@ async fn init_test(
 ) -> (Model<Project>, Model<HeadlessProject>, Arc<FakeFs>) {
     init_logger();
 
-    let (forwarder, ssh_server_client) = SshRemoteClient::fake_server(cx, server_cx);
+    let (opts, ssh_server_client) = SshRemoteClient::fake_server(cx, server_cx);
     let fs = FakeFs::new(server_cx.executor());
     fs.insert_tree(
         "/code",
@@ -744,7 +744,7 @@ async fn init_test(
         )
     });
 
-    let ssh = SshRemoteClient::fake_client(forwarder, cx).await;
+    let ssh = SshRemoteClient::fake_client(opts, cx).await;
     let project = build_project(ssh, cx);
     project
         .update(cx, {