ssh remoting: Fix SSH connection not being closed (#18329)

Thorsten Ball and Bennet created

This fixes the `SshSession` being leaked.

There were two leaks:

1. `Arc<SshSession>` itself got leaked into the `SettingsObserver` that
   lives as long as the application. Fixed with a weak reference.
2. The two tasks spawned by an `SshSession` had a circular dependency
   and didn't exit while the other one was running. Fixed by fixing (1)
   and then attaching one of the tasks to the `SshSession`, which means
   it gets dropped with the session itself, which leads the other task
   to error and exit.

Co-authored-by: Bennet <bennet@zed.dev>

Release Notes:

- N/A

---------

Co-authored-by: Bennet <bennet@zed.dev>

Change summary

crates/project/src/project_settings.rs        | 13 +++++---
crates/recent_projects/src/recent_projects.rs |  2 
crates/remote/src/ssh_session.rs              | 31 ++++++++++++++++----
crates/rpc/src/proto_client.rs                | 20 ++++++++++++
crates/worktree/src/worktree.rs               |  2 
5 files changed, 53 insertions(+), 15 deletions(-)

Detailed changes

crates/project/src/project_settings.rs 🔗

@@ -334,17 +334,20 @@ impl SettingsObserver {
             .log_err();
         }
 
+        let weak_client = ssh.downgrade();
         cx.observe_global::<SettingsStore>(move |_, cx| {
             let new_settings = cx.global::<SettingsStore>().raw_user_settings();
             if &settings != new_settings {
                 settings = new_settings.clone()
             }
             if let Some(content) = serde_json::to_string(&settings).log_err() {
-                ssh.send(proto::UpdateUserSettings {
-                    project_id: 0,
-                    content,
-                })
-                .log_err();
+                if let Some(ssh) = weak_client.upgrade() {
+                    ssh.send(proto::UpdateUserSettings {
+                        project_id: 0,
+                        content,
+                    })
+                    .log_err();
+                }
             }
         })
         .detach();

crates/recent_projects/src/recent_projects.rs 🔗

@@ -509,7 +509,7 @@ impl PickerDelegate for RecentProjectsDelegate {
                                         .color(Color::Muted)
                                         .into_any_element()
                                 }
-                                SerializedWorkspaceLocation::Ssh(_) => Icon::new(IconName::Screen)
+                                SerializedWorkspaceLocation::Ssh(_) => Icon::new(IconName::Server)
                                     .color(Color::Muted)
                                     .into_any_element(),
                                 SerializedWorkspaceLocation::DevServer(_) => {

crates/remote/src/ssh_session.rs 🔗

@@ -11,7 +11,7 @@ use futures::{
     future::BoxFuture,
     select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, StreamExt as _,
 };
-use gpui::{AppContext, AsyncAppContext, Model, SemanticVersion};
+use gpui::{AppContext, AsyncAppContext, Model, SemanticVersion, Task};
 use parking_lot::Mutex;
 use rpc::{
     proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
@@ -51,6 +51,7 @@ pub struct SshSession {
     spawn_process_tx: mpsc::UnboundedSender<SpawnRequest>,
     client_socket: Option<SshSocket>,
     state: Mutex<ProtoMessageHandlerSet>, // Lock
+    _io_task: Option<Task<Result<()>>>,
 }
 
 struct SshClientState {
@@ -173,8 +174,7 @@ impl SshSession {
         let mut child_stdout = remote_server_child.stdout.take().unwrap();
         let mut child_stdin = remote_server_child.stdin.take().unwrap();
 
-        let executor = cx.background_executor().clone();
-        executor.clone().spawn(async move {
+        let io_task = cx.background_executor().spawn(async move {
             let mut stdin_buffer = Vec::new();
             let mut stdout_buffer = Vec::new();
             let mut stderr_buffer = Vec::new();
@@ -264,9 +264,18 @@ impl SshSession {
                     }
                 }
             }
-        }).detach();
+        });
 
-        cx.update(|cx| Self::new(incoming_rx, outgoing_tx, spawn_process_tx, Some(socket), cx))
+        cx.update(|cx| {
+            Self::new(
+                incoming_rx,
+                outgoing_tx,
+                spawn_process_tx,
+                Some(socket),
+                Some(io_task),
+                cx,
+            )
+        })
     }
 
     pub fn server(
@@ -275,7 +284,7 @@ impl SshSession {
         cx: &AppContext,
     ) -> Arc<SshSession> {
         let (tx, _rx) = mpsc::unbounded();
-        Self::new(incoming_rx, outgoing_tx, tx, None, cx)
+        Self::new(incoming_rx, outgoing_tx, tx, None, None, cx)
     }
 
     #[cfg(any(test, feature = "test-support"))]
@@ -293,6 +302,7 @@ impl SshSession {
                     client_to_server_tx,
                     tx.clone(),
                     None, // todo()
+                    None,
                     cx,
                 )
             }),
@@ -302,6 +312,7 @@ impl SshSession {
                     server_to_client_tx,
                     tx.clone(),
                     None,
+                    None,
                     cx,
                 )
             }),
@@ -313,6 +324,7 @@ impl SshSession {
         outgoing_tx: mpsc::UnboundedSender<Envelope>,
         spawn_process_tx: mpsc::UnboundedSender<SpawnRequest>,
         client_socket: Option<SshSocket>,
+        io_task: Option<Task<Result<()>>>,
         cx: &AppContext,
     ) -> Arc<SshSession> {
         let this = Arc::new(Self {
@@ -322,13 +334,18 @@ impl SshSession {
             spawn_process_tx,
             client_socket,
             state: Default::default(),
+            _io_task: io_task,
         });
 
         cx.spawn(|cx| {
-            let this = this.clone();
+            let this = Arc::downgrade(&this);
             async move {
                 let peer_id = PeerId { owner_id: 0, id: 0 };
                 while let Some(incoming) = incoming_rx.next().await {
+                    let Some(this) = this.upgrade() else {
+                        return anyhow::Ok(());
+                    };
+
                     if let Some(request_id) = incoming.responding_to {
                         let request_id = MessageId(request_id);
                         let sender = this.response_channels.lock().remove(&request_id);

crates/rpc/src/proto_client.rs 🔗

@@ -10,11 +10,29 @@ use proto::{
     error::ErrorExt as _, AnyTypedEnvelope, EntityMessage, Envelope, EnvelopedMessage,
     RequestMessage, TypedEnvelope,
 };
-use std::{any::TypeId, sync::Arc};
+use std::{
+    any::TypeId,
+    sync::{Arc, Weak},
+};
 
 #[derive(Clone)]
 pub struct AnyProtoClient(Arc<dyn ProtoClient>);
 
+impl AnyProtoClient {
+    pub fn downgrade(&self) -> AnyWeakProtoClient {
+        AnyWeakProtoClient(Arc::downgrade(&self.0))
+    }
+}
+
+#[derive(Clone)]
+pub struct AnyWeakProtoClient(Weak<dyn ProtoClient>);
+
+impl AnyWeakProtoClient {
+    pub fn upgrade(&self) -> Option<AnyProtoClient> {
+        self.0.upgrade().map(AnyProtoClient)
+    }
+}
+
 pub trait ProtoClient: Send + Sync {
     fn request(
         &self,

crates/worktree/src/worktree.rs 🔗

@@ -472,7 +472,7 @@ impl Worktree {
                 disconnected: false,
             };
 
-            // Apply updates to a separate snapshto in a background task, then
+            // Apply updates to a separate snapshot in a background task, then
             // send them to a foreground task which updates the model.
             cx.background_executor()
                 .spawn(async move {