Refactor ssh remoting - make ChannelClient type private (#36514)

Max Brunsfeld created

This PR is one step in a series of refactors to prepare for having
"remote" projects that do not use SSH. The main use cases for this are
WSL and dev containers.

Release Notes:

- N/A

Change summary

crates/editor/src/editor.rs                  |   5 
crates/project/src/project.rs                |  23 +--
crates/remote/src/ssh_session.rs             | 146 +++++++++------------
crates/remote_server/src/headless_project.rs |  67 ++++-----
crates/remote_server/src/unix.rs             |  13 -
crates/rpc/src/proto_client.rs               |  19 ++
crates/tasks_ui/src/tasks_ui.rs              |   6 
7 files changed, 133 insertions(+), 146 deletions(-)

Detailed changes

crates/editor/src/editor.rs 🔗

@@ -14895,10 +14895,7 @@ impl Editor {
             };
 
             let hide_runnables = project
-                .update(cx, |project, cx| {
-                    // Do not display any test indicators in non-dev server remote projects.
-                    project.is_via_collab() && project.ssh_connection_string(cx).is_none()
-                })
+                .update(cx, |project, _| project.is_via_collab())
                 .unwrap_or(true);
             if hide_runnables {
                 return;

crates/project/src/project.rs 🔗

@@ -1346,14 +1346,13 @@ impl Project {
             };
 
             // ssh -> local machine handlers
-            let ssh = ssh.read(cx);
-            ssh.subscribe_to_entity(SSH_PROJECT_ID, &cx.entity());
-            ssh.subscribe_to_entity(SSH_PROJECT_ID, &this.buffer_store);
-            ssh.subscribe_to_entity(SSH_PROJECT_ID, &this.worktree_store);
-            ssh.subscribe_to_entity(SSH_PROJECT_ID, &this.lsp_store);
-            ssh.subscribe_to_entity(SSH_PROJECT_ID, &this.dap_store);
-            ssh.subscribe_to_entity(SSH_PROJECT_ID, &this.settings_observer);
-            ssh.subscribe_to_entity(SSH_PROJECT_ID, &this.git_store);
+            ssh_proto.subscribe_to_entity(SSH_PROJECT_ID, &cx.entity());
+            ssh_proto.subscribe_to_entity(SSH_PROJECT_ID, &this.buffer_store);
+            ssh_proto.subscribe_to_entity(SSH_PROJECT_ID, &this.worktree_store);
+            ssh_proto.subscribe_to_entity(SSH_PROJECT_ID, &this.lsp_store);
+            ssh_proto.subscribe_to_entity(SSH_PROJECT_ID, &this.dap_store);
+            ssh_proto.subscribe_to_entity(SSH_PROJECT_ID, &this.settings_observer);
+            ssh_proto.subscribe_to_entity(SSH_PROJECT_ID, &this.git_store);
 
             ssh_proto.add_entity_message_handler(Self::handle_create_buffer_for_peer);
             ssh_proto.add_entity_message_handler(Self::handle_update_worktree);
@@ -1900,14 +1899,6 @@ impl Project {
         false
     }
 
-    pub fn ssh_connection_string(&self, cx: &App) -> Option<SharedString> {
-        if let Some(ssh_state) = &self.ssh_client {
-            return Some(ssh_state.read(cx).connection_string().into());
-        }
-
-        None
-    }
-
     pub fn ssh_connection_state(&self, cx: &App) -> Option<remote::ConnectionState> {
         self.ssh_client
             .as_ref()

crates/remote/src/ssh_session.rs 🔗

@@ -26,8 +26,7 @@ use parking_lot::Mutex;
 
 use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
 use rpc::{
-    AnyProtoClient, EntityMessageSubscriber, ErrorExt, ProtoClient, ProtoMessageHandlerSet,
-    RpcError,
+    AnyProtoClient, ErrorExt, ProtoClient, ProtoMessageHandlerSet, RpcError,
     proto::{self, Envelope, EnvelopedMessage, PeerId, RequestMessage, build_typed_envelope},
 };
 use schemars::JsonSchema;
@@ -37,7 +36,6 @@ use smol::{
     process::{self, Child, Stdio},
 };
 use std::{
-    any::TypeId,
     collections::VecDeque,
     fmt, iter,
     ops::ControlFlow,
@@ -664,6 +662,7 @@ impl ConnectionIdentifier {
     pub fn setup() -> Self {
         Self::Setup(NEXT_ID.fetch_add(1, SeqCst))
     }
+
     // This string gets used in a socket name, and so must be relatively short.
     // The total length of:
     //   /home/{username}/.local/share/zed/server_state/{name}/stdout.sock
@@ -760,6 +759,15 @@ impl SshRemoteClient {
         })
     }
 
+    pub fn proto_client_from_channels(
+        incoming_rx: mpsc::UnboundedReceiver<Envelope>,
+        outgoing_tx: mpsc::UnboundedSender<Envelope>,
+        cx: &App,
+        name: &'static str,
+    ) -> AnyProtoClient {
+        ChannelClient::new(incoming_rx, outgoing_tx, cx, name).into()
+    }
+
     pub fn shutdown_processes<T: RequestMessage>(
         &self,
         shutdown_request: Option<T>,
@@ -990,64 +998,63 @@ impl SshRemoteClient {
         };
 
         cx.spawn(async move |cx| {
-                let mut missed_heartbeats = 0;
-
-                let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
-                futures::pin_mut!(keepalive_timer);
+            let mut missed_heartbeats = 0;
 
-                loop {
-                    select_biased! {
-                        result = connection_activity_rx.next().fuse() => {
-                            if result.is_none() {
-                                log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
-                                return Ok(());
-                            }
+            let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
+            futures::pin_mut!(keepalive_timer);
 
-                            if missed_heartbeats != 0 {
-                                missed_heartbeats = 0;
-                                let _ =this.update(cx, |this, cx| {
-                                    this.handle_heartbeat_result(missed_heartbeats, cx)
-                                })?;
-                            }
+            loop {
+                select_biased! {
+                    result = connection_activity_rx.next().fuse() => {
+                        if result.is_none() {
+                            log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping.");
+                            return Ok(());
                         }
-                        _ = keepalive_timer => {
-                            log::debug!("Sending heartbeat to server...");
-
-                            let result = select_biased! {
-                                _ = connection_activity_rx.next().fuse() => {
-                                    Ok(())
-                                }
-                                ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
-                                    ping_result
-                                }
-                            };
-
-                            if result.is_err() {
-                                missed_heartbeats += 1;
-                                log::warn!(
-                                    "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
-                                    HEARTBEAT_TIMEOUT,
-                                    missed_heartbeats,
-                                    MAX_MISSED_HEARTBEATS
-                                );
-                            } else if missed_heartbeats != 0 {
-                                missed_heartbeats = 0;
-                            } else {
-                                continue;
-                            }
 
-                            let result = this.update(cx, |this, cx| {
+                        if missed_heartbeats != 0 {
+                            missed_heartbeats = 0;
+                            let _ =this.update(cx, |this, cx| {
                                 this.handle_heartbeat_result(missed_heartbeats, cx)
                             })?;
-                            if result.is_break() {
-                                return Ok(());
-                            }
                         }
                     }
+                    _ = keepalive_timer => {
+                        log::debug!("Sending heartbeat to server...");
+
+                        let result = select_biased! {
+                            _ = connection_activity_rx.next().fuse() => {
+                                Ok(())
+                            }
+                            ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
+                                ping_result
+                            }
+                        };
+
+                        if result.is_err() {
+                            missed_heartbeats += 1;
+                            log::warn!(
+                                "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
+                                HEARTBEAT_TIMEOUT,
+                                missed_heartbeats,
+                                MAX_MISSED_HEARTBEATS
+                            );
+                        } else if missed_heartbeats != 0 {
+                            missed_heartbeats = 0;
+                        } else {
+                            continue;
+                        }
 
-                    keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
+                        let result = this.update(cx, |this, cx| {
+                            this.handle_heartbeat_result(missed_heartbeats, cx)
+                        })?;
+                        if result.is_break() {
+                            return Ok(());
+                        }
+                    }
                 }
 
+                keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
+            }
         })
     }
 
@@ -1145,10 +1152,6 @@ impl SshRemoteClient {
         cx.notify();
     }
 
-    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
-        self.client.subscribe_to_entity(remote_id, entity);
-    }
-
     pub fn ssh_info(&self) -> Option<(SshArgs, PathStyle)> {
         self.state
             .lock()
@@ -1222,7 +1225,7 @@ impl SshRemoteClient {
     pub fn fake_server(
         client_cx: &mut gpui::TestAppContext,
         server_cx: &mut gpui::TestAppContext,
-    ) -> (SshConnectionOptions, Arc<ChannelClient>) {
+    ) -> (SshConnectionOptions, AnyProtoClient) {
         let port = client_cx
             .update(|cx| cx.default_global::<ConnectionPool>().connections.len() as u16 + 1);
         let opts = SshConnectionOptions {
@@ -1255,7 +1258,7 @@ impl SshRemoteClient {
             })
         });
 
-        (opts, server_client)
+        (opts, server_client.into())
     }
 
     #[cfg(any(test, feature = "test-support"))]
@@ -2269,7 +2272,7 @@ impl SshRemoteConnection {
 
 type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
 
-pub struct ChannelClient {
+struct ChannelClient {
     next_message_id: AtomicU32,
     outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
     buffer: Mutex<VecDeque<Envelope>>,
@@ -2281,7 +2284,7 @@ pub struct ChannelClient {
 }
 
 impl ChannelClient {
-    pub fn new(
+    fn new(
         incoming_rx: mpsc::UnboundedReceiver<Envelope>,
         outgoing_tx: mpsc::UnboundedSender<Envelope>,
         cx: &App,
@@ -2402,7 +2405,7 @@ impl ChannelClient {
         })
     }
 
-    pub fn reconnect(
+    fn reconnect(
         self: &Arc<Self>,
         incoming_rx: UnboundedReceiver<Envelope>,
         outgoing_tx: UnboundedSender<Envelope>,
@@ -2412,26 +2415,7 @@ impl ChannelClient {
         *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx);
     }
 
-    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
-        let id = (TypeId::of::<E>(), remote_id);
-
-        let mut message_handlers = self.message_handlers.lock();
-        if message_handlers
-            .entities_by_type_and_remote_id
-            .contains_key(&id)
-        {
-            panic!("already subscribed to entity");
-        }
-
-        message_handlers.entities_by_type_and_remote_id.insert(
-            id,
-            EntityMessageSubscriber::Entity {
-                handle: entity.downgrade().into(),
-            },
-        );
-    }
-
-    pub fn request<T: RequestMessage>(
+    fn request<T: RequestMessage>(
         &self,
         payload: T,
     ) -> impl 'static + Future<Output = Result<T::Response>> {
@@ -2453,7 +2437,7 @@ impl ChannelClient {
         }
     }
 
-    pub async fn resync(&self, timeout: Duration) -> Result<()> {
+    async fn resync(&self, timeout: Duration) -> Result<()> {
         smol::future::or(
             async {
                 self.request_internal(proto::FlushBufferedMessages {}, false)
@@ -2475,7 +2459,7 @@ impl ChannelClient {
         .await
     }
 
-    pub async fn ping(&self, timeout: Duration) -> Result<()> {
+    async fn ping(&self, timeout: Duration) -> Result<()> {
         smol::future::or(
             async {
                 self.request(proto::Ping {}).await?;

crates/remote_server/src/headless_project.rs 🔗

@@ -19,7 +19,6 @@ use project::{
     task_store::TaskStore,
     worktree_store::WorktreeStore,
 };
-use remote::ssh_session::ChannelClient;
 use rpc::{
     AnyProtoClient, TypedEnvelope,
     proto::{self, SSH_PEER_ID, SSH_PROJECT_ID},
@@ -50,7 +49,7 @@ pub struct HeadlessProject {
 }
 
 pub struct HeadlessAppState {
-    pub session: Arc<ChannelClient>,
+    pub session: AnyProtoClient,
     pub fs: Arc<dyn Fs>,
     pub http_client: Arc<dyn HttpClient>,
     pub node_runtime: NodeRuntime,
@@ -81,7 +80,7 @@ impl HeadlessProject {
 
         let worktree_store = cx.new(|cx| {
             let mut store = WorktreeStore::local(true, fs.clone());
-            store.shared(SSH_PROJECT_ID, session.clone().into(), cx);
+            store.shared(SSH_PROJECT_ID, session.clone(), cx);
             store
         });
 
@@ -99,7 +98,7 @@ impl HeadlessProject {
 
         let buffer_store = cx.new(|cx| {
             let mut buffer_store = BufferStore::local(worktree_store.clone(), cx);
-            buffer_store.shared(SSH_PROJECT_ID, session.clone().into(), cx);
+            buffer_store.shared(SSH_PROJECT_ID, session.clone(), cx);
             buffer_store
         });
 
@@ -117,7 +116,7 @@ impl HeadlessProject {
                 breakpoint_store.clone(),
                 cx,
             );
-            dap_store.shared(SSH_PROJECT_ID, session.clone().into(), cx);
+            dap_store.shared(SSH_PROJECT_ID, session.clone(), cx);
             dap_store
         });
 
@@ -129,7 +128,7 @@ impl HeadlessProject {
                 fs.clone(),
                 cx,
             );
-            store.shared(SSH_PROJECT_ID, session.clone().into(), cx);
+            store.shared(SSH_PROJECT_ID, session.clone(), cx);
             store
         });
 
@@ -152,7 +151,7 @@ impl HeadlessProject {
                 environment.clone(),
                 cx,
             );
-            task_store.shared(SSH_PROJECT_ID, session.clone().into(), cx);
+            task_store.shared(SSH_PROJECT_ID, session.clone(), cx);
             task_store
         });
         let settings_observer = cx.new(|cx| {
@@ -162,7 +161,7 @@ impl HeadlessProject {
                 task_store.clone(),
                 cx,
             );
-            observer.shared(SSH_PROJECT_ID, session.clone().into(), cx);
+            observer.shared(SSH_PROJECT_ID, session.clone(), cx);
             observer
         });
 
@@ -183,7 +182,7 @@ impl HeadlessProject {
                 fs.clone(),
                 cx,
             );
-            lsp_store.shared(SSH_PROJECT_ID, session.clone().into(), cx);
+            lsp_store.shared(SSH_PROJECT_ID, session.clone(), cx);
             lsp_store
         });
 
@@ -210,8 +209,6 @@ impl HeadlessProject {
             cx,
         );
 
-        let client: AnyProtoClient = session.clone().into();
-
         // local_machine -> ssh handlers
         session.subscribe_to_entity(SSH_PROJECT_ID, &worktree_store);
         session.subscribe_to_entity(SSH_PROJECT_ID, &buffer_store);
@@ -223,44 +220,45 @@ impl HeadlessProject {
         session.subscribe_to_entity(SSH_PROJECT_ID, &settings_observer);
         session.subscribe_to_entity(SSH_PROJECT_ID, &git_store);
 
-        client.add_request_handler(cx.weak_entity(), Self::handle_list_remote_directory);
-        client.add_request_handler(cx.weak_entity(), Self::handle_get_path_metadata);
-        client.add_request_handler(cx.weak_entity(), Self::handle_shutdown_remote_server);
-        client.add_request_handler(cx.weak_entity(), Self::handle_ping);
+        session.add_request_handler(cx.weak_entity(), Self::handle_list_remote_directory);
+        session.add_request_handler(cx.weak_entity(), Self::handle_get_path_metadata);
+        session.add_request_handler(cx.weak_entity(), Self::handle_shutdown_remote_server);
+        session.add_request_handler(cx.weak_entity(), Self::handle_ping);
 
-        client.add_entity_request_handler(Self::handle_add_worktree);
-        client.add_request_handler(cx.weak_entity(), Self::handle_remove_worktree);
+        session.add_entity_request_handler(Self::handle_add_worktree);
+        session.add_request_handler(cx.weak_entity(), Self::handle_remove_worktree);
 
-        client.add_entity_request_handler(Self::handle_open_buffer_by_path);
-        client.add_entity_request_handler(Self::handle_open_new_buffer);
-        client.add_entity_request_handler(Self::handle_find_search_candidates);
-        client.add_entity_request_handler(Self::handle_open_server_settings);
+        session.add_entity_request_handler(Self::handle_open_buffer_by_path);
+        session.add_entity_request_handler(Self::handle_open_new_buffer);
+        session.add_entity_request_handler(Self::handle_find_search_candidates);
+        session.add_entity_request_handler(Self::handle_open_server_settings);
 
-        client.add_entity_request_handler(BufferStore::handle_update_buffer);
-        client.add_entity_message_handler(BufferStore::handle_close_buffer);
+        session.add_entity_request_handler(BufferStore::handle_update_buffer);
+        session.add_entity_message_handler(BufferStore::handle_close_buffer);
 
-        client.add_request_handler(
+        session.add_request_handler(
             extensions.clone().downgrade(),
             HeadlessExtensionStore::handle_sync_extensions,
         );
-        client.add_request_handler(
+        session.add_request_handler(
             extensions.clone().downgrade(),
             HeadlessExtensionStore::handle_install_extension,
         );
 
-        BufferStore::init(&client);
-        WorktreeStore::init(&client);
-        SettingsObserver::init(&client);
-        LspStore::init(&client);
-        TaskStore::init(Some(&client));
-        ToolchainStore::init(&client);
-        DapStore::init(&client, cx);
+        BufferStore::init(&session);
+        WorktreeStore::init(&session);
+        SettingsObserver::init(&session);
+        LspStore::init(&session);
+        TaskStore::init(Some(&session));
+        ToolchainStore::init(&session);
+        DapStore::init(&session, cx);
         // todo(debugger): Re init breakpoint store when we set it up for collab
         // BreakpointStore::init(&client);
-        GitStore::init(&client);
+        GitStore::init(&session);
 
         HeadlessProject {
-            session: client,
+            next_entry_id: Default::default(),
+            session,
             settings_observer,
             fs,
             worktree_store,
@@ -268,7 +266,6 @@ impl HeadlessProject {
             lsp_store,
             task_store,
             dap_store,
-            next_entry_id: Default::default(),
             languages,
             extensions,
             git_store,

crates/remote_server/src/unix.rs 🔗

@@ -19,11 +19,11 @@ use project::project_settings::ProjectSettings;
 
 use proto::CrashReport;
 use release_channel::{AppVersion, RELEASE_CHANNEL, ReleaseChannel};
-use remote::proxy::ProxyLaunchError;
-use remote::ssh_session::ChannelClient;
+use remote::SshRemoteClient;
 use remote::{
     json_log::LogRecord,
     protocol::{read_message, write_message},
+    proxy::ProxyLaunchError,
 };
 use reqwest_client::ReqwestClient;
 use rpc::proto::{self, Envelope, SSH_PROJECT_ID};
@@ -199,8 +199,7 @@ fn init_panic_hook(session_id: String) {
     }));
 }
 
-fn handle_crash_files_requests(project: &Entity<HeadlessProject>, client: &Arc<ChannelClient>) {
-    let client: AnyProtoClient = client.clone().into();
+fn handle_crash_files_requests(project: &Entity<HeadlessProject>, client: &AnyProtoClient) {
     client.add_request_handler(
         project.downgrade(),
         |_, _: TypedEnvelope<proto::GetCrashFiles>, _cx| async move {
@@ -276,7 +275,7 @@ fn start_server(
     listeners: ServerListeners,
     log_rx: Receiver<Vec<u8>>,
     cx: &mut App,
-) -> Arc<ChannelClient> {
+) -> AnyProtoClient {
     // This is the server idle timeout. If no connection comes in this timeout, the server will shut down.
     const IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10 * 60);
 
@@ -395,7 +394,7 @@ fn start_server(
     })
     .detach();
 
-    ChannelClient::new(incoming_rx, outgoing_tx, cx, "server")
+    SshRemoteClient::proto_client_from_channels(incoming_rx, outgoing_tx, cx, "server")
 }
 
 fn init_paths() -> anyhow::Result<()> {
@@ -792,7 +791,7 @@ async fn write_size_prefixed_buffer<S: AsyncWrite + Unpin>(
 }
 
 fn initialize_settings(
-    session: Arc<ChannelClient>,
+    session: AnyProtoClient,
     fs: Arc<dyn Fs>,
     cx: &mut App,
 ) -> watch::Receiver<Option<NodeBinaryOptions>> {

crates/rpc/src/proto_client.rs 🔗

@@ -315,4 +315,23 @@ impl AnyProtoClient {
                 }),
             );
     }
+
+    pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Entity<E>) {
+        let id = (TypeId::of::<E>(), remote_id);
+
+        let mut message_handlers = self.0.message_handler_set().lock();
+        if message_handlers
+            .entities_by_type_and_remote_id
+            .contains_key(&id)
+        {
+            panic!("already subscribed to entity");
+        }
+
+        message_handlers.entities_by_type_and_remote_id.insert(
+            id,
+            EntityMessageSubscriber::Entity {
+                handle: entity.downgrade().into(),
+            },
+        );
+    }
 }

crates/tasks_ui/src/tasks_ui.rs 🔗

@@ -148,9 +148,9 @@ pub fn toggle_modal(
 ) -> Task<()> {
     let task_store = workspace.project().read(cx).task_store().clone();
     let workspace_handle = workspace.weak_handle();
-    let can_open_modal = workspace.project().update(cx, |project, cx| {
-        project.is_local() || project.ssh_connection_string(cx).is_some() || project.is_via_ssh()
-    });
+    let can_open_modal = workspace
+        .project()
+        .read_with(cx, |project, _| !project.is_via_collab());
     if can_open_modal {
         let task_contexts = task_contexts(workspace, window, cx);
         cx.spawn_in(window, async move |workspace, cx| {