ssh remoting: Add infrastructure to handle reconnects (#18572)

Thorsten Ball and Bennet created

This restructures the code in `remote` so that it's easier to replace
the current SSH connection with a new one in case of
disconnects/reconnects.

Right now, it successfully reconnects, BUT we're still missing the big
piece on the server-side: keeping the server process alive and
reconnecting to the same process that keeps the project-state.

Release Notes:

- N/A

---------

Co-authored-by: Bennet <bennet@zed.dev>

Change summary

crates/collab/src/tests/remote_editing_collaboration_tests.rs |   4 
crates/collab/src/tests/test_server.rs                        |   4 
crates/project/src/project.rs                                 |  83 
crates/project/src/terminals.rs                               |   8 
crates/recent_projects/src/ssh_connections.rs                 |   6 
crates/remote/src/remote.rs                                   |   2 
crates/remote/src/ssh_session.rs                              | 826 ++--
crates/remote_server/src/headless_project.rs                  |   4 
crates/remote_server/src/main.rs                              |   5 
crates/remote_server/src/remote_editing_tests.rs              |  11 
crates/workspace/src/workspace.rs                             |   4 
11 files changed, 559 insertions(+), 398 deletions(-)

Detailed changes

crates/collab/src/tests/remote_editing_collaboration_tests.rs 🔗

@@ -4,7 +4,7 @@ use fs::{FakeFs, Fs as _};
 use gpui::{Context as _, TestAppContext};
 use language::language_settings::all_language_settings;
 use project::ProjectPath;
-use remote::SshSession;
+use remote::SshRemoteClient;
 use remote_server::HeadlessProject;
 use serde_json::json;
 use std::{path::Path, sync::Arc};
@@ -24,7 +24,7 @@ async fn test_sharing_an_ssh_remote_project(
         .await;
 
     // Set up project on remote FS
-    let (client_ssh, server_ssh) = SshSession::fake(cx_a, server_cx);
+    let (client_ssh, server_ssh) = SshRemoteClient::fake(cx_a, server_cx);
     let remote_fs = FakeFs::new(server_cx.executor());
     remote_fs
         .insert_tree(

crates/collab/src/tests/test_server.rs 🔗

@@ -25,7 +25,7 @@ use node_runtime::NodeRuntime;
 use notifications::NotificationStore;
 use parking_lot::Mutex;
 use project::{Project, WorktreeId};
-use remote::SshSession;
+use remote::SshRemoteClient;
 use rpc::{
     proto::{self, ChannelRole},
     RECEIVE_TIMEOUT,
@@ -835,7 +835,7 @@ impl TestClient {
     pub async fn build_ssh_project(
         &self,
         root_path: impl AsRef<Path>,
-        ssh: Arc<SshSession>,
+        ssh: Arc<SshRemoteClient>,
         cx: &mut TestAppContext,
     ) -> (Model<Project>, WorktreeId) {
         let project = cx.update(|cx| {

crates/project/src/project.rs 🔗

@@ -54,7 +54,7 @@ use parking_lot::{Mutex, RwLock};
 use paths::{local_tasks_file_relative_path, local_vscode_tasks_file_relative_path};
 pub use prettier_store::PrettierStore;
 use project_settings::{ProjectSettings, SettingsObserver, SettingsObserverEvent};
-use remote::SshSession;
+use remote::SshRemoteClient;
 use rpc::{proto::SSH_PROJECT_ID, AnyProtoClient, ErrorCode};
 use search::{SearchInputKind, SearchQuery, SearchResult};
 use search_history::SearchHistory;
@@ -138,7 +138,7 @@ pub struct Project {
     join_project_response_message_id: u32,
     user_store: Model<UserStore>,
     fs: Arc<dyn Fs>,
-    ssh_session: Option<Arc<SshSession>>,
+    ssh_client: Option<Arc<SshRemoteClient>>,
     client_state: ProjectClientState,
     collaborators: HashMap<proto::PeerId, Collaborator>,
     client_subscriptions: Vec<client::Subscription>,
@@ -643,7 +643,7 @@ impl Project {
                 user_store,
                 settings_observer,
                 fs,
-                ssh_session: None,
+                ssh_client: None,
                 buffers_needing_diff: Default::default(),
                 git_diff_debouncer: DebouncedDelay::new(),
                 terminals: Terminals {
@@ -664,7 +664,7 @@ impl Project {
     }
 
     pub fn ssh(
-        ssh: Arc<SshSession>,
+        ssh: Arc<SshRemoteClient>,
         client: Arc<Client>,
         node: NodeRuntime,
         user_store: Model<UserStore>,
@@ -682,14 +682,14 @@ impl Project {
                 SnippetProvider::new(fs.clone(), BTreeSet::from_iter([global_snippets_dir]), cx);
 
             let worktree_store =
-                cx.new_model(|_| WorktreeStore::remote(false, ssh.clone().into(), 0, None));
+                cx.new_model(|_| WorktreeStore::remote(false, ssh.to_proto_client(), 0, None));
             cx.subscribe(&worktree_store, Self::on_worktree_store_event)
                 .detach();
 
             let buffer_store = cx.new_model(|cx| {
                 BufferStore::remote(
                     worktree_store.clone(),
-                    ssh.clone().into(),
+                    ssh.to_proto_client(),
                     SSH_PROJECT_ID,
                     cx,
                 )
@@ -698,7 +698,7 @@ impl Project {
                 .detach();
 
             let settings_observer = cx.new_model(|cx| {
-                SettingsObserver::new_ssh(ssh.clone().into(), worktree_store.clone(), cx)
+                SettingsObserver::new_ssh(ssh.to_proto_client(), worktree_store.clone(), cx)
             });
             cx.subscribe(&settings_observer, Self::on_settings_observer_event)
                 .detach();
@@ -709,7 +709,7 @@ impl Project {
                     buffer_store.clone(),
                     worktree_store.clone(),
                     languages.clone(),
-                    ssh.clone().into(),
+                    ssh.to_proto_client(),
                     SSH_PROJECT_ID,
                     cx,
                 )
@@ -733,7 +733,7 @@ impl Project {
                 user_store,
                 settings_observer,
                 fs,
-                ssh_session: Some(ssh.clone()),
+                ssh_client: Some(ssh.clone()),
                 buffers_needing_diff: Default::default(),
                 git_diff_debouncer: DebouncedDelay::new(),
                 terminals: Terminals {
@@ -751,7 +751,7 @@ impl Project {
                 search_excluded_history: Self::new_search_history(),
             };
 
-            let client: AnyProtoClient = ssh.clone().into();
+            let client: AnyProtoClient = ssh.to_proto_client();
 
             ssh.subscribe_to_entity(SSH_PROJECT_ID, &cx.handle());
             ssh.subscribe_to_entity(SSH_PROJECT_ID, &this.buffer_store);
@@ -907,7 +907,7 @@ impl Project {
                 user_store: user_store.clone(),
                 snippets,
                 fs,
-                ssh_session: None,
+                ssh_client: None,
                 settings_observer: settings_observer.clone(),
                 client_subscriptions: Default::default(),
                 _subscriptions: vec![cx.on_release(Self::release)],
@@ -1230,7 +1230,7 @@ impl Project {
         match self.client_state {
             ProjectClientState::Remote { replica_id, .. } => replica_id,
             _ => {
-                if self.ssh_session.is_some() {
+                if self.ssh_client.is_some() {
                     1
                 } else {
                     0
@@ -1638,7 +1638,7 @@ impl Project {
     pub fn is_local(&self) -> bool {
         match &self.client_state {
             ProjectClientState::Local | ProjectClientState::Shared { .. } => {
-                self.ssh_session.is_none()
+                self.ssh_client.is_none()
             }
             ProjectClientState::Remote { .. } => false,
         }
@@ -1647,7 +1647,7 @@ impl Project {
     pub fn is_via_ssh(&self) -> bool {
         match &self.client_state {
             ProjectClientState::Local | ProjectClientState::Shared { .. } => {
-                self.ssh_session.is_some()
+                self.ssh_client.is_some()
             }
             ProjectClientState::Remote { .. } => false,
         }
@@ -1933,8 +1933,9 @@ impl Project {
             }
             BufferStoreEvent::BufferChangedFilePath { .. } => {}
             BufferStoreEvent::BufferDropped(buffer_id) => {
-                if let Some(ref ssh_session) = self.ssh_session {
-                    ssh_session
+                if let Some(ref ssh_client) = self.ssh_client {
+                    ssh_client
+                        .to_proto_client()
                         .send(proto::CloseBuffer {
                             project_id: 0,
                             buffer_id: buffer_id.to_proto(),
@@ -2139,13 +2140,14 @@ impl Project {
             } => {
                 let operation = language::proto::serialize_operation(operation);
 
-                if let Some(ssh) = &self.ssh_session {
-                    ssh.send(proto::UpdateBuffer {
-                        project_id: 0,
-                        buffer_id: buffer_id.to_proto(),
-                        operations: vec![operation.clone()],
-                    })
-                    .ok();
+                if let Some(ssh) = &self.ssh_client {
+                    ssh.to_proto_client()
+                        .send(proto::UpdateBuffer {
+                            project_id: 0,
+                            buffer_id: buffer_id.to_proto(),
+                            operations: vec![operation.clone()],
+                        })
+                        .ok();
                 }
 
                 self.enqueue_buffer_ordered_message(BufferOrderedMessage::Operation {
@@ -2825,14 +2827,13 @@ impl Project {
     ) -> Receiver<Model<Buffer>> {
         let (tx, rx) = smol::channel::unbounded();
 
-        let (client, remote_id): (AnyProtoClient, _) =
-            if let Some(ssh_session) = self.ssh_session.clone() {
-                (ssh_session.into(), 0)
-            } else if let Some(remote_id) = self.remote_id() {
-                (self.client.clone().into(), remote_id)
-            } else {
-                return rx;
-            };
+        let (client, remote_id): (AnyProtoClient, _) = if let Some(ssh_client) = &self.ssh_client {
+            (ssh_client.to_proto_client(), 0)
+        } else if let Some(remote_id) = self.remote_id() {
+            (self.client.clone().into(), remote_id)
+        } else {
+            return rx;
+        };
 
         let request = client.request(proto::FindSearchCandidates {
             project_id: remote_id,
@@ -2961,11 +2962,13 @@ impl Project {
 
                     exists.then(|| ResolvedPath::AbsPath(expanded))
                 })
-            } else if let Some(ssh_session) = self.ssh_session.as_ref() {
-                let request = ssh_session.request(proto::CheckFileExists {
-                    project_id: SSH_PROJECT_ID,
-                    path: path.to_string(),
-                });
+            } else if let Some(ssh_client) = self.ssh_client.as_ref() {
+                let request = ssh_client
+                    .to_proto_client()
+                    .request(proto::CheckFileExists {
+                        project_id: SSH_PROJECT_ID,
+                        path: path.to_string(),
+                    });
                 cx.background_executor().spawn(async move {
                     let response = request.await.log_err()?;
                     if response.exists {
@@ -3035,13 +3038,13 @@ impl Project {
     ) -> Task<Result<Vec<PathBuf>>> {
         if self.is_local() {
             DirectoryLister::Local(self.fs.clone()).list_directory(query, cx)
-        } else if let Some(session) = self.ssh_session.as_ref() {
+        } else if let Some(session) = self.ssh_client.as_ref() {
             let request = proto::ListRemoteDirectory {
                 dev_server_id: SSH_PROJECT_ID,
                 path: query,
             };
 
-            let response = session.request(request);
+            let response = session.to_proto_client().request(request);
             cx.background_executor().spawn(async move {
                 let response = response.await?;
                 Ok(response.entries.into_iter().map(PathBuf::from).collect())
@@ -3465,11 +3468,11 @@ impl Project {
         cx: AsyncAppContext,
     ) -> Result<proto::Ack> {
         let buffer_store = this.read_with(&cx, |this, cx| {
-            if let Some(ssh) = &this.ssh_session {
+            if let Some(ssh) = &this.ssh_client {
                 let mut payload = envelope.payload.clone();
                 payload.project_id = 0;
                 cx.background_executor()
-                    .spawn(ssh.request(payload))
+                    .spawn(ssh.to_proto_client().request(payload))
                     .detach_and_log_err(cx);
             }
             this.buffer_store.clone()

crates/project/src/terminals.rs 🔗

@@ -67,8 +67,12 @@ impl Project {
     }
 
     fn ssh_command(&self, cx: &AppContext) -> Option<SshCommand> {
-        if let Some(ssh_session) = self.ssh_session.as_ref() {
-            return Some(SshCommand::Direct(ssh_session.ssh_args()));
+        if let Some(args) = self
+            .ssh_client
+            .as_ref()
+            .and_then(|session| session.ssh_args())
+        {
+            return Some(SshCommand::Direct(args));
         }
 
         let dev_server_project_id = self.dev_server_project_id()?;

crates/recent_projects/src/ssh_connections.rs 🔗

@@ -11,7 +11,7 @@ use gpui::{
     Transformation, View,
 };
 use release_channel::{AppVersion, ReleaseChannel};
-use remote::{SshConnectionOptions, SshPlatform, SshSession};
+use remote::{SshConnectionOptions, SshPlatform, SshRemoteClient};
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use settings::{Settings, SettingsSources};
@@ -376,12 +376,12 @@ pub fn connect_over_ssh(
     connection_options: SshConnectionOptions,
     ui: View<SshPrompt>,
     cx: &mut WindowContext,
-) -> Task<Result<Arc<SshSession>>> {
+) -> Task<Result<Arc<SshRemoteClient>>> {
     let window = cx.window_handle();
     let known_password = connection_options.password.clone();
 
     cx.spawn(|mut cx| async move {
-        remote::SshSession::client(
+        remote::SshRemoteClient::new(
             connection_options,
             Arc::new(SshClientDelegate {
                 window,

crates/remote/src/remote.rs 🔗

@@ -2,4 +2,4 @@ pub mod json_log;
 pub mod protocol;
 pub mod ssh_session;
 
-pub use ssh_session::{SshClientDelegate, SshConnectionOptions, SshPlatform, SshSession};
+pub use ssh_session::{SshClientDelegate, SshConnectionOptions, SshPlatform, SshRemoteClient};

crates/remote/src/ssh_session.rs 🔗

@@ -7,19 +7,23 @@ use crate::{
 use anyhow::{anyhow, Context as _, Result};
 use collections::HashMap;
 use futures::{
-    channel::{mpsc, oneshot},
+    channel::{
+        mpsc::{self, UnboundedReceiver, UnboundedSender},
+        oneshot,
+    },
     future::BoxFuture,
-    select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, StreamExt as _,
+    select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, SinkExt,
+    StreamExt as _,
 };
 use gpui::{AppContext, AsyncAppContext, Model, SemanticVersion, Task};
 use parking_lot::Mutex;
 use rpc::{
     proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
-    EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError,
+    AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError,
 };
 use smol::{
     fs,
-    process::{self, Stdio},
+    process::{self, Child, Stdio},
 };
 use std::{
     any::TypeId,
@@ -44,22 +48,6 @@ pub struct SshSocket {
     socket_path: PathBuf,
 }
 
-pub struct SshSession {
-    next_message_id: AtomicU32,
-    response_channels: ResponseChannels, // Lock
-    outgoing_tx: mpsc::UnboundedSender<Envelope>,
-    spawn_process_tx: mpsc::UnboundedSender<SpawnRequest>,
-    client_socket: Option<SshSocket>,
-    state: Mutex<ProtoMessageHandlerSet>, // Lock
-    _io_task: Option<Task<Result<()>>>,
-}
-
-struct SshClientState {
-    socket: SshSocket,
-    master_process: process::Child,
-    _temp_dir: TempDir,
-}
-
 #[derive(Debug, Clone, PartialEq, Eq)]
 pub struct SshConnectionOptions {
     pub host: String,
@@ -105,18 +93,13 @@ impl SshConnectionOptions {
     }
 }
 
-struct SpawnRequest {
-    command: String,
-    process_tx: oneshot::Sender<process::Child>,
-}
-
 #[derive(Copy, Clone, Debug)]
 pub struct SshPlatform {
     pub os: &'static str,
     pub arch: &'static str,
 }
 
-pub trait SshClientDelegate {
+pub trait SshClientDelegate: Send + Sync {
     fn ask_password(
         &self,
         prompt: String,
@@ -132,48 +115,249 @@ pub trait SshClientDelegate {
     fn set_error(&self, error_message: String, cx: &mut AsyncAppContext);
 }
 
-type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
+impl SshSocket {
+    fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
+        let mut command = process::Command::new("ssh");
+        self.ssh_options(&mut command)
+            .arg(self.connection_options.ssh_url())
+            .arg(program);
+        command
+    }
+
+    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
+        command
+            .stdin(Stdio::piped())
+            .stdout(Stdio::piped())
+            .stderr(Stdio::piped())
+            .args(["-o", "ControlMaster=no", "-o"])
+            .arg(format!("ControlPath={}", self.socket_path.display()))
+    }
+
+    fn ssh_args(&self) -> Vec<String> {
+        vec![
+            "-o".to_string(),
+            "ControlMaster=no".to_string(),
+            "-o".to_string(),
+            format!("ControlPath={}", self.socket_path.display()),
+            self.connection_options.ssh_url(),
+        ]
+    }
+}
 
-impl SshSession {
-    pub async fn client(
+async fn run_cmd(command: &mut process::Command) -> Result<String> {
+    let output = command.output().await?;
+    if output.status.success() {
+        Ok(String::from_utf8_lossy(&output.stdout).to_string())
+    } else {
+        Err(anyhow!(
+            "failed to run command: {}",
+            String::from_utf8_lossy(&output.stderr)
+        ))
+    }
+}
+#[cfg(unix)]
+async fn read_with_timeout(
+    stdout: &mut process::ChildStdout,
+    timeout: std::time::Duration,
+    output: &mut Vec<u8>,
+) -> Result<(), std::io::Error> {
+    smol::future::or(
+        async {
+            stdout.read_to_end(output).await?;
+            Ok::<_, std::io::Error>(())
+        },
+        async {
+            smol::Timer::after(timeout).await;
+
+            Err(std::io::Error::new(
+                std::io::ErrorKind::TimedOut,
+                "Read operation timed out",
+            ))
+        },
+    )
+    .await
+}
+
+struct ChannelForwarder {
+    quit_tx: UnboundedSender<()>,
+    forwarding_task: Task<(UnboundedSender<Envelope>, UnboundedReceiver<Envelope>)>,
+}
+
+impl ChannelForwarder {
+    fn new(
+        mut incoming_tx: UnboundedSender<Envelope>,
+        mut outgoing_rx: UnboundedReceiver<Envelope>,
+        cx: &mut AsyncAppContext,
+    ) -> (Self, UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
+        let (quit_tx, mut quit_rx) = mpsc::unbounded::<()>();
+
+        let (proxy_incoming_tx, mut proxy_incoming_rx) = mpsc::unbounded::<Envelope>();
+        let (mut proxy_outgoing_tx, proxy_outgoing_rx) = mpsc::unbounded::<Envelope>();
+
+        let forwarding_task = cx.background_executor().spawn(async move {
+            loop {
+                select_biased! {
+                    _ = quit_rx.next().fuse() => {
+                        break;
+                    },
+                    incoming_envelope = proxy_incoming_rx.next().fuse() => {
+                        if let Some(envelope) = incoming_envelope {
+                            if incoming_tx.send(envelope).await.is_err() {
+                                break;
+                            }
+                        } else {
+                            break;
+                        }
+                    }
+                    outgoing_envelope = outgoing_rx.next().fuse() => {
+                        if let Some(envelope) = outgoing_envelope {
+                            if proxy_outgoing_tx.send(envelope).await.is_err() {
+                                break;
+                            }
+                        } else {
+                            break;
+                        }
+                    }
+                }
+            }
+
+            (incoming_tx, outgoing_rx)
+        });
+
+        (
+            Self {
+                forwarding_task,
+                quit_tx,
+            },
+            proxy_incoming_tx,
+            proxy_outgoing_rx,
+        )
+    }
+
+    async fn into_channels(mut self) -> (UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
+        let _ = self.quit_tx.send(()).await;
+        self.forwarding_task.await
+    }
+}
+
+struct SshRemoteClientState {
+    ssh_connection: SshRemoteConnection,
+    delegate: Arc<dyn SshClientDelegate>,
+    forwarder: ChannelForwarder,
+    _multiplex_task: Task<Result<()>>,
+}
+
+pub struct SshRemoteClient {
+    client: Arc<ChannelClient>,
+    inner_state: Arc<Mutex<Option<SshRemoteClientState>>>,
+}
+
+impl SshRemoteClient {
+    pub async fn new(
         connection_options: SshConnectionOptions,
         delegate: Arc<dyn SshClientDelegate>,
         cx: &mut AsyncAppContext,
     ) -> Result<Arc<Self>> {
-        let client_state = SshClientState::new(connection_options, delegate.clone(), cx).await?;
+        let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
+        let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
 
-        let platform = client_state.query_platform().await?;
-        let (local_binary_path, version) = delegate.get_server_binary(platform, cx).await??;
-        let remote_binary_path = delegate.remote_server_binary_path(cx)?;
-        client_state
-            .ensure_server_binary(
-                &delegate,
-                &local_binary_path,
-                &remote_binary_path,
-                version,
+        let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?;
+        let this = Arc::new(Self {
+            client,
+            inner_state: Arc::new(Mutex::new(None)),
+        });
+
+        let inner_state = {
+            let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
+                ChannelForwarder::new(incoming_tx, outgoing_rx, cx);
+
+            let (ssh_connection, ssh_process) =
+                Self::establish_connection(connection_options.clone(), delegate.clone(), cx)
+                    .await?;
+
+            let multiplex_task = Self::multiplex(
+                this.clone(),
+                ssh_process,
+                proxy_incoming_tx,
+                proxy_outgoing_rx,
                 cx,
-            )
-            .await?;
+            );
 
-        let (spawn_process_tx, mut spawn_process_rx) = mpsc::unbounded::<SpawnRequest>();
-        let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded::<Envelope>();
-        let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
+            SshRemoteClientState {
+                ssh_connection,
+                delegate,
+                forwarder: proxy,
+                _multiplex_task: multiplex_task,
+            }
+        };
 
-        let socket = client_state.socket.clone();
-        run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
+        this.inner_state.lock().replace(inner_state);
 
-        let mut remote_server_child = socket
-            .ssh_command(format!(
-                "RUST_LOG={} RUST_BACKTRACE={} {:?} run",
-                std::env::var("RUST_LOG").unwrap_or_default(),
-                std::env::var("RUST_BACKTRACE").unwrap_or_default(),
-                remote_binary_path,
-            ))
-            .spawn()
-            .context("failed to spawn remote server")?;
-        let mut child_stderr = remote_server_child.stderr.take().unwrap();
-        let mut child_stdout = remote_server_child.stdout.take().unwrap();
-        let mut child_stdin = remote_server_child.stdin.take().unwrap();
+        Ok(this)
+    }
+
+    fn reconnect(this: Arc<Self>, cx: &mut AsyncAppContext) -> Result<()> {
+        let Some(state) = this.inner_state.lock().take() else {
+            return Err(anyhow!("reconnect is already in progress"));
+        };
+
+        let SshRemoteClientState {
+            mut ssh_connection,
+            delegate,
+            forwarder: proxy,
+            _multiplex_task,
+        } = state;
+        drop(_multiplex_task);
+
+        cx.spawn(|mut cx| async move {
+            let (incoming_tx, outgoing_rx) = proxy.into_channels().await;
+
+            ssh_connection.master_process.kill()?;
+            ssh_connection
+                .master_process
+                .status()
+                .await
+                .context("Failed to kill ssh process")?;
+
+            let connection_options = ssh_connection.socket.connection_options.clone();
+
+            let (ssh_connection, ssh_process) =
+                Self::establish_connection(connection_options, delegate.clone(), &mut cx).await?;
+
+            let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
+                ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
+
+            let inner_state = SshRemoteClientState {
+                ssh_connection,
+                delegate,
+                forwarder: proxy,
+                _multiplex_task: Self::multiplex(
+                    this.clone(),
+                    ssh_process,
+                    proxy_incoming_tx,
+                    proxy_outgoing_rx,
+                    &mut cx,
+                ),
+            };
+            this.inner_state.lock().replace(inner_state);
+
+            anyhow::Ok(())
+        })
+        .detach();
+
+        anyhow::Ok(())
+    }
+
+    fn multiplex(
+        this: Arc<Self>,
+        mut ssh_process: Child,
+        incoming_tx: UnboundedSender<Envelope>,
+        mut outgoing_rx: UnboundedReceiver<Envelope>,
+        cx: &mut AsyncAppContext,
+    ) -> Task<Result<()>> {
+        let mut child_stderr = ssh_process.stderr.take().unwrap();
+        let mut child_stdout = ssh_process.stdout.take().unwrap();
+        let mut child_stdin = ssh_process.stdin.take().unwrap();
 
         let io_task = cx.background_executor().spawn(async move {
             let mut stdin_buffer = Vec::new();
@@ -194,27 +378,15 @@ impl SshSession {
                         write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
                     }
 
-                    request = spawn_process_rx.next().fuse() => {
-                        let Some(request) = request else {
-                            return Ok(());
-                        };
-
-                        log::info!("spawn process: {:?}", request.command);
-                        let child = client_state.socket
-                            .ssh_command(&request.command)
-                            .spawn()
-                            .context("failed to create channel")?;
-                        request.process_tx.send(child).ok();
-                    }
-
                     result = child_stdout.read(&mut stdout_buffer).fuse() => {
                         match result {
                             Ok(0) => {
                                 child_stdin.close().await?;
                                 outgoing_rx.close();
-                                let status = remote_server_child.status().await?;
+                                let status = ssh_process.status().await?;
                                 if !status.success() {
-                                    log::error!("channel exited with status: {status:?}");
+                                    log::error!("ssh process exited with status: {status:?}");
+                                    return Err(anyhow!("ssh process exited with non-zero status code: {:?}", status.code()));
                                 }
                                 return Ok(());
                             }
@@ -267,239 +439,112 @@ impl SshSession {
             }
         });
 
-        cx.update(|cx| {
-            Self::new(
-                incoming_rx,
-                outgoing_tx,
-                spawn_process_tx,
-                Some(socket),
-                Some(io_task),
-                cx,
-            )
-        })
-    }
+        cx.spawn(|mut cx| async move {
+            let result = io_task.await;
 
-    pub fn server(
-        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
-        outgoing_tx: mpsc::UnboundedSender<Envelope>,
-        cx: &AppContext,
-    ) -> Arc<SshSession> {
-        let (tx, _rx) = mpsc::unbounded();
-        Self::new(incoming_rx, outgoing_tx, tx, None, None, cx)
-    }
-
-    #[cfg(any(test, feature = "test-support"))]
-    pub fn fake(
-        client_cx: &mut gpui::TestAppContext,
-        server_cx: &mut gpui::TestAppContext,
-    ) -> (Arc<Self>, Arc<Self>) {
-        let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
-        let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
-        let (tx, _rx) = mpsc::unbounded();
-        (
-            client_cx.update(|cx| {
-                Self::new(
-                    server_to_client_rx,
-                    client_to_server_tx,
-                    tx.clone(),
-                    None, // todo()
-                    None,
-                    cx,
-                )
-            }),
-            server_cx.update(|cx| {
-                Self::new(
-                    client_to_server_rx,
-                    server_to_client_tx,
-                    tx.clone(),
-                    None,
-                    None,
-                    cx,
-                )
-            }),
-        )
-    }
-
-    fn new(
-        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
-        outgoing_tx: mpsc::UnboundedSender<Envelope>,
-        spawn_process_tx: mpsc::UnboundedSender<SpawnRequest>,
-        client_socket: Option<SshSocket>,
-        io_task: Option<Task<Result<()>>>,
-        cx: &AppContext,
-    ) -> Arc<SshSession> {
-        let this = Arc::new(Self {
-            next_message_id: AtomicU32::new(0),
-            response_channels: ResponseChannels::default(),
-            outgoing_tx,
-            spawn_process_tx,
-            client_socket,
-            state: Default::default(),
-            _io_task: io_task,
-        });
-
-        cx.spawn(|cx| {
-            let this = Arc::downgrade(&this);
-            async move {
-                let peer_id = PeerId { owner_id: 0, id: 0 };
-                while let Some(incoming) = incoming_rx.next().await {
-                    let Some(this) = this.upgrade() else {
-                        return anyhow::Ok(());
-                    };
-
-                    if let Some(request_id) = incoming.responding_to {
-                        let request_id = MessageId(request_id);
-                        let sender = this.response_channels.lock().remove(&request_id);
-                        if let Some(sender) = sender {
-                            let (tx, rx) = oneshot::channel();
-                            if incoming.payload.is_some() {
-                                sender.send((incoming, tx)).ok();
-                            }
-                            rx.await.ok();
-                        }
-                    } else if let Some(envelope) =
-                        build_typed_envelope(peer_id, Instant::now(), incoming)
-                    {
-                        let type_name = envelope.payload_type_name();
-                        if let Some(future) = ProtoMessageHandlerSet::handle_message(
-                            &this.state,
-                            envelope,
-                            this.clone().into(),
-                            cx.clone(),
-                        ) {
-                            log::debug!("ssh message received. name:{type_name}");
-                            match future.await {
-                                Ok(_) => {
-                                    log::debug!("ssh message handled. name:{type_name}");
-                                }
-                                Err(error) => {
-                                    log::error!(
-                                        "error handling message. type:{type_name}, error:{error}",
-                                    );
-                                }
-                            }
-                        } else {
-                            log::error!("unhandled ssh message name:{type_name}");
-                        }
-                    }
-                }
-                anyhow::Ok(())
+            if let Err(error) = result {
+                log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
+                Self::reconnect(this, &mut cx).ok();
             }
-        })
-        .detach();
-
-        this
-    }
 
-    pub fn request<T: RequestMessage>(
-        &self,
-        payload: T,
-    ) -> impl 'static + Future<Output = Result<T::Response>> {
-        log::debug!("ssh request start. name:{}", T::NAME);
-        let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME);
-        async move {
-            let response = response.await?;
-            log::debug!("ssh request finish. name:{}", T::NAME);
-            T::Response::from_envelope(response)
-                .ok_or_else(|| anyhow!("received a response of the wrong type"))
-        }
-    }
-
-    pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
-        log::debug!("ssh send name:{}", T::NAME);
-        self.send_dynamic(payload.into_envelope(0, None, None))
+            Ok(())
+        })
     }
 
-    pub fn request_dynamic(
-        &self,
-        mut envelope: proto::Envelope,
-        type_name: &'static str,
-    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
-        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
-        let (tx, rx) = oneshot::channel();
-        let mut response_channels_lock = self.response_channels.lock();
-        response_channels_lock.insert(MessageId(envelope.id), tx);
-        drop(response_channels_lock);
-        let result = self.outgoing_tx.unbounded_send(envelope);
-        async move {
-            if let Err(error) = &result {
-                log::error!("failed to send message: {}", error);
-                return Err(anyhow!("failed to send message: {}", error));
-            }
-
-            let response = rx.await.context("connection lost")?.0;
-            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
-                return Err(RpcError::from_proto(error, type_name));
-            }
-            Ok(response)
-        }
-    }
+    async fn establish_connection(
+        connection_options: SshConnectionOptions,
+        delegate: Arc<dyn SshClientDelegate>,
+        cx: &mut AsyncAppContext,
+    ) -> Result<(SshRemoteConnection, Child)> {
+        let ssh_connection =
+            SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?;
 
-    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
-        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
-        self.outgoing_tx.unbounded_send(envelope)?;
-        Ok(())
-    }
+        let platform = ssh_connection.query_platform().await?;
+        let (local_binary_path, version) = delegate.get_server_binary(platform, cx).await??;
+        let remote_binary_path = delegate.remote_server_binary_path(cx)?;
+        ssh_connection
+            .ensure_server_binary(
+                &delegate,
+                &local_binary_path,
+                &remote_binary_path,
+                version,
+                cx,
+            )
+            .await?;
 
-    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
-        let id = (TypeId::of::<E>(), remote_id);
+        let socket = ssh_connection.socket.clone();
+        run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
 
-        let mut state = self.state.lock();
-        if state.entities_by_type_and_remote_id.contains_key(&id) {
-            panic!("already subscribed to entity");
-        }
+        let ssh_process = socket
+            .ssh_command(format!(
+                "RUST_LOG={} RUST_BACKTRACE={} {:?} run",
+                std::env::var("RUST_LOG").unwrap_or_default(),
+                std::env::var("RUST_BACKTRACE").unwrap_or_default(),
+                remote_binary_path,
+            ))
+            .spawn()
+            .context("failed to spawn remote server")?;
 
-        state.entities_by_type_and_remote_id.insert(
-            id,
-            EntityMessageSubscriber::Entity {
-                handle: entity.downgrade().into(),
-            },
-        );
+        Ok((ssh_connection, ssh_process))
     }
 
-    pub async fn spawn_process(&self, command: String) -> process::Child {
-        let (process_tx, process_rx) = oneshot::channel();
-        self.spawn_process_tx
-            .unbounded_send(SpawnRequest {
-                command,
-                process_tx,
-            })
-            .ok();
-        process_rx.await.unwrap()
+    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
+        self.client.subscribe_to_entity(remote_id, entity);
     }
 
-    pub fn ssh_args(&self) -> Vec<String> {
-        self.client_socket.as_ref().unwrap().ssh_args()
+    pub fn ssh_args(&self) -> Option<Vec<String>> {
+        let state = self.inner_state.lock();
+        state
+            .as_ref()
+            .map(|state| state.ssh_connection.socket.ssh_args())
     }
-}
 
-impl ProtoClient for SshSession {
-    fn request(
-        &self,
-        envelope: proto::Envelope,
-        request_type: &'static str,
-    ) -> BoxFuture<'static, Result<proto::Envelope>> {
-        self.request_dynamic(envelope, request_type).boxed()
+    pub fn to_proto_client(&self) -> AnyProtoClient {
+        self.client.clone().into()
     }
 
-    fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
-        self.send_dynamic(envelope)
-    }
+    #[cfg(any(test, feature = "test-support"))]
+    pub fn fake(
+        client_cx: &mut gpui::TestAppContext,
+        server_cx: &mut gpui::TestAppContext,
+    ) -> (Arc<Self>, Arc<ChannelClient>) {
+        let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
+        let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
 
-    fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
-        self.send_dynamic(envelope)
+        (
+            client_cx.update(|cx| {
+                let client = ChannelClient::new(server_to_client_rx, client_to_server_tx, cx);
+                Arc::new(Self {
+                    client,
+                    inner_state: Arc::new(Mutex::new(None)),
+                })
+            }),
+            server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)),
+        )
     }
+}
 
-    fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
-        &self.state
+impl From<SshRemoteClient> for AnyProtoClient {
+    fn from(client: SshRemoteClient) -> Self {
+        AnyProtoClient::new(client.client.clone())
     }
+}
 
-    fn is_via_collab(&self) -> bool {
-        false
+struct SshRemoteConnection {
+    socket: SshSocket,
+    master_process: process::Child,
+    _temp_dir: TempDir,
+}
+
+impl Drop for SshRemoteConnection {
+    fn drop(&mut self) {
+        if let Err(error) = self.master_process.kill() {
+            log::error!("failed to kill SSH master process: {}", error);
+        }
     }
 }
 
-impl SshClientState {
+impl SshRemoteConnection {
     #[cfg(not(unix))]
     async fn new(
         _connection_options: SshConnectionOptions,
@@ -740,74 +785,181 @@ impl SshClientState {
     }
 }
 
-#[cfg(unix)]
-async fn read_with_timeout(
-    stdout: &mut process::ChildStdout,
-    timeout: std::time::Duration,
-    output: &mut Vec<u8>,
-) -> Result<(), std::io::Error> {
-    smol::future::or(
-        async {
-            stdout.read_to_end(output).await?;
-            Ok::<_, std::io::Error>(())
-        },
-        async {
-            smol::Timer::after(timeout).await;
+type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
 
-            Err(std::io::Error::new(
-                std::io::ErrorKind::TimedOut,
-                "Read operation timed out",
-            ))
-        },
-    )
-    .await
+pub struct ChannelClient {
+    next_message_id: AtomicU32,
+    outgoing_tx: mpsc::UnboundedSender<Envelope>,
+    response_channels: ResponseChannels,             // Lock
+    message_handlers: Mutex<ProtoMessageHandlerSet>, // Lock
 }
 
-impl Drop for SshClientState {
-    fn drop(&mut self) {
-        if let Err(error) = self.master_process.kill() {
-            log::error!("failed to kill SSH master process: {}", error);
+impl ChannelClient {
+    pub fn new(
+        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
+        outgoing_tx: mpsc::UnboundedSender<Envelope>,
+        cx: &AppContext,
+    ) -> Arc<Self> {
+        let this = Arc::new(Self {
+            outgoing_tx,
+            next_message_id: AtomicU32::new(0),
+            response_channels: ResponseChannels::default(),
+            message_handlers: Default::default(),
+        });
+
+        Self::start_handling_messages(this.clone(), incoming_rx, cx);
+
+        this
+    }
+
+    fn start_handling_messages(
+        this: Arc<Self>,
+        mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
+        cx: &AppContext,
+    ) {
+        cx.spawn(|cx| {
+            let this = Arc::downgrade(&this);
+            async move {
+                let peer_id = PeerId { owner_id: 0, id: 0 };
+                while let Some(incoming) = incoming_rx.next().await {
+                    let Some(this) = this.upgrade() else {
+                        return anyhow::Ok(());
+                    };
+
+                    if let Some(request_id) = incoming.responding_to {
+                        let request_id = MessageId(request_id);
+                        let sender = this.response_channels.lock().remove(&request_id);
+                        if let Some(sender) = sender {
+                            let (tx, rx) = oneshot::channel();
+                            if incoming.payload.is_some() {
+                                sender.send((incoming, tx)).ok();
+                            }
+                            rx.await.ok();
+                        }
+                    } else if let Some(envelope) =
+                        build_typed_envelope(peer_id, Instant::now(), incoming)
+                    {
+                        let type_name = envelope.payload_type_name();
+                        if let Some(future) = ProtoMessageHandlerSet::handle_message(
+                            &this.message_handlers,
+                            envelope,
+                            this.clone().into(),
+                            cx.clone(),
+                        ) {
+                            log::debug!("ssh message received. name:{type_name}");
+                            match future.await {
+                                Ok(_) => {
+                                    log::debug!("ssh message handled. name:{type_name}");
+                                }
+                                Err(error) => {
+                                    log::error!(
+                                        "error handling message. type:{type_name}, error:{error}",
+                                    );
+                                }
+                            }
+                        } else {
+                            log::error!("unhandled ssh message name:{type_name}");
+                        }
+                    }
+                }
+                anyhow::Ok(())
+            }
+        })
+        .detach();
+    }
+
+    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
+        let id = (TypeId::of::<E>(), remote_id);
+
+        let mut message_handlers = self.message_handlers.lock();
+        if message_handlers
+            .entities_by_type_and_remote_id
+            .contains_key(&id)
+        {
+            panic!("already subscribed to entity");
         }
+
+        message_handlers.entities_by_type_and_remote_id.insert(
+            id,
+            EntityMessageSubscriber::Entity {
+                handle: entity.downgrade().into(),
+            },
+        );
     }
-}
 
-impl SshSocket {
-    fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
-        let mut command = process::Command::new("ssh");
-        self.ssh_options(&mut command)
-            .arg(self.connection_options.ssh_url())
-            .arg(program);
-        command
+    pub fn request<T: RequestMessage>(
+        &self,
+        payload: T,
+    ) -> impl 'static + Future<Output = Result<T::Response>> {
+        log::debug!("ssh request start. name:{}", T::NAME);
+        let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME);
+        async move {
+            let response = response.await?;
+            log::debug!("ssh request finish. name:{}", T::NAME);
+            T::Response::from_envelope(response)
+                .ok_or_else(|| anyhow!("received a response of the wrong type"))
+        }
     }
 
-    fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
-        command
-            .stdin(Stdio::piped())
-            .stdout(Stdio::piped())
-            .stderr(Stdio::piped())
-            .args(["-o", "ControlMaster=no", "-o"])
-            .arg(format!("ControlPath={}", self.socket_path.display()))
+    pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
+        log::debug!("ssh send name:{}", T::NAME);
+        self.send_dynamic(payload.into_envelope(0, None, None))
     }
 
-    fn ssh_args(&self) -> Vec<String> {
-        vec![
-            "-o".to_string(),
-            "ControlMaster=no".to_string(),
-            "-o".to_string(),
-            format!("ControlPath={}", self.socket_path.display()),
-            self.connection_options.ssh_url(),
-        ]
+    pub fn request_dynamic(
+        &self,
+        mut envelope: proto::Envelope,
+        type_name: &'static str,
+    ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
+        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
+        let (tx, rx) = oneshot::channel();
+        let mut response_channels_lock = self.response_channels.lock();
+        response_channels_lock.insert(MessageId(envelope.id), tx);
+        drop(response_channels_lock);
+        let result = self.outgoing_tx.unbounded_send(envelope);
+        async move {
+            if let Err(error) = &result {
+                log::error!("failed to send message: {}", error);
+                return Err(anyhow!("failed to send message: {}", error));
+            }
+
+            let response = rx.await.context("connection lost")?.0;
+            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
+                return Err(RpcError::from_proto(error, type_name));
+            }
+            Ok(response)
+        }
+    }
+
+    pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
+        envelope.id = self.next_message_id.fetch_add(1, SeqCst);
+        self.outgoing_tx.unbounded_send(envelope)?;
+        Ok(())
     }
 }
 
-async fn run_cmd(command: &mut process::Command) -> Result<String> {
-    let output = command.output().await?;
-    if output.status.success() {
-        Ok(String::from_utf8_lossy(&output.stdout).to_string())
-    } else {
-        Err(anyhow!(
-            "failed to run command: {}",
-            String::from_utf8_lossy(&output.stderr)
-        ))
+impl ProtoClient for ChannelClient {
+    fn request(
+        &self,
+        envelope: proto::Envelope,
+        request_type: &'static str,
+    ) -> BoxFuture<'static, Result<proto::Envelope>> {
+        self.request_dynamic(envelope, request_type).boxed()
+    }
+
+    fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
+        self.send_dynamic(envelope)
+    }
+
+    fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
+        self.send_dynamic(envelope)
+    }
+
+    fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
+        &self.message_handlers
+    }
+
+    fn is_via_collab(&self) -> bool {
+        false
     }
 }

crates/remote_server/src/headless_project.rs 🔗

@@ -10,7 +10,7 @@ use project::{
     worktree_store::WorktreeStore,
     LspStore, LspStoreEvent, PrettierStore, ProjectPath, WorktreeId,
 };
-use remote::SshSession;
+use remote::ssh_session::ChannelClient;
 use rpc::{
     proto::{self, SSH_PEER_ID, SSH_PROJECT_ID},
     AnyProtoClient, TypedEnvelope,
@@ -41,7 +41,7 @@ impl HeadlessProject {
         project::Project::init_settings(cx);
     }
 
-    pub fn new(session: Arc<SshSession>, fs: Arc<dyn Fs>, cx: &mut ModelContext<Self>) -> Self {
+    pub fn new(session: Arc<ChannelClient>, fs: Arc<dyn Fs>, cx: &mut ModelContext<Self>) -> Self {
         let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
 
         let node_runtime = NodeRuntime::unavailable();

crates/remote_server/src/main.rs 🔗

@@ -6,7 +6,6 @@ use gpui::Context as _;
 use remote::{
     json_log::LogRecord,
     protocol::{read_message, write_message},
-    SshSession,
 };
 use remote_server::HeadlessProject;
 use smol::{io::AsyncWriteExt, stream::StreamExt as _, Async};
@@ -24,6 +23,8 @@ fn main() {
 
 #[cfg(not(windows))]
 fn main() {
+    use remote::ssh_session::ChannelClient;
+
     env_logger::builder()
         .format(|buf, record| {
             serde_json::to_writer(&mut *buf, &LogRecord::new(record))?;
@@ -55,7 +56,7 @@ fn main() {
         let mut stdin = Async::new(io::stdin()).unwrap();
         let mut stdout = Async::new(io::stdout()).unwrap();
 
-        let session = SshSession::server(incoming_rx, outgoing_tx, cx);
+        let session = ChannelClient::new(incoming_rx, outgoing_tx, cx);
         let project = cx.new_model(|cx| {
             HeadlessProject::new(
                 session.clone(),

crates/remote_server/src/remote_editing_tests.rs 🔗

@@ -15,7 +15,7 @@ use project::{
     search::{SearchQuery, SearchResult},
     Project, ProjectPath,
 };
-use remote::SshSession;
+use remote::SshRemoteClient;
 use serde_json::json;
 use settings::{Settings, SettingsLocation, SettingsStore};
 use smol::stream::StreamExt;
@@ -616,7 +616,7 @@ async fn init_test(
     cx: &mut TestAppContext,
     server_cx: &mut TestAppContext,
 ) -> (Model<Project>, Model<HeadlessProject>, Arc<FakeFs>) {
-    let (client_ssh, server_ssh) = SshSession::fake(cx, server_cx);
+    let (ssh_remote_client, ssh_server_client) = SshRemoteClient::fake(cx, server_cx);
     init_logger();
 
     let fs = FakeFs::new(server_cx.executor());
@@ -642,8 +642,9 @@ async fn init_test(
     );
 
     server_cx.update(HeadlessProject::init);
-    let headless = server_cx.new_model(|cx| HeadlessProject::new(server_ssh, fs.clone(), cx));
-    let project = build_project(client_ssh, cx);
+    let headless =
+        server_cx.new_model(|cx| HeadlessProject::new(ssh_server_client, fs.clone(), cx));
+    let project = build_project(ssh_remote_client, cx);
 
     project
         .update(cx, {
@@ -654,7 +655,7 @@ async fn init_test(
     (project, headless, fs)
 }
 
-fn build_project(ssh: Arc<SshSession>, cx: &mut TestAppContext) -> Model<Project> {
+fn build_project(ssh: Arc<SshRemoteClient>, cx: &mut TestAppContext) -> Model<Project> {
     cx.update(|cx| {
         let settings_store = SettingsStore::test(cx);
         cx.set_global(settings_store);

crates/workspace/src/workspace.rs 🔗

@@ -61,7 +61,7 @@ use postage::stream::Stream;
 use project::{
     DirectoryLister, Project, ProjectEntryId, ProjectPath, ResolvedPath, Worktree, WorktreeId,
 };
-use remote::{SshConnectionOptions, SshSession};
+use remote::{SshConnectionOptions, SshRemoteClient};
 use serde::Deserialize;
 use session::AppSession;
 use settings::{InvalidSettingsError, Settings};
@@ -5514,7 +5514,7 @@ pub fn join_hosted_project(
 pub fn open_ssh_project(
     window: WindowHandle<Workspace>,
     connection_options: SshConnectionOptions,
-    session: Arc<SshSession>,
+    session: Arc<SshRemoteClient>,
     app_state: Arc<AppState>,
     paths: Vec<PathBuf>,
     cx: &mut AppContext,