Make host reconnection test pass when mutating worktree while offline

Antonio Scandurra created

Change summary

crates/call/src/room.rs                      |   2 
crates/collab/src/db.rs                      | 157 +++++++++++++---
crates/collab/src/db/project_collaborator.rs |  10 +
crates/collab/src/rpc.rs                     | 200 +++++++++++----------
crates/project/src/project.rs                |  11 +
crates/project/src/worktree.rs               | 110 +++++------
crates/rpc/proto/zed.proto                   |   2 
7 files changed, 307 insertions(+), 185 deletions(-)

Detailed changes

crates/call/src/room.rs 🔗

@@ -357,7 +357,7 @@ impl Room {
                 if let Some(project_id) = project.remote_id() {
                     projects.insert(project_id, handle.clone());
                     rejoined_projects.push(proto::RejoinProject {
-                        project_id,
+                        id: project_id,
                         worktrees: project
                             .worktrees(cx)
                             .map(|worktree| {

crates/collab/src/db.rs 🔗

@@ -1343,12 +1343,116 @@ impl Database {
 
     pub async fn rejoin_room(
         &self,
-        room_id: proto::RejoinRoom,
+        rejoin_room: proto::RejoinRoom,
         user_id: UserId,
-        connection_id: ConnectionId,
-    ) -> Result<RejoinedRoom> {
-        println!("==============");
-        todo!()
+        connection: ConnectionId,
+    ) -> Result<RoomGuard<RejoinedRoom>> {
+        self.room_transaction(|tx| async {
+            let tx = tx;
+            let room_id = RoomId::from_proto(rejoin_room.id);
+            let participant_update = room_participant::Entity::update_many()
+                .filter(
+                    Condition::all()
+                        .add(room_participant::Column::RoomId.eq(room_id))
+                        .add(room_participant::Column::UserId.eq(user_id))
+                        .add(room_participant::Column::AnsweringConnectionId.is_not_null())
+                        .add(
+                            Condition::any()
+                                .add(room_participant::Column::AnsweringConnectionLost.eq(true))
+                                .add(
+                                    room_participant::Column::AnsweringConnectionServerId
+                                        .ne(connection.owner_id as i32),
+                                ),
+                        ),
+                )
+                .set(room_participant::ActiveModel {
+                    answering_connection_id: ActiveValue::set(Some(connection.id as i32)),
+                    answering_connection_server_id: ActiveValue::set(Some(ServerId(
+                        connection.owner_id as i32,
+                    ))),
+                    answering_connection_lost: ActiveValue::set(false),
+                    ..Default::default()
+                })
+                .exec(&*tx)
+                .await?;
+            if participant_update.rows_affected == 0 {
+                Err(anyhow!("room does not exist or was already joined"))?
+            } else {
+                let mut reshared_projects = Vec::new();
+                for reshared_project in &rejoin_room.reshared_projects {
+                    let project_id = ProjectId::from_proto(reshared_project.project_id);
+                    let project = project::Entity::find_by_id(project_id)
+                        .one(&*tx)
+                        .await?
+                        .ok_or_else(|| anyhow!("project does not exist"))?;
+                    if project.host_user_id != user_id {
+                        return Err(anyhow!("no such project"))?;
+                    }
+
+                    let mut collaborators = project
+                        .find_related(project_collaborator::Entity)
+                        .all(&*tx)
+                        .await?;
+                    let host_ix = collaborators
+                        .iter()
+                        .position(|collaborator| {
+                            collaborator.user_id == user_id && collaborator.is_host
+                        })
+                        .ok_or_else(|| anyhow!("host not found among collaborators"))?;
+                    let host = collaborators.swap_remove(host_ix);
+                    let old_connection_id = host.connection();
+
+                    project::Entity::update(project::ActiveModel {
+                        host_connection_id: ActiveValue::set(Some(connection.id as i32)),
+                        host_connection_server_id: ActiveValue::set(Some(ServerId(
+                            connection.owner_id as i32,
+                        ))),
+                        ..project.into_active_model()
+                    })
+                    .exec(&*tx)
+                    .await?;
+                    project_collaborator::Entity::update(project_collaborator::ActiveModel {
+                        connection_id: ActiveValue::set(connection.id as i32),
+                        connection_server_id: ActiveValue::set(ServerId(
+                            connection.owner_id as i32,
+                        )),
+                        ..host.into_active_model()
+                    })
+                    .exec(&*tx)
+                    .await?;
+
+                    reshared_projects.push(ResharedProject {
+                        id: project_id,
+                        old_connection_id,
+                        collaborators: collaborators
+                            .iter()
+                            .map(|collaborator| ProjectCollaborator {
+                                connection_id: collaborator.connection(),
+                                user_id: collaborator.user_id,
+                                replica_id: collaborator.replica_id,
+                                is_host: collaborator.is_host,
+                            })
+                            .collect(),
+                        worktrees: reshared_project.worktrees.clone(),
+                    });
+                }
+
+                // TODO: handle unshared projects
+                // TODO: handle left projects
+
+                let room = self.get_room(room_id, &tx).await?;
+                Ok((
+                    room_id,
+                    RejoinedRoom {
+                        room,
+                        // TODO: handle rejoined projects
+                        rejoined_projects: Default::default(),
+                        reshared_projects,
+                    },
+                ))
+            }
+        })
+        .await
     }
 
     pub async fn leave_room(
@@ -1447,10 +1551,7 @@ impl Database {
                                 host_connection_id: Default::default(),
                             });
 
-                    let collaborator_connection_id = ConnectionId {
-                        owner_id: collaborator.connection_server_id.0 as u32,
-                        id: collaborator.connection_id as u32,
-                    };
+                    let collaborator_connection_id = collaborator.connection();
                     if collaborator_connection_id != connection {
                         left_project.connection_ids.push(collaborator_connection_id);
                     }
@@ -2232,10 +2333,7 @@ impl Database {
                 collaborators: collaborators
                     .into_iter()
                     .map(|collaborator| ProjectCollaborator {
-                        connection_id: ConnectionId {
-                            owner_id: collaborator.connection_server_id.0 as u32,
-                            id: collaborator.connection_id as u32,
-                        },
+                        connection_id: collaborator.connection(),
                         user_id: collaborator.user_id,
                         replica_id: collaborator.replica_id,
                         is_host: collaborator.is_host,
@@ -2287,10 +2385,7 @@ impl Database {
                 .await?;
             let connection_ids = collaborators
                 .into_iter()
-                .map(|collaborator| ConnectionId {
-                    owner_id: collaborator.connection_server_id.0 as u32,
-                    id: collaborator.connection_id as u32,
-                })
+                .map(|collaborator| collaborator.connection())
                 .collect();
 
             let left_project = LeftProject {
@@ -2320,10 +2415,7 @@ impl Database {
                 .await?
                 .into_iter()
                 .map(|collaborator| ProjectCollaborator {
-                    connection_id: ConnectionId {
-                        owner_id: collaborator.connection_server_id.0 as u32,
-                        id: collaborator.connection_id as u32,
-                    },
+                    connection_id: collaborator.connection(),
                     user_id: collaborator.user_id,
                     replica_id: collaborator.replica_id,
                     is_host: collaborator.is_host,
@@ -2352,18 +2444,15 @@ impl Database {
                 .one(&*tx)
                 .await?
                 .ok_or_else(|| anyhow!("no such project"))?;
-            let mut participants = project_collaborator::Entity::find()
+            let mut collaborators = project_collaborator::Entity::find()
                 .filter(project_collaborator::Column::ProjectId.eq(project_id))
                 .stream(&*tx)
                 .await?;
 
             let mut connection_ids = HashSet::default();
-            while let Some(participant) = participants.next().await {
-                let participant = participant?;
-                connection_ids.insert(ConnectionId {
-                    owner_id: participant.connection_server_id.0 as u32,
-                    id: participant.connection_id as u32,
-                });
+            while let Some(collaborator) = collaborators.next().await {
+                let collaborator = collaborator?;
+                connection_ids.insert(collaborator.connection());
             }
 
             if connection_ids.contains(&connection_id) {
@@ -2380,7 +2469,7 @@ impl Database {
         project_id: ProjectId,
         tx: &DatabaseTransaction,
     ) -> Result<Vec<ConnectionId>> {
-        let mut participants = project_collaborator::Entity::find()
+        let mut collaborators = project_collaborator::Entity::find()
             .filter(
                 project_collaborator::Column::ProjectId
                     .eq(project_id)
@@ -2390,12 +2479,9 @@ impl Database {
             .await?;
 
         let mut guest_connection_ids = Vec::new();
-        while let Some(participant) = participants.next().await {
-            let participant = participant?;
-            guest_connection_ids.push(ConnectionId {
-                owner_id: participant.connection_server_id.0 as u32,
-                id: participant.connection_id as u32,
-            });
+        while let Some(collaborator) = collaborators.next().await {
+            let collaborator = collaborator?;
+            guest_connection_ids.push(collaborator.connection());
         }
         Ok(guest_connection_ids)
     }
@@ -2817,6 +2903,7 @@ pub struct ResharedProject {
     pub id: ProjectId,
     pub old_connection_id: ConnectionId,
     pub collaborators: Vec<ProjectCollaborator>,
+    pub worktrees: Vec<proto::WorktreeMetadata>,
 }
 
 pub struct RejoinedProject {

crates/collab/src/db/project_collaborator.rs 🔗

@@ -1,4 +1,5 @@
 use super::{ProjectCollaboratorId, ProjectId, ReplicaId, ServerId, UserId};
+use rpc::ConnectionId;
 use sea_orm::entity::prelude::*;
 
 #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
@@ -14,6 +15,15 @@ pub struct Model {
     pub is_host: bool,
 }
 
+impl Model {
+    pub fn connection(&self) -> ConnectionId {
+        ConnectionId {
+            owner_id: self.connection_server_id.0 as u32,
+            id: self.connection_id as u32,
+        }
+    }
+}
+
 #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
 pub enum Relation {
     #[sea_orm(

crates/collab/src/rpc.rs 🔗

@@ -942,56 +942,55 @@ async fn rejoin_room(
     response: Response<proto::RejoinRoom>,
     session: Session,
 ) -> Result<()> {
-    let mut rejoined_room = session
-        .db()
-        .await
-        .rejoin_room(request, session.user_id, session.connection_id)
-        .await?;
+    {
+        let mut rejoined_room = session
+            .db()
+            .await
+            .rejoin_room(request, session.user_id, session.connection_id)
+            .await?;
 
-    response.send(proto::RejoinRoomResponse {
-        room: Some(rejoined_room.room.clone()),
-        reshared_projects: rejoined_room
-            .reshared_projects
-            .iter()
-            .map(|project| proto::ResharedProject {
-                id: project.id.to_proto(),
-                collaborators: project
-                    .collaborators
-                    .iter()
-                    .map(|collaborator| collaborator.to_proto())
-                    .collect(),
-            })
-            .collect(),
-        rejoined_projects: rejoined_room
-            .rejoined_projects
-            .iter()
-            .map(|rejoined_project| proto::RejoinedProject {
-                id: rejoined_project.id.to_proto(),
-                worktrees: rejoined_project
-                    .worktrees
-                    .iter()
-                    .map(|worktree| proto::WorktreeMetadata {
-                        id: worktree.id,
-                        root_name: worktree.root_name.clone(),
-                        visible: worktree.visible,
-                        abs_path: worktree.abs_path.clone(),
-                    })
-                    .collect(),
-                collaborators: rejoined_project
-                    .collaborators
-                    .iter()
-                    .map(|collaborator| collaborator.to_proto())
-                    .collect(),
-                language_servers: rejoined_project.language_servers.clone(),
-            })
-            .collect(),
-    })?;
-    room_updated(&rejoined_room.room, &session.peer);
+        response.send(proto::RejoinRoomResponse {
+            room: Some(rejoined_room.room.clone()),
+            reshared_projects: rejoined_room
+                .reshared_projects
+                .iter()
+                .map(|project| proto::ResharedProject {
+                    id: project.id.to_proto(),
+                    collaborators: project
+                        .collaborators
+                        .iter()
+                        .map(|collaborator| collaborator.to_proto())
+                        .collect(),
+                })
+                .collect(),
+            rejoined_projects: rejoined_room
+                .rejoined_projects
+                .iter()
+                .map(|rejoined_project| proto::RejoinedProject {
+                    id: rejoined_project.id.to_proto(),
+                    worktrees: rejoined_project
+                        .worktrees
+                        .iter()
+                        .map(|worktree| proto::WorktreeMetadata {
+                            id: worktree.id,
+                            root_name: worktree.root_name.clone(),
+                            visible: worktree.visible,
+                            abs_path: worktree.abs_path.clone(),
+                        })
+                        .collect(),
+                    collaborators: rejoined_project
+                        .collaborators
+                        .iter()
+                        .map(|collaborator| collaborator.to_proto())
+                        .collect(),
+                    language_servers: rejoined_project.language_servers.clone(),
+                })
+                .collect(),
+        })?;
+        room_updated(&rejoined_room.room, &session.peer);
 
-    // Notify other participants about this peer's reconnection to projects.
-    for project in &rejoined_room.reshared_projects {
-        for collaborator in &project.collaborators {
-            if collaborator.connection_id != session.connection_id {
+        for project in &rejoined_room.reshared_projects {
+            for collaborator in &project.collaborators {
                 session
                     .peer
                     .send(
@@ -1004,11 +1003,28 @@ async fn rejoin_room(
                     )
                     .trace_err();
             }
+
+            broadcast(
+                session.connection_id,
+                project
+                    .collaborators
+                    .iter()
+                    .map(|collaborator| collaborator.connection_id),
+                |connection_id| {
+                    session.peer.forward_send(
+                        session.connection_id,
+                        connection_id,
+                        proto::UpdateProject {
+                            project_id: project.id.to_proto(),
+                            worktrees: project.worktrees.clone(),
+                        },
+                    )
+                },
+            );
         }
-    }
-    for project in &rejoined_room.rejoined_projects {
-        for collaborator in &project.collaborators {
-            if collaborator.connection_id != session.connection_id {
+
+        for project in &rejoined_room.rejoined_projects {
+            for collaborator in &project.collaborators {
                 session
                     .peer
                     .send(
@@ -1022,57 +1038,57 @@ async fn rejoin_room(
                     .trace_err();
             }
         }
-    }
 
-    for project in &mut rejoined_room.rejoined_projects {
-        for worktree in mem::take(&mut project.worktrees) {
-            #[cfg(any(test, feature = "test-support"))]
-            const MAX_CHUNK_SIZE: usize = 2;
-            #[cfg(not(any(test, feature = "test-support")))]
-            const MAX_CHUNK_SIZE: usize = 256;
+        for project in &mut rejoined_room.rejoined_projects {
+            for worktree in mem::take(&mut project.worktrees) {
+                #[cfg(any(test, feature = "test-support"))]
+                const MAX_CHUNK_SIZE: usize = 2;
+                #[cfg(not(any(test, feature = "test-support")))]
+                const MAX_CHUNK_SIZE: usize = 256;
 
-            // Stream this worktree's entries.
-            let message = proto::UpdateWorktree {
-                project_id: project.id.to_proto(),
-                worktree_id: worktree.id,
-                abs_path: worktree.abs_path.clone(),
-                root_name: worktree.root_name,
-                updated_entries: worktree.updated_entries,
-                removed_entries: worktree.removed_entries,
-                scan_id: worktree.scan_id,
-                is_last_update: worktree.is_complete,
-            };
-            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
-                session.peer.send(session.connection_id, update.clone())?;
+                // Stream this worktree's entries.
+                let message = proto::UpdateWorktree {
+                    project_id: project.id.to_proto(),
+                    worktree_id: worktree.id,
+                    abs_path: worktree.abs_path.clone(),
+                    root_name: worktree.root_name,
+                    updated_entries: worktree.updated_entries,
+                    removed_entries: worktree.removed_entries,
+                    scan_id: worktree.scan_id,
+                    is_last_update: worktree.is_complete,
+                };
+                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
+                    session.peer.send(session.connection_id, update.clone())?;
+                }
+
+                // Stream this worktree's diagnostics.
+                for summary in worktree.diagnostic_summaries {
+                    session.peer.send(
+                        session.connection_id,
+                        proto::UpdateDiagnosticSummary {
+                            project_id: project.id.to_proto(),
+                            worktree_id: worktree.id,
+                            summary: Some(summary),
+                        },
+                    )?;
+                }
             }
 
-            // Stream this worktree's diagnostics.
-            for summary in worktree.diagnostic_summaries {
+            for language_server in &project.language_servers {
                 session.peer.send(
                     session.connection_id,
-                    proto::UpdateDiagnosticSummary {
+                    proto::UpdateLanguageServer {
                         project_id: project.id.to_proto(),
-                        worktree_id: worktree.id,
-                        summary: Some(summary),
+                        language_server_id: language_server.id,
+                        variant: Some(
+                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
+                                proto::LspDiskBasedDiagnosticsUpdated {},
+                            ),
+                        ),
                     },
                 )?;
             }
         }
-
-        for language_server in &project.language_servers {
-            session.peer.send(
-                session.connection_id,
-                proto::UpdateLanguageServer {
-                    project_id: project.id.to_proto(),
-                    language_server_id: language_server.id,
-                    variant: Some(
-                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
-                            proto::LspDiskBasedDiagnosticsUpdated {},
-                        ),
-                    ),
-                },
-            )?;
-        }
     }
 
     update_user_contacts(session.user_id, &session).await?;

crates/project/src/project.rs 🔗

@@ -1081,6 +1081,17 @@ impl Project {
         cx: &mut ModelContext<Self>,
     ) -> Result<()> {
         self.set_collaborators_from_proto(message.collaborators, cx)?;
+        for worktree in self.worktrees.iter() {
+            if let Some(worktree) = worktree.upgrade(&cx) {
+                worktree.update(cx, |worktree, _| {
+                    if let Some(worktree) = worktree.as_local_mut() {
+                        worktree.reshare()
+                    } else {
+                        Ok(())
+                    }
+                })?;
+            }
+        }
         Ok(())
     }
 

crates/project/src/worktree.rs 🔗

@@ -166,6 +166,7 @@ enum ScanState {
 struct ShareState {
     project_id: u64,
     snapshots_tx: watch::Sender<LocalSnapshot>,
+    reshared: watch::Sender<()>,
     _maintain_remote_snapshot: Task<Option<()>>,
 }
 
@@ -967,9 +968,11 @@ impl LocalWorktree {
         let (share_tx, share_rx) = oneshot::channel();
 
         if self.share.is_some() {
-            let _ = share_tx.send(Ok(()));
+            let _ = share_tx.send(());
         } else {
             let (snapshots_tx, mut snapshots_rx) = watch::channel_with(self.snapshot());
+            let (reshared_tx, mut reshared_rx) = watch::channel();
+            let _ = reshared_rx.try_recv();
             let worktree_id = cx.model_id() as u64;
 
             for (path, summary) in self.diagnostic_summaries.iter() {
@@ -982,47 +985,48 @@ impl LocalWorktree {
                 }
             }
 
-            let maintain_remote_snapshot = cx.background().spawn({
-                let rpc = self.client.clone();
+            let _maintain_remote_snapshot = cx.background().spawn({
+                let client = self.client.clone();
                 async move {
-                    let mut prev_snapshot = match snapshots_rx.recv().await {
-                        Some(snapshot) => {
-                            let update = proto::UpdateWorktree {
-                                project_id,
-                                worktree_id,
-                                abs_path: snapshot.abs_path().to_string_lossy().into(),
-                                root_name: snapshot.root_name().to_string(),
-                                updated_entries: snapshot
-                                    .entries_by_path
-                                    .iter()
-                                    .map(Into::into)
-                                    .collect(),
-                                removed_entries: Default::default(),
-                                scan_id: snapshot.scan_id as u64,
-                                is_last_update: true,
-                            };
-                            if let Err(error) = send_worktree_update(&rpc, update).await {
-                                let _ = share_tx.send(Err(error));
-                                return Err(anyhow!("failed to send initial update worktree"));
-                            } else {
-                                let _ = share_tx.send(Ok(()));
-                                snapshot
+                    let mut share_tx = Some(share_tx);
+                    let mut prev_snapshot = LocalSnapshot {
+                        ignores_by_parent_abs_path: Default::default(),
+                        git_repositories: Default::default(),
+                        removed_entry_ids: Default::default(),
+                        next_entry_id: Default::default(),
+                        snapshot: Snapshot {
+                            id: WorktreeId(worktree_id as usize),
+                            abs_path: Path::new("").into(),
+                            root_name: Default::default(),
+                            root_char_bag: Default::default(),
+                            entries_by_path: Default::default(),
+                            entries_by_id: Default::default(),
+                            scan_id: 0,
+                            is_complete: true,
+                        },
+                    };
+                    while let Some(snapshot) = snapshots_rx.recv().await {
+                        #[cfg(any(test, feature = "test-support"))]
+                        const MAX_CHUNK_SIZE: usize = 2;
+                        #[cfg(not(any(test, feature = "test-support")))]
+                        const MAX_CHUNK_SIZE: usize = 256;
+
+                        let update =
+                            snapshot.build_update(&prev_snapshot, project_id, worktree_id, true);
+                        for update in proto::split_worktree_update(update, MAX_CHUNK_SIZE) {
+                            while let Err(error) = client.request(update.clone()).await {
+                                log::error!("failed to send worktree update: {}", error);
+                                log::info!("waiting for worktree to be reshared");
+                                if reshared_rx.next().await.is_none() {
+                                    return Ok(());
+                                }
                             }
                         }
-                        None => {
-                            share_tx
-                                .send(Err(anyhow!("worktree dropped before share completed")))
-                                .ok();
-                            return Err(anyhow!("failed to send initial update worktree"));
+
+                        if let Some(share_tx) = share_tx.take() {
+                            let _ = share_tx.send(());
                         }
-                    };
 
-                    while let Some(snapshot) = snapshots_rx.recv().await {
-                        send_worktree_update(
-                            &rpc,
-                            snapshot.build_update(&prev_snapshot, project_id, worktree_id, true),
-                        )
-                        .await?;
                         prev_snapshot = snapshot;
                     }
 
@@ -1034,21 +1038,28 @@ impl LocalWorktree {
             self.share = Some(ShareState {
                 project_id,
                 snapshots_tx,
-                _maintain_remote_snapshot: maintain_remote_snapshot,
+                reshared: reshared_tx,
+                _maintain_remote_snapshot,
             });
         }
 
-        cx.foreground().spawn(async move {
-            share_rx
-                .await
-                .unwrap_or_else(|_| Err(anyhow!("share ended")))
-        })
+        cx.foreground()
+            .spawn(async move { share_rx.await.map_err(|_| anyhow!("share ended")) })
     }
 
     pub fn unshare(&mut self) {
         self.share.take();
     }
 
+    pub fn reshare(&mut self) -> Result<()> {
+        let share = self
+            .share
+            .as_mut()
+            .ok_or_else(|| anyhow!("can't reshare a worktree that wasn't shared"))?;
+        *share.reshared.borrow_mut() = ();
+        Ok(())
+    }
+
     pub fn is_shared(&self) -> bool {
         self.share.is_some()
     }
@@ -2936,19 +2947,6 @@ impl<'a> TryFrom<(&'a CharBag, proto::Entry)> for Entry {
     }
 }
 
-async fn send_worktree_update(client: &Arc<Client>, update: proto::UpdateWorktree) -> Result<()> {
-    #[cfg(any(test, feature = "test-support"))]
-    const MAX_CHUNK_SIZE: usize = 2;
-    #[cfg(not(any(test, feature = "test-support")))]
-    const MAX_CHUNK_SIZE: usize = 256;
-
-    for update in proto::split_worktree_update(update, MAX_CHUNK_SIZE) {
-        client.request(update).await?;
-    }
-
-    Ok(())
-}
-
 #[cfg(test)]
 mod tests {
     use super::*;

crates/rpc/proto/zed.proto 🔗

@@ -172,7 +172,7 @@ message RejoinRoom {
 }
 
 message RejoinProject {
-    uint64 project_id = 1;
+    uint64 id = 1;
     repeated RejoinWorktree worktrees = 2;
 }