Merge pull request #1950 from zed-industries/reconnect-to-room

Antonio Scandurra created

Automatically re-join call when client connection drops

Change summary

crates/call/src/participant.rs                                                       |  10 
crates/call/src/room.rs                                                              | 116 
crates/collab/migrations.sqlite/20221109000000_test_schema.sql                       |  19 
crates/collab/migrations/20221207165001_add_connection_lost_to_room_participants.sql |   7 
crates/collab/src/db.rs                                                              |  95 
crates/collab/src/db/room_participant.rs                                             |   1 
crates/collab/src/executor.rs                                                        |  36 
crates/collab/src/integration_tests.rs                                               | 124 
crates/collab/src/lib.rs                                                             |   1 
crates/collab/src/rpc.rs                                                             | 194 
10 files changed, 419 insertions(+), 184 deletions(-)

Detailed changes

crates/call/src/participant.rs 🔗

@@ -4,7 +4,7 @@ use collections::HashMap;
 use gpui::WeakModelHandle;
 pub use live_kit_client::Frame;
 use project::Project;
-use std::sync::Arc;
+use std::{fmt, sync::Arc};
 
 #[derive(Copy, Clone, Debug, Eq, PartialEq)]
 pub enum ParticipantLocation {
@@ -36,7 +36,7 @@ pub struct LocalParticipant {
     pub active_project: Option<WeakModelHandle<Project>>,
 }
 
-#[derive(Clone)]
+#[derive(Clone, Debug)]
 pub struct RemoteParticipant {
     pub user: Arc<User>,
     pub projects: Vec<proto::ParticipantProject>,
@@ -49,6 +49,12 @@ pub struct RemoteVideoTrack {
     pub(crate) live_kit_track: Arc<live_kit_client::RemoteVideoTrack>,
 }
 
+impl fmt::Debug for RemoteVideoTrack {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.debug_struct("RemoteVideoTrack").finish()
+    }
+}
+
 impl RemoteVideoTrack {
     pub fn frames(&self) -> async_broadcast::Receiver<Frame> {
         self.live_kit_track.frames()

crates/call/src/room.rs 🔗

@@ -5,14 +5,18 @@ use crate::{
 use anyhow::{anyhow, Result};
 use client::{proto, Client, PeerId, TypedEnvelope, User, UserStore};
 use collections::{BTreeMap, HashSet};
-use futures::StreamExt;
-use gpui::{AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task};
+use futures::{FutureExt, StreamExt};
+use gpui::{
+    AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task, WeakModelHandle,
+};
 use live_kit_client::{LocalTrackPublication, LocalVideoTrack, RemoteVideoTrackUpdate};
 use postage::stream::Stream;
 use project::Project;
-use std::{mem, sync::Arc};
+use std::{mem, sync::Arc, time::Duration};
 use util::{post_inc, ResultExt};
 
+pub const RECONNECTION_TIMEOUT: Duration = client::RECEIVE_TIMEOUT;
+
 #[derive(Clone, Debug, PartialEq, Eq)]
 pub enum Event {
     ParticipantLocationChanged {
@@ -46,6 +50,7 @@ pub struct Room {
     user_store: ModelHandle<UserStore>,
     subscriptions: Vec<client::Subscription>,
     pending_room_update: Option<Task<()>>,
+    _maintain_connection: Task<Result<()>>,
 }
 
 impl Entity for Room {
@@ -66,21 +71,6 @@ impl Room {
         user_store: ModelHandle<UserStore>,
         cx: &mut ModelContext<Self>,
     ) -> Self {
-        let mut client_status = client.status();
-        cx.spawn_weak(|this, mut cx| async move {
-            let is_connected = client_status
-                .next()
-                .await
-                .map_or(false, |s| s.is_connected());
-            // Even if we're initially connected, any future change of the status means we momentarily disconnected.
-            if !is_connected || client_status.next().await.is_some() {
-                if let Some(this) = this.upgrade(&cx) {
-                    let _ = this.update(&mut cx, |this, cx| this.leave(cx));
-                }
-            }
-        })
-        .detach();
-
         let live_kit_room = if let Some(connection_info) = live_kit_connection_info {
             let room = live_kit_client::Room::new();
             let mut status = room.status();
@@ -131,6 +121,9 @@ impl Room {
             None
         };
 
+        let _maintain_connection =
+            cx.spawn_weak(|this, cx| Self::maintain_connection(this, client.clone(), cx));
+
         Self {
             id,
             live_kit: live_kit_room,
@@ -145,6 +138,7 @@ impl Room {
             pending_room_update: None,
             client,
             user_store,
+            _maintain_connection,
         }
     }
 
@@ -245,6 +239,83 @@ impl Room {
         Ok(())
     }
 
+    async fn maintain_connection(
+        this: WeakModelHandle<Self>,
+        client: Arc<Client>,
+        mut cx: AsyncAppContext,
+    ) -> Result<()> {
+        let mut client_status = client.status();
+        loop {
+            let is_connected = client_status
+                .next()
+                .await
+                .map_or(false, |s| s.is_connected());
+            // Even if we're initially connected, any future change of the status means we momentarily disconnected.
+            if !is_connected || client_status.next().await.is_some() {
+                let room_id = this
+                    .upgrade(&cx)
+                    .ok_or_else(|| anyhow!("room was dropped"))?
+                    .update(&mut cx, |this, cx| {
+                        this.status = RoomStatus::Rejoining;
+                        cx.notify();
+                        this.id
+                    });
+
+                // Wait for client to re-establish a connection to the server.
+                let mut reconnection_timeout = cx.background().timer(RECONNECTION_TIMEOUT).fuse();
+                let client_reconnection = async {
+                    loop {
+                        if let Some(status) = client_status.next().await {
+                            if status.is_connected() {
+                                return true;
+                            }
+                        } else {
+                            return false;
+                        }
+                    }
+                }
+                .fuse();
+                futures::pin_mut!(client_reconnection);
+
+                futures::select_biased! {
+                    reconnected = client_reconnection => {
+                        if reconnected {
+                            // Client managed to reconnect to the server. Now attempt to join the room.
+                            let rejoin_room = async {
+                                let response = client.request(proto::JoinRoom { id: room_id }).await?;
+                                let room_proto = response.room.ok_or_else(|| anyhow!("invalid room"))?;
+                                this.upgrade(&cx)
+                                    .ok_or_else(|| anyhow!("room was dropped"))?
+                                    .update(&mut cx, |this, cx| {
+                                        this.status = RoomStatus::Online;
+                                        this.apply_room_update(room_proto, cx)
+                                    })?;
+                                anyhow::Ok(())
+                            };
+
+                            // If we successfully joined the room, go back around the loop
+                            // waiting for future connection status changes.
+                            if rejoin_room.await.log_err().is_some() {
+                                continue;
+                            }
+                        }
+                    }
+                    _ = reconnection_timeout => {}
+                }
+
+                // The client failed to re-establish a connection to the server
+                // or an error occurred while trying to re-join the room. Either way
+                // we leave the room and return an error.
+                if let Some(this) = this.upgrade(&cx) {
+                    let _ = this.update(&mut cx, |this, cx| this.leave(cx));
+                }
+                return Err(anyhow!(
+                    "can't reconnect to room: client failed to re-establish connection"
+                ));
+            }
+        }
+    }
+
     pub fn id(&self) -> u64 {
         self.id
     }
@@ -325,9 +396,11 @@ impl Room {
                 }
 
                 if let Some(participants) = remote_participants.log_err() {
+                    let mut participant_peer_ids = HashSet::default();
                     for (participant, user) in room.participants.into_iter().zip(participants) {
                         let peer_id = PeerId(participant.peer_id);
                         this.participant_user_ids.insert(participant.user_id);
+                        participant_peer_ids.insert(peer_id);
 
                         let old_projects = this
                             .remote_participants
@@ -394,8 +467,8 @@ impl Room {
                         }
                     }
 
-                    this.remote_participants.retain(|_, participant| {
-                        if this.participant_user_ids.contains(&participant.user.id) {
+                    this.remote_participants.retain(|peer_id, participant| {
+                        if participant_peer_ids.contains(peer_id) {
                             true
                         } else {
                             for project in &participant.projects {
@@ -477,10 +550,12 @@ impl Room {
         {
             for participant in self.remote_participants.values() {
                 assert!(self.participant_user_ids.contains(&participant.user.id));
+                assert_ne!(participant.user.id, self.client.user_id().unwrap());
             }
 
             for participant in &self.pending_participants {
                 assert!(self.participant_user_ids.contains(&participant.id));
+                assert_ne!(participant.id, self.client.user_id().unwrap());
             }
 
             assert_eq!(
@@ -751,6 +826,7 @@ impl Default for ScreenTrack {
 #[derive(Copy, Clone, PartialEq, Eq)]
 pub enum RoomStatus {
     Online,
+    Rejoining,
     Offline,
 }
 

crates/collab/migrations.sqlite/20221109000000_test_schema.sql 🔗

@@ -1,5 +1,5 @@
 CREATE TABLE "users" (
-    "id" INTEGER PRIMARY KEY,
+    "id" INTEGER PRIMARY KEY AUTOINCREMENT,
     "github_login" VARCHAR,
     "admin" BOOLEAN,
     "email_address" VARCHAR(255) DEFAULT NULL,
@@ -17,14 +17,14 @@ CREATE INDEX "index_users_on_email_address" ON "users" ("email_address");
 CREATE INDEX "index_users_on_github_user_id" ON "users" ("github_user_id");
 
 CREATE TABLE "access_tokens" (
-    "id" INTEGER PRIMARY KEY,
+    "id" INTEGER PRIMARY KEY AUTOINCREMENT,
     "user_id" INTEGER REFERENCES users (id),
     "hash" VARCHAR(128)
 );
 CREATE INDEX "index_access_tokens_user_id" ON "access_tokens" ("user_id");
 
 CREATE TABLE "contacts" (
-    "id" INTEGER PRIMARY KEY,
+    "id" INTEGER PRIMARY KEY AUTOINCREMENT,
     "user_id_a" INTEGER REFERENCES users (id) NOT NULL,
     "user_id_b" INTEGER REFERENCES users (id) NOT NULL,
     "a_to_b" BOOLEAN NOT NULL,
@@ -35,12 +35,12 @@ CREATE UNIQUE INDEX "index_contacts_user_ids" ON "contacts" ("user_id_a", "user_
 CREATE INDEX "index_contacts_user_id_b" ON "contacts" ("user_id_b");
 
 CREATE TABLE "rooms" (
-    "id" INTEGER PRIMARY KEY,
+    "id" INTEGER PRIMARY KEY AUTOINCREMENT,
     "live_kit_room" VARCHAR NOT NULL
 );
 
 CREATE TABLE "projects" (
-    "id" INTEGER PRIMARY KEY,
+    "id" INTEGER PRIMARY KEY AUTOINCREMENT,
     "room_id" INTEGER REFERENCES rooms (id) NOT NULL,
     "host_user_id" INTEGER REFERENCES users (id) NOT NULL,
     "host_connection_id" INTEGER NOT NULL,
@@ -99,7 +99,7 @@ CREATE TABLE "language_servers" (
 CREATE INDEX "index_language_servers_on_project_id" ON "language_servers" ("project_id");
 
 CREATE TABLE "project_collaborators" (
-    "id" INTEGER PRIMARY KEY,
+    "id" INTEGER PRIMARY KEY AUTOINCREMENT,
     "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE,
     "connection_id" INTEGER NOT NULL,
     "connection_epoch" TEXT NOT NULL,
@@ -110,13 +110,16 @@ CREATE TABLE "project_collaborators" (
 CREATE INDEX "index_project_collaborators_on_project_id" ON "project_collaborators" ("project_id");
 CREATE UNIQUE INDEX "index_project_collaborators_on_project_id_and_replica_id" ON "project_collaborators" ("project_id", "replica_id");
 CREATE INDEX "index_project_collaborators_on_connection_epoch" ON "project_collaborators" ("connection_epoch");
+CREATE INDEX "index_project_collaborators_on_connection_id" ON "project_collaborators" ("connection_id");
+CREATE UNIQUE INDEX "index_project_collaborators_on_project_id_connection_id_and_epoch" ON "project_collaborators" ("project_id", "connection_id", "connection_epoch");
 
 CREATE TABLE "room_participants" (
-    "id" INTEGER PRIMARY KEY,
+    "id" INTEGER PRIMARY KEY AUTOINCREMENT,
     "room_id" INTEGER NOT NULL REFERENCES rooms (id),
     "user_id" INTEGER NOT NULL REFERENCES users (id),
     "answering_connection_id" INTEGER,
     "answering_connection_epoch" TEXT,
+    "answering_connection_lost" BOOLEAN NOT NULL,
     "location_kind" INTEGER,
     "location_project_id" INTEGER,
     "initial_project_id" INTEGER,
@@ -127,3 +130,5 @@ CREATE TABLE "room_participants" (
 CREATE UNIQUE INDEX "index_room_participants_on_user_id" ON "room_participants" ("user_id");
 CREATE INDEX "index_room_participants_on_answering_connection_epoch" ON "room_participants" ("answering_connection_epoch");
 CREATE INDEX "index_room_participants_on_calling_connection_epoch" ON "room_participants" ("calling_connection_epoch");
+CREATE INDEX "index_room_participants_on_answering_connection_id" ON "room_participants" ("answering_connection_id");
+CREATE UNIQUE INDEX "index_room_participants_on_answering_connection_id_and_answering_connection_epoch" ON "room_participants" ("answering_connection_id", "answering_connection_epoch");

crates/collab/migrations/20221207165001_add_connection_lost_to_room_participants.sql 🔗

@@ -0,0 +1,7 @@
+ALTER TABLE "room_participants"
+    ADD "answering_connection_lost" BOOLEAN NOT NULL DEFAULT FALSE;
+
+CREATE INDEX "index_project_collaborators_on_connection_id" ON "project_collaborators" ("connection_id");
+CREATE UNIQUE INDEX "index_project_collaborators_on_project_id_connection_id_and_epoch" ON "project_collaborators" ("project_id", "connection_id", "connection_epoch");
+CREATE INDEX "index_room_participants_on_answering_connection_id" ON "room_participants" ("answering_connection_id");
+CREATE UNIQUE INDEX "index_room_participants_on_answering_connection_id_and_answering_connection_epoch" ON "room_participants" ("answering_connection_id", "answering_connection_epoch");

crates/collab/src/db.rs 🔗

@@ -1034,6 +1034,7 @@ impl Database {
                 user_id: ActiveValue::set(user_id),
                 answering_connection_id: ActiveValue::set(Some(connection_id.0 as i32)),
                 answering_connection_epoch: ActiveValue::set(Some(self.epoch)),
+                answering_connection_lost: ActiveValue::set(false),
                 calling_user_id: ActiveValue::set(user_id),
                 calling_connection_id: ActiveValue::set(connection_id.0 as i32),
                 calling_connection_epoch: ActiveValue::set(self.epoch),
@@ -1060,6 +1061,7 @@ impl Database {
             room_participant::ActiveModel {
                 room_id: ActiveValue::set(room_id),
                 user_id: ActiveValue::set(called_user_id),
+                answering_connection_lost: ActiveValue::set(false),
                 calling_user_id: ActiveValue::set(calling_user_id),
                 calling_connection_id: ActiveValue::set(calling_connection_id.0 as i32),
                 calling_connection_epoch: ActiveValue::set(self.epoch),
@@ -1175,11 +1177,16 @@ impl Database {
                     room_participant::Column::RoomId
                         .eq(room_id)
                         .and(room_participant::Column::UserId.eq(user_id))
-                        .and(room_participant::Column::AnsweringConnectionId.is_null()),
+                        .and(
+                            room_participant::Column::AnsweringConnectionId
+                                .is_null()
+                                .or(room_participant::Column::AnsweringConnectionLost.eq(true)),
+                        ),
                 )
                 .set(room_participant::ActiveModel {
                     answering_connection_id: ActiveValue::set(Some(connection_id.0 as i32)),
                     answering_connection_epoch: ActiveValue::set(Some(self.epoch)),
+                    answering_connection_lost: ActiveValue::set(false),
                     ..Default::default()
                 })
                 .exec(&*tx)
@@ -1197,7 +1204,7 @@ impl Database {
     pub async fn leave_room(&self, connection_id: ConnectionId) -> Result<RoomGuard<LeftRoom>> {
         self.room_transaction(|tx| async move {
             let leaving_participant = room_participant::Entity::find()
-                .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0))
+                .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0 as i32))
                 .one(&*tx)
                 .await?;
 
@@ -1240,7 +1247,7 @@ impl Database {
                         project_collaborator::Column::ProjectId,
                         QueryProjectIds::ProjectId,
                     )
-                    .filter(project_collaborator::Column::ConnectionId.eq(connection_id.0))
+                    .filter(project_collaborator::Column::ConnectionId.eq(connection_id.0 as i32))
                     .into_values::<_, QueryProjectIds>()
                     .all(&*tx)
                     .await?;
@@ -1277,7 +1284,7 @@ impl Database {
 
                 // Leave projects.
                 project_collaborator::Entity::delete_many()
-                    .filter(project_collaborator::Column::ConnectionId.eq(connection_id.0))
+                    .filter(project_collaborator::Column::ConnectionId.eq(connection_id.0 as i32))
                     .exec(&*tx)
                     .await?;
 
@@ -1286,7 +1293,7 @@ impl Database {
                     .filter(
                         project::Column::RoomId
                             .eq(room_id)
-                            .and(project::Column::HostConnectionId.eq(connection_id.0)),
+                            .and(project::Column::HostConnectionId.eq(connection_id.0 as i32)),
                     )
                     .exec(&*tx)
                     .await?;
@@ -1344,11 +1351,9 @@ impl Database {
             }
 
             let result = room_participant::Entity::update_many()
-                .filter(
-                    room_participant::Column::RoomId
-                        .eq(room_id)
-                        .and(room_participant::Column::AnsweringConnectionId.eq(connection_id.0)),
-                )
+                .filter(room_participant::Column::RoomId.eq(room_id).and(
+                    room_participant::Column::AnsweringConnectionId.eq(connection_id.0 as i32),
+                ))
                 .set(room_participant::ActiveModel {
                     location_kind: ActiveValue::set(Some(location_kind)),
                     location_project_id: ActiveValue::set(location_project_id),
@@ -1367,6 +1372,66 @@ impl Database {
         .await
     }
 
+    pub async fn connection_lost(
+        &self,
+        connection_id: ConnectionId,
+    ) -> Result<RoomGuard<Vec<LeftProject>>> {
+        self.room_transaction(|tx| async move {
+            let participant = room_participant::Entity::find()
+                .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0 as i32))
+                .one(&*tx)
+                .await?
+                .ok_or_else(|| anyhow!("not a participant in any room"))?;
+            let room_id = participant.room_id;
+
+            room_participant::Entity::update(room_participant::ActiveModel {
+                answering_connection_lost: ActiveValue::set(true),
+                ..participant.into_active_model()
+            })
+            .exec(&*tx)
+            .await?;
+
+            let collaborator_on_projects = project_collaborator::Entity::find()
+                .find_also_related(project::Entity)
+                .filter(project_collaborator::Column::ConnectionId.eq(connection_id.0 as i32))
+                .all(&*tx)
+                .await?;
+            project_collaborator::Entity::delete_many()
+                .filter(project_collaborator::Column::ConnectionId.eq(connection_id.0 as i32))
+                .exec(&*tx)
+                .await?;
+
+            let mut left_projects = Vec::new();
+            for (_, project) in collaborator_on_projects {
+                if let Some(project) = project {
+                    let collaborators = project
+                        .find_related(project_collaborator::Entity)
+                        .all(&*tx)
+                        .await?;
+                    let connection_ids = collaborators
+                        .into_iter()
+                        .map(|collaborator| ConnectionId(collaborator.connection_id as u32))
+                        .collect();
+
+                    left_projects.push(LeftProject {
+                        id: project.id,
+                        host_user_id: project.host_user_id,
+                        host_connection_id: ConnectionId(project.host_connection_id as u32),
+                        connection_ids,
+                    });
+                }
+            }
+
+            project::Entity::delete_many()
+                .filter(project::Column::HostConnectionId.eq(connection_id.0 as i32))
+                .exec(&*tx)
+                .await?;
+
+            Ok((room_id, left_projects))
+        })
+        .await
+    }
+
     fn build_incoming_call(
         room: &proto::Room,
         called_user_id: UserId,
@@ -1514,7 +1579,7 @@ impl Database {
     ) -> Result<RoomGuard<(ProjectId, proto::Room)>> {
         self.room_transaction(|tx| async move {
             let participant = room_participant::Entity::find()
-                .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0))
+                .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0 as i32))
                 .one(&*tx)
                 .await?
                 .ok_or_else(|| anyhow!("could not find participant"))?;
@@ -1600,7 +1665,7 @@ impl Database {
     ) -> Result<RoomGuard<(proto::Room, Vec<ConnectionId>)>> {
         self.room_transaction(|tx| async move {
             let project = project::Entity::find_by_id(project_id)
-                .filter(project::Column::HostConnectionId.eq(connection_id.0))
+                .filter(project::Column::HostConnectionId.eq(connection_id.0 as i32))
                 .one(&*tx)
                 .await?
                 .ok_or_else(|| anyhow!("no such project"))?;
@@ -1654,7 +1719,7 @@ impl Database {
 
             // Ensure the update comes from the host.
             let project = project::Entity::find_by_id(project_id)
-                .filter(project::Column::HostConnectionId.eq(connection_id.0))
+                .filter(project::Column::HostConnectionId.eq(connection_id.0 as i32))
                 .one(&*tx)
                 .await?
                 .ok_or_else(|| anyhow!("no such project"))?;
@@ -1837,7 +1902,7 @@ impl Database {
     ) -> Result<RoomGuard<(Project, ReplicaId)>> {
         self.room_transaction(|tx| async move {
             let participant = room_participant::Entity::find()
-                .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0))
+                .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0 as i32))
                 .one(&*tx)
                 .await?
                 .ok_or_else(|| anyhow!("must join a room first"))?;
@@ -1974,7 +2039,7 @@ impl Database {
                 .filter(
                     project_collaborator::Column::ProjectId
                         .eq(project_id)
-                        .and(project_collaborator::Column::ConnectionId.eq(connection_id.0)),
+                        .and(project_collaborator::Column::ConnectionId.eq(connection_id.0 as i32)),
                 )
                 .exec(&*tx)
                 .await?;

crates/collab/src/db/room_participant.rs 🔗

@@ -10,6 +10,7 @@ pub struct Model {
     pub user_id: UserId,
     pub answering_connection_id: Option<i32>,
     pub answering_connection_epoch: Option<Uuid>,
+    pub answering_connection_lost: bool,
     pub location_kind: Option<i32>,
     pub location_project_id: Option<ProjectId>,
     pub initial_project_id: Option<ProjectId>,

crates/collab/src/executor.rs 🔗

@@ -0,0 +1,36 @@
+use std::{future::Future, time::Duration};
+
+#[derive(Clone)]
+pub enum Executor {
+    Production,
+    #[cfg(test)]
+    Deterministic(std::sync::Arc<gpui::executor::Background>),
+}
+
+impl Executor {
+    pub fn spawn_detached<F>(&self, future: F)
+    where
+        F: 'static + Send + Future<Output = ()>,
+    {
+        match self {
+            Executor::Production => {
+                tokio::spawn(future);
+            }
+            #[cfg(test)]
+            Executor::Deterministic(background) => {
+                background.spawn(future).detach();
+            }
+        }
+    }
+
+    pub fn sleep(&self, duration: Duration) -> impl Future<Output = ()> {
+        let this = self.clone();
+        async move {
+            match this {
+                Executor::Production => tokio::time::sleep(duration).await,
+                #[cfg(test)]
+                Executor::Deterministic(background) => background.timer(duration).await,
+            }
+        }
+    }
+}

crates/collab/src/integration_tests.rs 🔗

@@ -1,9 +1,9 @@
 use crate::{
     db::{self, NewUserParams, TestDb, UserId},
-    rpc::{Executor, Server},
+    executor::Executor,
+    rpc::{Server, RECONNECT_TIMEOUT},
     AppState,
 };
-
 use ::rpc::Peer;
 use anyhow::anyhow;
 use call::{room, ActiveCall, ParticipantLocation, Room};
@@ -17,7 +17,7 @@ use editor::{
     ToggleCodeActions, Undo,
 };
 use fs::{FakeFs, Fs as _, HomeDir, LineEnding};
-use futures::{channel::oneshot, Future, StreamExt as _};
+use futures::{channel::oneshot, StreamExt as _};
 use gpui::{
     executor::{self, Deterministic},
     geometry::vector::vec2f,
@@ -45,7 +45,6 @@ use std::{
         atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
         Arc,
     },
-    time::Duration,
 };
 use theme::ThemeRegistry;
 use unindent::Unindent as _;
@@ -366,7 +365,7 @@ async fn test_room_uniqueness(
 }
 
 #[gpui::test(iterations = 10)]
-async fn test_leaving_room_on_disconnection(
+async fn test_disconnecting_from_room(
     deterministic: Arc<Deterministic>,
     cx_a: &mut TestAppContext,
     cx_b: &mut TestAppContext,
@@ -415,9 +414,29 @@ async fn test_leaving_room_on_disconnection(
         }
     );
 
+    // User A automatically reconnects to the room upon disconnection.
+    server.disconnect_client(client_a.peer_id().unwrap());
+    deterministic.advance_clock(RECEIVE_TIMEOUT);
+    deterministic.run_until_parked();
+    assert_eq!(
+        room_participants(&room_a, cx_a),
+        RoomParticipants {
+            remote: vec!["user_b".to_string()],
+            pending: Default::default()
+        }
+    );
+    assert_eq!(
+        room_participants(&room_b, cx_b),
+        RoomParticipants {
+            remote: vec!["user_a".to_string()],
+            pending: Default::default()
+        }
+    );
+
     // When user A disconnects, both client A and B clear their room on the active call.
+    server.forbid_connections();
     server.disconnect_client(client_a.peer_id().unwrap());
-    cx_a.foreground().advance_clock(rpc::RECEIVE_TIMEOUT);
+    deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
     active_call_a.read_with(cx_a, |call, _| assert!(call.room().is_none()));
     active_call_b.read_with(cx_b, |call, _| assert!(call.room().is_none()));
     assert_eq!(
@@ -435,6 +454,10 @@ async fn test_leaving_room_on_disconnection(
         }
     );
 
+    // Allow user A to reconnect to the server.
+    server.allow_connections();
+    deterministic.advance_clock(RECEIVE_TIMEOUT);
+
     // Call user B again from client A.
     active_call_a
         .update(cx_a, |call, cx| {
@@ -558,7 +581,7 @@ async fn test_calls_on_multiple_connections(
 
     // User B disconnects the client that is not on the call. Everything should be fine.
     client_b1.disconnect(&cx_b1.to_async()).unwrap();
-    deterministic.advance_clock(rpc::RECEIVE_TIMEOUT);
+    deterministic.advance_clock(RECEIVE_TIMEOUT);
     client_b1
         .authenticate_and_connect(false, &cx_b1.to_async())
         .await
@@ -617,12 +640,15 @@ async fn test_calls_on_multiple_connections(
     assert!(incoming_call_b2.next().await.unwrap().is_some());
 
     // User A disconnects, causing both connections to stop ringing.
+    server.forbid_connections();
     server.disconnect_client(client_a.peer_id().unwrap());
-    deterministic.advance_clock(rpc::RECEIVE_TIMEOUT);
+    deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
     assert!(incoming_call_b1.next().await.unwrap().is_none());
     assert!(incoming_call_b2.next().await.unwrap().is_none());
 
     // User A reconnects automatically, then calls user B again.
+    server.allow_connections();
+    deterministic.advance_clock(RECEIVE_TIMEOUT);
     active_call_a
         .update(cx_a, |call, cx| {
             call.invite(client_b1.user_id().unwrap(), None, cx)
@@ -637,7 +663,7 @@ async fn test_calls_on_multiple_connections(
     server.forbid_connections();
     server.disconnect_client(client_b1.peer_id().unwrap());
     server.disconnect_client(client_b2.peer_id().unwrap());
-    deterministic.advance_clock(rpc::RECEIVE_TIMEOUT);
+    deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
     active_call_a.read_with(cx_a, |call, _| assert!(call.room().is_none()));
 }
 
@@ -927,8 +953,9 @@ async fn test_host_disconnect(
     assert!(cx_b.is_window_edited(workspace_b.window_id()));
 
     // Drop client A's connection. Collaborators should disappear and the project should not be shown as shared.
+    server.forbid_connections();
     server.disconnect_client(client_a.peer_id().unwrap());
-    deterministic.advance_clock(rpc::RECEIVE_TIMEOUT);
+    deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
     project_a
         .condition(cx_a, |project, _| project.collaborators().is_empty())
         .await;
@@ -951,6 +978,11 @@ async fn test_host_disconnect(
         .unwrap();
     assert!(can_close);
 
+    // Allow client A to reconnect to the server.
+    server.allow_connections();
+    deterministic.advance_clock(RECEIVE_TIMEOUT);
+
+    // Client B calls client A again after they reconnected.
     let active_call_b = cx_b.read(ActiveCall::global);
     active_call_b
         .update(cx_b, |call, cx| {
@@ -971,7 +1003,7 @@ async fn test_host_disconnect(
 
     // Drop client A's connection again. We should still unshare it successfully.
     server.disconnect_client(client_a.peer_id().unwrap());
-    deterministic.advance_clock(rpc::RECEIVE_TIMEOUT);
+    deterministic.advance_clock(RECEIVE_TIMEOUT);
     project_a.read_with(cx_a, |project, _| assert!(!project.is_shared()));
 }
 
@@ -2297,7 +2329,7 @@ async fn test_leaving_project(
     // Simulate connection loss for client C and ensure client A observes client C leaving the project.
     client_c.wait_for_current_user(cx_c).await;
     server.disconnect_client(client_c.peer_id().unwrap());
-    cx_a.foreground().advance_clock(rpc::RECEIVE_TIMEOUT);
+    cx_a.foreground().advance_clock(RECEIVE_TIMEOUT);
     deterministic.run_until_parked();
     project_a.read_with(cx_a, |project, _| {
         assert_eq!(project.collaborators().len(), 0);
@@ -4230,7 +4262,7 @@ async fn test_contacts(
 
     server.disconnect_client(client_c.peer_id().unwrap());
     server.forbid_connections();
-    deterministic.advance_clock(rpc::RECEIVE_TIMEOUT);
+    deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
     assert_eq!(
         contacts(&client_a, cx_a),
         [
@@ -4534,7 +4566,7 @@ async fn test_contacts(
 
     server.forbid_connections();
     server.disconnect_client(client_a.peer_id().unwrap());
-    deterministic.advance_clock(rpc::RECEIVE_TIMEOUT);
+    deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
     assert_eq!(contacts(&client_a, cx_a), []);
     assert_eq!(
         contacts(&client_b, cx_b),
@@ -5630,7 +5662,6 @@ async fn test_random_collaboration(
 
     let mut clients = Vec::new();
     let mut user_ids = Vec::new();
-    let mut peer_ids = Vec::new();
     let mut op_start_signals = Vec::new();
     let mut next_entity_id = 100000;
 
@@ -5657,7 +5688,6 @@ async fn test_random_collaboration(
                 let op_start_signal = futures::channel::mpsc::unbounded();
                 let guest = server.create_client(&mut guest_cx, &guest_username).await;
                 user_ids.push(guest.current_user_id(&guest_cx));
-                peer_ids.push(guest.peer_id().unwrap());
                 op_start_signals.push(op_start_signal.0);
                 clients.push(guest_cx.foreground().spawn(guest.simulate(
                     guest_username.clone(),
@@ -5669,16 +5699,26 @@ async fn test_random_collaboration(
                 log::info!("Added connection for {}", guest_username);
                 operations += 1;
             }
-            20..=29 if clients.len() > 1 => {
+            20..=24 if clients.len() > 1 => {
                 let guest_ix = rng.lock().gen_range(1..clients.len());
-                log::info!("Removing guest {}", user_ids[guest_ix]);
+                log::info!(
+                    "Simulating full disconnection of guest {}",
+                    user_ids[guest_ix]
+                );
                 let removed_guest_id = user_ids.remove(guest_ix);
-                let removed_peer_id = peer_ids.remove(guest_ix);
+                let user_connection_ids = server
+                    .connection_pool
+                    .lock()
+                    .await
+                    .user_connection_ids(removed_guest_id)
+                    .collect::<Vec<_>>();
+                assert_eq!(user_connection_ids.len(), 1);
+                let removed_peer_id = PeerId(user_connection_ids[0].0);
                 let guest = clients.remove(guest_ix);
                 op_start_signals.remove(guest_ix);
                 server.forbid_connections();
                 server.disconnect_client(removed_peer_id);
-                deterministic.advance_clock(RECEIVE_TIMEOUT);
+                deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
                 deterministic.start_waiting();
                 log::info!("Waiting for guest {} to exit...", removed_guest_id);
                 let (guest, mut guest_cx) = guest.await;
@@ -5712,6 +5752,22 @@ async fn test_random_collaboration(
 
                 operations += 1;
             }
+            25..=29 if clients.len() > 1 => {
+                let guest_ix = rng.lock().gen_range(1..clients.len());
+                let user_id = user_ids[guest_ix];
+                log::info!("Simulating temporary disconnection of guest {}", user_id);
+                let user_connection_ids = server
+                    .connection_pool
+                    .lock()
+                    .await
+                    .user_connection_ids(user_id)
+                    .collect::<Vec<_>>();
+                assert_eq!(user_connection_ids.len(), 1);
+                let peer_id = PeerId(user_connection_ids[0].0);
+                server.disconnect_client(peer_id);
+                deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
+                operations += 1;
+            }
             _ if !op_start_signals.is_empty() => {
                 while operations < max_operations && rng.lock().gen_bool(0.7) {
                     op_start_signals
@@ -6000,7 +6056,7 @@ impl TestServer {
                                 client_name,
                                 user,
                                 Some(connection_id_tx),
-                                cx.background(),
+                                Executor::Deterministic(cx.background()),
                             ))
                             .detach();
                         let connection_id = connection_id_rx.await.unwrap();
@@ -6137,6 +6193,7 @@ impl Deref for TestServer {
 impl Drop for TestServer {
     fn drop(&mut self) {
         self.peer.reset();
+        self.server.teardown();
         self.test_live_kit_server.teardown().unwrap();
     }
 }
@@ -6397,11 +6454,14 @@ impl TestClient {
                         .clone()
                 }
             };
-            if let Err(error) = active_call
-                .update(cx, |call, cx| call.share_project(project.clone(), cx))
-                .await
-            {
-                log::error!("{}: error sharing project, {:?}", username, error);
+
+            if active_call.read_with(cx, |call, _| call.room().is_some()) {
+                if let Err(error) = active_call
+                    .update(cx, |call, cx| call.share_project(project.clone(), cx))
+                    .await
+                {
+                    log::error!("{}: error sharing project, {:?}", username, error);
+                }
             }
 
             let buffers = client.buffers.entry(project.clone()).or_default();
@@ -6829,18 +6889,6 @@ impl Drop for TestClient {
     }
 }
 
-impl Executor for Arc<gpui::executor::Background> {
-    type Sleep = gpui::executor::Timer;
-
-    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
-        self.spawn(future).detach();
-    }
-
-    fn sleep(&self, duration: Duration) -> Self::Sleep {
-        self.as_ref().timer(duration)
-    }
-}
-
 #[derive(Debug, Eq, PartialEq)]
 struct RoomParticipants {
     remote: Vec<String>,

crates/collab/src/lib.rs 🔗

@@ -2,6 +2,7 @@ pub mod api;
 pub mod auth;
 pub mod db;
 pub mod env;
+mod executor;
 #[cfg(test)]
 mod integration_tests;
 pub mod rpc;

crates/collab/src/rpc.rs 🔗

@@ -3,6 +3,7 @@ mod connection_pool;
 use crate::{
     auth,
     db::{self, Database, ProjectId, RoomId, User, UserId},
+    executor::Executor,
     AppState, Result,
 };
 use anyhow::anyhow;
@@ -52,13 +53,12 @@ use std::{
     },
     time::Duration,
 };
-use tokio::{
-    sync::{Mutex, MutexGuard},
-    time::Sleep,
-};
+use tokio::sync::{watch, Mutex, MutexGuard};
 use tower::ServiceBuilder;
 use tracing::{info_span, instrument, Instrument};
 
+pub const RECONNECT_TIMEOUT: Duration = rpc::RECEIVE_TIMEOUT;
+
 lazy_static! {
     static ref METRIC_CONNECTIONS: IntGauge =
         register_int_gauge!("connections", "number of connections").unwrap();
@@ -143,17 +143,9 @@ pub struct Server {
     pub(crate) connection_pool: Arc<Mutex<ConnectionPool>>,
     app_state: Arc<AppState>,
     handlers: HashMap<TypeId, MessageHandler>,
+    teardown: watch::Sender<()>,
 }
 
-pub trait Executor: Send + Clone {
-    type Sleep: Send + Future;
-    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
-    fn sleep(&self, duration: Duration) -> Self::Sleep;
-}
-
-#[derive(Clone)]
-pub struct RealExecutor;
-
 pub(crate) struct ConnectionPoolGuard<'a> {
     guard: MutexGuard<'a, ConnectionPool>,
     _not_send: PhantomData<Rc<()>>,
@@ -182,6 +174,7 @@ impl Server {
             app_state,
             connection_pool: Default::default(),
             handlers: Default::default(),
+            teardown: watch::channel(()).0,
         };
 
         server
@@ -244,6 +237,10 @@ impl Server {
         Arc::new(server)
     }
 
+    pub fn teardown(&self) {
+        let _ = self.teardown.send(());
+    }
+
     fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
     where
         F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
@@ -330,29 +327,25 @@ impl Server {
         })
     }
 
-    pub fn handle_connection<E: Executor>(
+    pub fn handle_connection(
         self: &Arc<Self>,
         connection: Connection,
         address: String,
         user: User,
         mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
-        executor: E,
+        executor: Executor,
     ) -> impl Future<Output = Result<()>> {
         let this = self.clone();
         let user_id = user.id;
         let login = user.github_login;
         let span = info_span!("handle connection", %user_id, %login, %address);
+        let teardown = self.teardown.subscribe();
         async move {
             let (connection_id, handle_io, mut incoming_rx) = this
                 .peer
                 .add_connection(connection, {
                     let executor = executor.clone();
-                    move |duration| {
-                        let timer = executor.sleep(duration);
-                        async move {
-                            timer.await;
-                        }
-                    }
+                    move |duration| executor.sleep(duration)
                 });
 
             tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
@@ -452,7 +445,7 @@ impl Server {
 
             drop(foreground_message_handlers);
             tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
-            if let Err(error) = sign_out(session).await {
+            if let Err(error) = sign_out(session, teardown, executor).await {
                 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
             }
 
@@ -543,18 +536,6 @@ impl<'a> Drop for ConnectionPoolGuard<'a> {
     }
 }
 
-impl Executor for RealExecutor {
-    type Sleep = Sleep;
-
-    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
-        tokio::task::spawn(future);
-    }
-
-    fn sleep(&self, duration: Duration) -> Self::Sleep {
-        tokio::time::sleep(duration)
-    }
-}
-
 fn broadcast<F>(
     sender_id: ConnectionId,
     receiver_ids: impl IntoIterator<Item = ConnectionId>,
@@ -636,7 +617,7 @@ pub async fn handle_websocket_request(
         let connection = Connection::new(Box::pin(socket));
         async move {
             server
-                .handle_connection(connection, socket_address, user, None, RealExecutor)
+                .handle_connection(connection, socket_address, user, None, Executor::Production)
                 .await
                 .log_err();
         }
@@ -665,30 +646,48 @@ pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result
     Ok(encoded_metrics)
 }
 
-#[instrument(err)]
-async fn sign_out(session: Session) -> Result<()> {
+#[instrument(err, skip(executor))]
+async fn sign_out(
+    session: Session,
+    mut teardown: watch::Receiver<()>,
+    executor: Executor,
+) -> Result<()> {
     session.peer.disconnect(session.connection_id);
-    let decline_calls = {
-        let mut pool = session.connection_pool().await;
-        pool.remove_connection(session.connection_id)?;
-        let mut connections = pool.user_connection_ids(session.user_id);
-        connections.next().is_none()
-    };
+    session
+        .connection_pool()
+        .await
+        .remove_connection(session.connection_id)?;
 
-    leave_room_for_session(&session).await.trace_err();
-    if decline_calls {
-        if let Some(room) = session
-            .db()
-            .await
-            .decline_call(None, session.user_id)
-            .await
-            .trace_err()
-        {
-            room_updated(&room, &session);
+    if let Some(mut left_projects) = session
+        .db()
+        .await
+        .connection_lost(session.connection_id)
+        .await
+        .trace_err()
+    {
+        for left_project in mem::take(&mut *left_projects) {
+            project_left(&left_project, &session);
         }
     }
 
-    update_user_contacts(session.user_id, &session).await?;
+    futures::select_biased! {
+        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
+            leave_room_for_session(&session).await.trace_err();
+
+            if !session
+                .connection_pool()
+                .await
+                .is_user_online(session.user_id)
+            {
+                let db = session.db().await;
+                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err() {
+                    room_updated(&room, &session);
+                }
+            }
+            update_user_contacts(session.user_id, &session).await?;
+        }
+        _ = teardown.changed().fuse() => {}
+    }
 
     Ok(())
 }
@@ -1118,20 +1117,7 @@ async fn leave_project(request: proto::LeaveProject, session: Session) -> Result
         host_connection_id = %project.host_connection_id,
         "leave project"
     );
-
-    broadcast(
-        sender_id,
-        project.connection_ids.iter().copied(),
-        |conn_id| {
-            session.peer.send(
-                conn_id,
-                proto::RemoveProjectCollaborator {
-                    project_id: project_id.to_proto(),
-                    peer_id: sender_id.0,
-                },
-            )
-        },
-    );
+    project_left(&project, &session);
 
     Ok(())
 }
@@ -1862,40 +1848,7 @@ async fn leave_room_for_session(session: &Session) -> Result<()> {
         contacts_to_update.insert(session.user_id);
 
         for project in left_room.left_projects.values() {
-            for connection_id in &project.connection_ids {
-                if project.host_user_id == session.user_id {
-                    session
-                        .peer
-                        .send(
-                            *connection_id,
-                            proto::UnshareProject {
-                                project_id: project.id.to_proto(),
-                            },
-                        )
-                        .trace_err();
-                } else {
-                    session
-                        .peer
-                        .send(
-                            *connection_id,
-                            proto::RemoveProjectCollaborator {
-                                project_id: project.id.to_proto(),
-                                peer_id: session.connection_id.0,
-                            },
-                        )
-                        .trace_err();
-                }
-            }
-
-            session
-                .peer
-                .send(
-                    session.connection_id,
-                    proto::UnshareProject {
-                        project_id: project.id.to_proto(),
-                    },
-                )
-                .trace_err();
+            project_left(project, session);
         }
 
         room_updated(&left_room.room, &session);
@@ -1935,6 +1888,43 @@ async fn leave_room_for_session(session: &Session) -> Result<()> {
     Ok(())
 }
 
+fn project_left(project: &db::LeftProject, session: &Session) {
+    for connection_id in &project.connection_ids {
+        if project.host_user_id == session.user_id {
+            session
+                .peer
+                .send(
+                    *connection_id,
+                    proto::UnshareProject {
+                        project_id: project.id.to_proto(),
+                    },
+                )
+                .trace_err();
+        } else {
+            session
+                .peer
+                .send(
+                    *connection_id,
+                    proto::RemoveProjectCollaborator {
+                        project_id: project.id.to_proto(),
+                        peer_id: session.connection_id.0,
+                    },
+                )
+                .trace_err();
+        }
+    }
+
+    session
+        .peer
+        .send(
+            session.connection_id,
+            proto::UnshareProject {
+                project_id: project.id.to_proto(),
+            },
+        )
+        .trace_err();
+}
+
 pub trait ResultExt {
     type Ok;