Merge pull request #680 from zed-industries/unregister-on-disconnect

Nathan Sobo created

Properly clear out registration and sharing state when a host loses their connection

Change summary

crates/project/src/project.rs     | 161 +++++++++++++++++---------------
crates/project/src/worktree.rs    |   9 +
crates/rpc/src/peer.rs            |   2 
crates/server/src/rpc.rs          | 100 ++++++++++++++++++++
crates/workspace/src/workspace.rs |   2 
5 files changed, 194 insertions(+), 80 deletions(-)

Detailed changes

crates/project/src/project.rs 🔗

@@ -302,31 +302,11 @@ impl Project {
                         let mut status = rpc.status();
                         while let Some(status) = status.next().await {
                             if let Some(this) = this.upgrade(&cx) {
-                                let remote_id = if status.is_connected() {
-                                    let response = rpc.request(proto::RegisterProject {}).await?;
-                                    Some(response.project_id)
+                                if status.is_connected() {
+                                    this.update(&mut cx, |this, cx| this.register(cx)).await?;
                                 } else {
-                                    None
-                                };
-
-                                if let Some(project_id) = remote_id {
-                                    let mut registrations = Vec::new();
-                                    this.update(&mut cx, |this, cx| {
-                                        for worktree in this.worktrees(cx).collect::<Vec<_>>() {
-                                            registrations.push(worktree.update(
-                                                cx,
-                                                |worktree, cx| {
-                                                    let worktree = worktree.as_local_mut().unwrap();
-                                                    worktree.register(project_id, cx)
-                                                },
-                                            ));
-                                        }
-                                    });
-                                    for registration in registrations {
-                                        registration.await?;
-                                    }
+                                    this.update(&mut cx, |this, cx| this.unregister(cx));
                                 }
-                                this.update(&mut cx, |this, cx| this.set_remote_id(remote_id, cx));
                             }
                         }
                         Ok(())
@@ -558,17 +538,54 @@ impl Project {
         &self.fs
     }
 
-    fn set_remote_id(&mut self, remote_id: Option<u64>, cx: &mut ModelContext<Self>) {
+    fn unregister(&mut self, cx: &mut ModelContext<Self>) {
+        self.unshare(cx);
+        for worktree in &self.worktrees {
+            if let Some(worktree) = worktree.upgrade(cx) {
+                worktree.update(cx, |worktree, _| {
+                    worktree.as_local_mut().unwrap().unregister();
+                });
+            }
+        }
+
         if let ProjectClientState::Local { remote_id_tx, .. } = &mut self.client_state {
-            *remote_id_tx.borrow_mut() = remote_id;
+            *remote_id_tx.borrow_mut() = None;
         }
 
         self.subscriptions.clear();
-        if let Some(remote_id) = remote_id {
-            self.subscriptions
-                .push(self.client.add_model_for_remote_entity(remote_id, cx));
-        }
-        cx.emit(Event::RemoteIdChanged(remote_id))
+    }
+
+    fn register(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
+        self.unregister(cx);
+
+        let response = self.client.request(proto::RegisterProject {});
+        cx.spawn(|this, mut cx| async move {
+            let remote_id = response.await?.project_id;
+
+            let mut registrations = Vec::new();
+            this.update(&mut cx, |this, cx| {
+                if let ProjectClientState::Local { remote_id_tx, .. } = &mut this.client_state {
+                    *remote_id_tx.borrow_mut() = Some(remote_id);
+                }
+
+                cx.emit(Event::RemoteIdChanged(Some(remote_id)));
+
+                this.subscriptions
+                    .push(this.client.add_model_for_remote_entity(remote_id, cx));
+
+                for worktree in &this.worktrees {
+                    if let Some(worktree) = worktree.upgrade(cx) {
+                        registrations.push(worktree.update(cx, |worktree, cx| {
+                            let worktree = worktree.as_local_mut().unwrap();
+                            worktree.register(remote_id, cx)
+                        }));
+                    }
+                }
+            });
+
+            futures::future::try_join_all(registrations).await?;
+            Ok(())
+        })
     }
 
     pub fn remote_id(&self) -> Option<u64> {
@@ -725,59 +742,51 @@ impl Project {
         })
     }
 
-    pub fn unshare(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
+    pub fn unshare(&mut self, cx: &mut ModelContext<Self>) {
         let rpc = self.client.clone();
-        cx.spawn(|this, mut cx| async move {
-            let project_id = this.update(&mut cx, |this, cx| {
-                if let ProjectClientState::Local {
-                    is_shared,
-                    remote_id_rx,
-                    ..
-                } = &mut this.client_state
-                {
-                    *is_shared = false;
 
-                    for open_buffer in this.opened_buffers.values_mut() {
-                        match open_buffer {
-                            OpenBuffer::Strong(buffer) => {
-                                *open_buffer = OpenBuffer::Weak(buffer.downgrade());
-                            }
-                            _ => {}
-                        }
-                    }
+        if let ProjectClientState::Local {
+            is_shared,
+            remote_id_rx,
+            ..
+        } = &mut self.client_state
+        {
+            if !*is_shared {
+                return;
+            }
 
-                    for worktree_handle in this.worktrees.iter_mut() {
-                        match worktree_handle {
-                            WorktreeHandle::Strong(worktree) => {
-                                if !worktree.read(cx).is_visible() {
-                                    *worktree_handle = WorktreeHandle::Weak(worktree.downgrade());
-                                }
-                            }
-                            _ => {}
-                        }
+            *is_shared = false;
+            self.collaborators.clear();
+            self.shared_buffers.clear();
+            for worktree_handle in self.worktrees.iter_mut() {
+                if let WorktreeHandle::Strong(worktree) = worktree_handle {
+                    let is_visible = worktree.update(cx, |worktree, _| {
+                        worktree.as_local_mut().unwrap().unshare();
+                        worktree.is_visible()
+                    });
+                    if !is_visible {
+                        *worktree_handle = WorktreeHandle::Weak(worktree.downgrade());
                     }
-
-                    remote_id_rx
-                        .borrow()
-                        .ok_or_else(|| anyhow!("no project id"))
-                } else {
-                    Err(anyhow!("can't share a remote project"))
                 }
-            })?;
+            }
 
-            rpc.send(proto::UnshareProject { project_id })?;
-            this.update(&mut cx, |this, cx| {
-                this.collaborators.clear();
-                this.shared_buffers.clear();
-                for worktree in this.worktrees(cx).collect::<Vec<_>>() {
-                    worktree.update(cx, |worktree, _| {
-                        worktree.as_local_mut().unwrap().unshare();
-                    });
+            for open_buffer in self.opened_buffers.values_mut() {
+                match open_buffer {
+                    OpenBuffer::Strong(buffer) => {
+                        *open_buffer = OpenBuffer::Weak(buffer.downgrade());
+                    }
+                    _ => {}
                 }
-                cx.notify()
-            });
-            Ok(())
-        })
+            }
+
+            if let Some(project_id) = *remote_id_rx.borrow() {
+                rpc.send(proto::UnshareProject { project_id }).log_err();
+            }
+
+            cx.notify();
+        } else {
+            log::error!("attempted to unshare a remote project");
+        }
     }
 
     fn project_unshared(&mut self, cx: &mut ModelContext<Self>) {

crates/project/src/worktree.rs 🔗

@@ -711,7 +711,9 @@ impl LocalWorktree {
                 let worktree = this.as_local_mut().unwrap();
                 match response {
                     Ok(_) => {
-                        worktree.registration = Registration::Done { project_id };
+                        if worktree.registration == Registration::Pending {
+                            worktree.registration = Registration::Done { project_id };
+                        }
                         Ok(())
                     }
                     Err(error) => {
@@ -808,6 +810,11 @@ impl LocalWorktree {
         })
     }
 
+    pub fn unregister(&mut self) {
+        self.unshare();
+        self.registration = Registration::None;
+    }
+
     pub fn unshare(&mut self) {
         self.share.take();
     }

crates/rpc/src/peer.rs 🔗

@@ -96,7 +96,7 @@ pub struct ConnectionState {
 
 const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
 const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
-const RECEIVE_TIMEOUT: Duration = Duration::from_secs(30);
+pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(5);
 
 impl Peer {
     pub fn new() -> Arc<Self> {

crates/server/src/rpc.rs 🔗

@@ -1310,10 +1310,105 @@ mod tests {
             .unwrap();
 
         // Unshare the project as client A
+        project_a.update(cx_a, |project, cx| project.unshare(cx));
+        project_b
+            .condition(cx_b, |project, _| project.is_read_only())
+            .await;
+        assert!(worktree_a.read_with(cx_a, |tree, _| !tree.as_local().unwrap().is_shared()));
+        cx_b.update(|_| {
+            drop(project_b);
+        });
+
+        // Share the project again and ensure guests can still join.
         project_a
-            .update(cx_a, |project, cx| project.unshare(cx))
+            .update(cx_a, |project, cx| project.share(cx))
+            .await
+            .unwrap();
+        assert!(worktree_a.read_with(cx_a, |tree, _| tree.as_local().unwrap().is_shared()));
+
+        let project_b2 = Project::remote(
+            project_id,
+            client_b.clone(),
+            client_b.user_store.clone(),
+            lang_registry.clone(),
+            fs.clone(),
+            &mut cx_b.to_async(),
+        )
+        .await
+        .unwrap();
+        project_b2
+            .update(cx_b, |p, cx| p.open_buffer((worktree_id, "a.txt"), cx))
             .await
             .unwrap();
+    }
+
+    #[gpui::test(iterations = 10)]
+    async fn test_host_disconnect(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
+        let lang_registry = Arc::new(LanguageRegistry::test());
+        let fs = FakeFs::new(cx_a.background());
+        cx_a.foreground().forbid_parking();
+
+        // Connect to a server as 2 clients.
+        let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+        let client_a = server.create_client(cx_a, "user_a").await;
+        let client_b = server.create_client(cx_b, "user_b").await;
+
+        // Share a project as client A
+        fs.insert_tree(
+            "/a",
+            json!({
+                ".zed.toml": r#"collaborators = ["user_b"]"#,
+                "a.txt": "a-contents",
+                "b.txt": "b-contents",
+            }),
+        )
+        .await;
+        let project_a = cx_a.update(|cx| {
+            Project::local(
+                client_a.clone(),
+                client_a.user_store.clone(),
+                lang_registry.clone(),
+                fs.clone(),
+                cx,
+            )
+        });
+        let (worktree_a, _) = project_a
+            .update(cx_a, |p, cx| {
+                p.find_or_create_local_worktree("/a", true, cx)
+            })
+            .await
+            .unwrap();
+        worktree_a
+            .read_with(cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
+            .await;
+        let project_id = project_a.update(cx_a, |p, _| p.next_remote_id()).await;
+        let worktree_id = worktree_a.read_with(cx_a, |tree, _| tree.id());
+        project_a.update(cx_a, |p, cx| p.share(cx)).await.unwrap();
+        assert!(worktree_a.read_with(cx_a, |tree, _| tree.as_local().unwrap().is_shared()));
+
+        // Join that project as client B
+        let project_b = Project::remote(
+            project_id,
+            client_b.clone(),
+            client_b.user_store.clone(),
+            lang_registry.clone(),
+            fs.clone(),
+            &mut cx_b.to_async(),
+        )
+        .await
+        .unwrap();
+        project_b
+            .update(cx_b, |p, cx| p.open_buffer((worktree_id, "a.txt"), cx))
+            .await
+            .unwrap();
+
+        // Drop client A's connection. Collaborators should disappear and the project should not be shown as shared.
+        server.disconnect_client(client_a.current_user_id(cx_a));
+        cx_a.foreground().advance_clock(rpc::RECEIVE_TIMEOUT);
+        project_a
+            .condition(cx_a, |project, _| project.collaborators().is_empty())
+            .await;
+        project_a.read_with(cx_a, |project, _| assert!(!project.is_shared()));
         project_b
             .condition(cx_b, |project, _| project.is_read_only())
             .await;
@@ -1322,6 +1417,9 @@ mod tests {
             drop(project_b);
         });
 
+        // Await reconnection
+        let project_id = project_a.update(cx_a, |p, _| p.next_remote_id()).await;
+
         // Share the project again and ensure guests can still join.
         project_a
             .update(cx_a, |project, cx| project.share(cx))

crates/workspace/src/workspace.rs 🔗

@@ -1278,7 +1278,7 @@ impl Workspace {
         self.project.update(cx, |project, cx| {
             if project.is_local() {
                 if project.is_shared() {
-                    project.unshare(cx).detach();
+                    project.unshare(cx);
                 } else {
                     project.share(cx).detach();
                 }