WIP

Antonio Scandurra created

Change summary

crates/client/src/client.rs                                                   |   4 
crates/collab/k8s/manifest.template.yml                                       |   2 
crates/collab/migrations.sqlite/20221109000000_test_schema.sql                |   5 
crates/collab/migrations/20221214144346_change_epoch_from_uuid_to_integer.sql |  13 
crates/collab/src/db.rs                                                       | 166 
crates/collab/src/db/server.rs                                                |  15 
crates/collab/src/integration_tests.rs                                        |  24 
crates/collab/src/lib.rs                                                      |   1 
crates/collab/src/main.rs                                                     |   6 
crates/collab/src/rpc.rs                                                      |  41 
crates/rpc/src/peer.rs                                                        |  19 
11 files changed, 200 insertions(+), 96 deletions(-)

Detailed changes

crates/client/src/client.rs 🔗

@@ -333,14 +333,14 @@ impl Client {
     }
 
     #[cfg(any(test, feature = "test-support"))]
-    pub fn tear_down(&self) {
+    pub fn teardown(&self) {
         let mut state = self.state.write();
         state._reconnect_task.take();
         state.message_handlers.clear();
         state.models_by_message_type.clear();
         state.entities_by_type_and_remote_id.clear();
         state.entity_id_extractors.clear();
-        self.peer.reset();
+        self.peer.teardown();
     }
 
     #[cfg(any(test, feature = "test-support"))]

crates/collab/k8s/manifest.template.yml 🔗

@@ -99,6 +99,8 @@ spec:
               value: ${RUST_LOG}
             - name: LOG_JSON
               value: "true"
+            - name: ZED_ENVIRONMENT
+              value: ${ZED_ENVIRONMENT}
           securityContext:
             capabilities:
               # FIXME - Switch to the more restrictive `PERFMON` capability.

crates/collab/migrations.sqlite/20221109000000_test_schema.sql 🔗

@@ -134,3 +134,8 @@ CREATE INDEX "index_room_participants_on_answering_connection_epoch" ON "room_pa
 CREATE INDEX "index_room_participants_on_calling_connection_epoch" ON "room_participants" ("calling_connection_epoch");
 CREATE INDEX "index_room_participants_on_answering_connection_id" ON "room_participants" ("answering_connection_id");
 CREATE UNIQUE INDEX "index_room_participants_on_answering_connection_id_and_answering_connection_epoch" ON "room_participants" ("answering_connection_id", "answering_connection_epoch");
+
+CREATE TABLE "servers" (
+    "epoch" INTEGER PRIMARY KEY AUTOINCREMENT,
+    "environment" VARCHAR NOT NULL
+);

crates/collab/migrations/20221214144346_change_epoch_from_uuid_to_integer.sql 🔗

@@ -1,9 +1,14 @@
 ALTER TABLE "projects"
-    ALTER COLUMN "host_connection_epoch" TYPE INTEGER USING 0;
+    ALTER COLUMN "host_connection_epoch" TYPE INTEGER USING -1;
 
 ALTER TABLE "project_collaborators"
-    ALTER COLUMN "connection_epoch" TYPE INTEGER USING 0;
+    ALTER COLUMN "connection_epoch" TYPE INTEGER USING -1;
 
 ALTER TABLE "room_participants"
-    ALTER COLUMN "answering_connection_epoch" TYPE INTEGER USING 0,
-    ALTER COLUMN "calling_connection_epoch" TYPE INTEGER USING 0;
+    ALTER COLUMN "answering_connection_epoch" TYPE INTEGER USING -1,
+    ALTER COLUMN "calling_connection_epoch" TYPE INTEGER USING -1;
+
+CREATE TABLE "servers" (
+    "epoch" SERIAL PRIMARY KEY,
+    "environment" VARCHAR NOT NULL
+);

crates/collab/src/db.rs 🔗

@@ -5,6 +5,7 @@ mod project;
 mod project_collaborator;
 mod room;
 mod room_participant;
+mod server;
 mod signup;
 #[cfg(test)]
 mod tests;
@@ -48,7 +49,6 @@ pub struct Database {
     background: Option<std::sync::Arc<gpui::executor::Background>>,
     #[cfg(test)]
     runtime: Option<tokio::runtime::Runtime>,
-    epoch: parking_lot::RwLock<Uuid>,
 }
 
 impl Database {
@@ -61,18 +61,12 @@ impl Database {
             background: None,
             #[cfg(test)]
             runtime: None,
-            epoch: parking_lot::RwLock::new(Uuid::new_v4()),
         })
     }
 
     #[cfg(test)]
     pub fn reset(&self) {
         self.rooms.clear();
-        *self.epoch.write() = Uuid::new_v4();
-    }
-
-    fn epoch(&self) -> Uuid {
-        *self.epoch.read()
     }
 
     pub async fn migrate(
@@ -116,14 +110,39 @@ impl Database {
         Ok(new_migrations)
     }
 
-    pub async fn delete_stale_projects(&self) -> Result<()> {
+    pub async fn create_server(&self, environment: &str) -> Result<ServerEpoch> {
         self.transaction(|tx| async move {
+            let server = server::ActiveModel {
+                environment: ActiveValue::set(environment.into()),
+                ..Default::default()
+            }
+            .insert(&*tx)
+            .await?;
+            Ok(server.epoch)
+        })
+        .await
+    }
+
+    pub async fn delete_stale_projects(
+        &self,
+        new_epoch: ServerEpoch,
+        environment: &str,
+    ) -> Result<()> {
+        self.transaction(|tx| async move {
+            let stale_server_epochs = self
+                .stale_server_epochs(environment, new_epoch, &tx)
+                .await?;
             project_collaborator::Entity::delete_many()
-                .filter(project_collaborator::Column::ConnectionEpoch.ne(self.epoch()))
+                .filter(
+                    project_collaborator::Column::ConnectionEpoch
+                        .is_in(stale_server_epochs.iter().copied()),
+                )
                 .exec(&*tx)
                 .await?;
             project::Entity::delete_many()
-                .filter(project::Column::HostConnectionEpoch.ne(self.epoch()))
+                .filter(
+                    project::Column::HostConnectionEpoch.is_in(stale_server_epochs.iter().copied()),
+                )
                 .exec(&*tx)
                 .await?;
             Ok(())
@@ -131,18 +150,27 @@ impl Database {
         .await
     }
 
-    pub async fn stale_room_ids(&self) -> Result<Vec<RoomId>> {
+    pub async fn stale_room_ids(
+        &self,
+        new_epoch: ServerEpoch,
+        environment: &str,
+    ) -> Result<Vec<RoomId>> {
         self.transaction(|tx| async move {
             #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
             enum QueryAs {
                 RoomId,
             }
 
+            let stale_server_epochs = self
+                .stale_server_epochs(environment, new_epoch, &tx)
+                .await?;
             Ok(room_participant::Entity::find()
                 .select_only()
                 .column(room_participant::Column::RoomId)
                 .distinct()
-                .filter(room_participant::Column::AnsweringConnectionEpoch.ne(self.epoch()))
+                .filter(
+                    room_participant::Column::AnsweringConnectionEpoch.is_in(stale_server_epochs),
+                )
                 .into_values::<_, QueryAs>()
                 .all(&*tx)
                 .await?)
@@ -150,12 +178,16 @@ impl Database {
         .await
     }
 
-    pub async fn refresh_room(&self, room_id: RoomId) -> Result<RoomGuard<RefreshedRoom>> {
+    pub async fn refresh_room(
+        &self,
+        room_id: RoomId,
+        new_epoch: ServerEpoch,
+    ) -> Result<RoomGuard<RefreshedRoom>> {
         self.room_transaction(|tx| async move {
             let stale_participant_filter = Condition::all()
                 .add(room_participant::Column::RoomId.eq(room_id))
                 .add(room_participant::Column::AnsweringConnectionId.is_not_null())
-                .add(room_participant::Column::AnsweringConnectionEpoch.ne(self.epoch()));
+                .add(room_participant::Column::AnsweringConnectionEpoch.ne(new_epoch));
 
             let stale_participant_user_ids = room_participant::Entity::find()
                 .filter(stale_participant_filter.clone())
@@ -199,6 +231,35 @@ impl Database {
         .await
     }
 
+    fn delete_stale_servers(&self, environment: &str, new_epoch: ServerEpoch) -> Result<()> {
+        self.transaction(|tx| async {
+            server::Entity::delete_many().filter(Condition::all().add())
+        })
+        .await
+    }
+
+    async fn stale_server_epochs(
+        &self,
+        environment: &str,
+        new_epoch: ServerEpoch,
+        tx: &DatabaseTransaction,
+    ) -> Result<Vec<ServerEpoch>> {
+        let stale_servers = server::Entity::find()
+            .filter(
+                Condition::all().add(
+                    server::Column::Environment
+                        .eq(environment)
+                        .add(server::Column::Epoch.ne(new_epoch)),
+                ),
+            )
+            .all(&*tx)
+            .await?;
+        Ok(stale_servers
+            .into_iter()
+            .map(|server| server.epoch)
+            .collect())
+    }
+
     // users
 
     pub async fn create_user(
@@ -1643,14 +1704,16 @@ impl Database {
                         Default::default(),
                     )),
                 };
+
+                let answering_connection = ConnectionId {
+                    epoch: answering_connection_epoch as u32,
+                    id: answering_connection_id as u32,
+                };
                 participants.insert(
-                    answering_connection_id,
+                    answering_connection,
                     proto::Participant {
                         user_id: db_participant.user_id.to_proto(),
-                        peer_id: Some(proto::PeerId {
-                            epoch: answering_connection_epoch as u32,
-                            id: answering_connection_id as u32,
-                        }),
+                        peer_id: Some(answering_connection.into()),
                         projects: Default::default(),
                         location: Some(proto::ParticipantLocation { variant: location }),
                     },
@@ -1673,7 +1736,11 @@ impl Database {
 
         while let Some(row) = db_projects.next().await {
             let (db_project, db_worktree) = row?;
-            if let Some(participant) = participants.get_mut(&db_project.host_connection_id) {
+            let host_connection = ConnectionId {
+                epoch: db_project.host_connection_epoch as u32,
+                id: db_project.host_connection_id as u32,
+            };
+            if let Some(participant) = participants.get_mut(&host_connection) {
                 let project = if let Some(project) = participant
                     .projects
                     .iter_mut()
@@ -2309,41 +2376,22 @@ impl Database {
         project_id: ProjectId,
         connection_id: ConnectionId,
     ) -> Result<RoomGuard<HashSet<ConnectionId>>> {
-        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
-        enum QueryAs {
-            Epoch,
-            Id,
-        }
-
-        #[derive(Debug, FromQueryResult)]
-        struct ProjectConnection {
-            epoch: i32,
-            id: i32,
-        }
-
         self.room_transaction(|tx| async move {
             let project = project::Entity::find_by_id(project_id)
                 .one(&*tx)
                 .await?
                 .ok_or_else(|| anyhow!("no such project"))?;
-            let mut db_connections = project_collaborator::Entity::find()
-                .select_only()
-                .column_as(project_collaborator::Column::ConnectionId, QueryAs::Id)
-                .column_as(
-                    project_collaborator::Column::ConnectionEpoch,
-                    QueryAs::Epoch,
-                )
+            let mut participants = project_collaborator::Entity::find()
                 .filter(project_collaborator::Column::ProjectId.eq(project_id))
-                .into_model::<ProjectConnection>()
                 .stream(&*tx)
                 .await?;
 
             let mut connection_ids = HashSet::default();
-            while let Some(connection) = db_connections.next().await {
-                let connection = connection?;
+            while let Some(participant) = participants.next().await {
+                let participant = participant?;
                 connection_ids.insert(ConnectionId {
-                    epoch: connection.epoch as u32,
-                    id: connection.id as u32,
+                    epoch: participant.connection_epoch as u32,
+                    id: participant.connection_id as u32,
                 });
             }
 
@@ -2361,40 +2409,21 @@ impl Database {
         project_id: ProjectId,
         tx: &DatabaseTransaction,
     ) -> Result<Vec<ConnectionId>> {
-        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
-        enum QueryAs {
-            Epoch,
-            Id,
-        }
-
-        #[derive(Debug, FromQueryResult)]
-        struct ProjectConnection {
-            epoch: i32,
-            id: i32,
-        }
-
-        let mut db_guest_connections = project_collaborator::Entity::find()
-            .select_only()
-            .column_as(project_collaborator::Column::ConnectionId, QueryAs::Id)
-            .column_as(
-                project_collaborator::Column::ConnectionEpoch,
-                QueryAs::Epoch,
-            )
+        let mut participants = project_collaborator::Entity::find()
             .filter(
                 project_collaborator::Column::ProjectId
                     .eq(project_id)
                     .and(project_collaborator::Column::IsHost.eq(false)),
             )
-            .into_model::<ProjectConnection>()
             .stream(tx)
             .await?;
 
         let mut guest_connection_ids = Vec::new();
-        while let Some(connection) = db_guest_connections.next().await {
-            let connection = connection?;
+        while let Some(participant) = participants.next().await {
+            let participant = participant?;
             guest_connection_ids.push(ConnectionId {
-                epoch: connection.epoch as u32,
-                id: connection.id as u32,
+                epoch: participant.connection_epoch as u32,
+                id: participant.connection_id as u32,
             });
         }
         Ok(guest_connection_ids)
@@ -2774,6 +2803,7 @@ id_type!(RoomParticipantId);
 id_type!(ProjectId);
 id_type!(ProjectCollaboratorId);
 id_type!(ReplicaId);
+id_type!(ServerEpoch);
 id_type!(SignupId);
 id_type!(UserId);
 

crates/collab/src/db/server.rs 🔗

@@ -0,0 +1,15 @@
+use super::ServerEpoch;
+use sea_orm::entity::prelude::*;
+
+#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
+#[sea_orm(table_name = "servers")]
+pub struct Model {
+    #[sea_orm(primary_key)]
+    pub epoch: ServerEpoch,
+    pub environment: String,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {}
+
+impl ActiveModelBehavior for ActiveModel {}

crates/collab/src/integration_tests.rs 🔗

@@ -608,7 +608,7 @@ async fn test_server_restarts(
     );
 
     // The server is torn down.
-    server.teardown();
+    server.reset().await;
 
     // Users A and B reconnect to the call. User C has troubles reconnecting, so it leaves the room.
     client_c.override_establish_connection(|_, cx| cx.spawn(|_| future::pending()));
@@ -778,7 +778,7 @@ async fn test_server_restarts(
     );
 
     // The server is torn down.
-    server.teardown();
+    server.reset().await;
 
     // Users A and B have troubles reconnecting, so they leave the room.
     client_a.override_establish_connection(|_, cx| cx.spawn(|_| future::pending()));
@@ -6122,7 +6122,7 @@ async fn test_random_collaboration(
             }
             30..=34 => {
                 log::info!("Simulating server restart");
-                server.teardown();
+                server.reset().await;
                 deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
                 server.start().await.unwrap();
                 deterministic.advance_clock(CLEANUP_TIMEOUT);
@@ -6320,7 +6320,13 @@ impl TestServer {
         )
         .unwrap();
         let app_state = Self::build_app_state(&test_db, &live_kit_server).await;
+        let epoch = app_state
+            .db
+            .create_server(&app_state.config.zed_environment)
+            .await
+            .unwrap();
         let server = Server::new(
+            epoch,
             app_state.clone(),
             Executor::Deterministic(deterministic.build_background()),
         );
@@ -6337,9 +6343,15 @@ impl TestServer {
         }
     }
 
-    fn teardown(&self) {
-        self.server.teardown();
+    async fn reset(&self) {
         self.app_state.db.reset();
+        let epoch = self
+            .app_state
+            .db
+            .create_server(&self.app_state.config.zed_environment)
+            .await
+            .unwrap();
+        self.server.reset(epoch);
     }
 
     async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient {
@@ -7251,7 +7263,7 @@ impl TestClient {
 
 impl Drop for TestClient {
     fn drop(&mut self) {
-        self.client.tear_down();
+        self.client.teardown();
     }
 }
 

crates/collab/src/lib.rs 🔗

@@ -97,6 +97,7 @@ pub struct Config {
     pub live_kit_secret: Option<String>,
     pub rust_log: Option<String>,
     pub log_json: Option<bool>,
+    pub zed_environment: String,
 }
 
 #[derive(Default, Deserialize)]

crates/collab/src/main.rs 🔗

@@ -57,7 +57,11 @@ async fn main() -> Result<()> {
             let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
                 .expect("failed to bind TCP listener");
 
-            let rpc_server = collab::rpc::Server::new(state.clone(), Executor::Production);
+            let epoch = state
+                .db
+                .create_server(&state.config.zed_environment)
+                .await?;
+            let rpc_server = collab::rpc::Server::new(epoch, state.clone(), Executor::Production);
             rpc_server.start().await?;
 
             let app = collab::api::routes(rpc_server.clone(), state.clone())

crates/collab/src/rpc.rs 🔗

@@ -2,7 +2,7 @@ mod connection_pool;
 
 use crate::{
     auth,
-    db::{self, Database, ProjectId, RoomId, User, UserId},
+    db::{self, Database, ProjectId, RoomId, ServerEpoch, User, UserId},
     executor::Executor,
     AppState, Result,
 };
@@ -138,6 +138,7 @@ impl Deref for DbHandle {
 }
 
 pub struct Server {
+    epoch: parking_lot::Mutex<ServerEpoch>,
     peer: Arc<Peer>,
     pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
     app_state: Arc<AppState>,
@@ -168,9 +169,10 @@ where
 }
 
 impl Server {
-    pub fn new(app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
+    pub fn new(epoch: ServerEpoch, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
         let mut server = Self {
-            peer: Peer::new(0),
+            epoch: parking_lot::Mutex::new(epoch),
+            peer: Peer::new(epoch.0 as u32),
             app_state,
             executor,
             connection_pool: Default::default(),
@@ -239,7 +241,8 @@ impl Server {
     }
 
     pub async fn start(&self) -> Result<()> {
-        let db = self.app_state.db.clone();
+        let epoch = *self.epoch.lock();
+        let app_state = self.app_state.clone();
         let peer = self.peer.clone();
         let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
         let pool = self.connection_pool.clone();
@@ -249,7 +252,10 @@ impl Server {
         let span_enter = span.enter();
 
         tracing::info!("begin deleting stale projects");
-        self.app_state.db.delete_stale_projects().await?;
+        app_state
+            .db
+            .delete_stale_projects(epoch, &app_state.config.zed_environment)
+            .await?;
         tracing::info!("finish deleting stale projects");
 
         drop(span_enter);
@@ -258,7 +264,12 @@ impl Server {
                 tracing::info!("waiting for cleanup timeout");
                 timeout.await;
                 tracing::info!("cleanup timeout expired, retrieving stale rooms");
-                if let Some(room_ids) = db.stale_room_ids().await.trace_err() {
+                if let Some(room_ids) = app_state
+                    .db
+                    .stale_room_ids(epoch, &app_state.config.zed_environment)
+                    .await
+                    .trace_err()
+                {
                     tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
                     for room_id in room_ids {
                         let mut contacts_to_update = HashSet::default();
@@ -266,7 +277,9 @@ impl Server {
                         let mut live_kit_room = String::new();
                         let mut delete_live_kit_room = false;
 
-                        if let Ok(mut refreshed_room) = db.refresh_room(room_id).await {
+                        if let Ok(mut refreshed_room) =
+                            app_state.db.refresh_room(room_id, epoch).await
+                        {
                             tracing::info!(
                                 room_id = room_id.0,
                                 new_participant_count = refreshed_room.room.participants.len(),
@@ -299,8 +312,8 @@ impl Server {
                         }
 
                         for user_id in contacts_to_update {
-                            let busy = db.is_user_busy(user_id).await.trace_err();
-                            let contacts = db.get_contacts(user_id).await.trace_err();
+                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
+                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
                             if let Some((busy, contacts)) = busy.zip(contacts) {
                                 let pool = pool.lock();
                                 let updated_contact = contact_for_user(user_id, false, busy, &pool);
@@ -345,11 +358,18 @@ impl Server {
     }
 
     pub fn teardown(&self) {
-        self.peer.reset();
+        self.peer.teardown();
         self.connection_pool.lock().reset();
         let _ = self.teardown.send(());
     }
 
+    #[cfg(test)]
+    pub fn reset(&self, epoch: ServerEpoch) {
+        self.teardown();
+        *self.epoch.lock() = epoch;
+        self.peer.reset(epoch.0 as u32);
+    }
+
     fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
     where
         F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
@@ -1474,6 +1494,7 @@ async fn update_buffer(
         .project_connection_ids(project_id, session.connection_id)
         .await?;
 
+    dbg!(session.connection_id, &*project_connection_ids);
     broadcast(
         session.connection_id,
         project_connection_ids.iter().copied(),

crates/rpc/src/peer.rs 🔗

@@ -97,7 +97,7 @@ impl<T: RequestMessage> TypedEnvelope<T> {
 }
 
 pub struct Peer {
-    epoch: u32,
+    epoch: AtomicU32,
     pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
     next_connection_id: AtomicU32,
 }
@@ -120,12 +120,16 @@ pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(5);
 impl Peer {
     pub fn new(epoch: u32) -> Arc<Self> {
         Arc::new(Self {
-            epoch,
+            epoch: AtomicU32::new(epoch),
             connections: Default::default(),
             next_connection_id: Default::default(),
         })
     }
 
+    pub fn epoch(&self) -> u32 {
+        self.epoch.load(SeqCst)
+    }
+
     #[instrument(skip_all)]
     pub fn add_connection<F, Fut, Out>(
         self: &Arc<Self>,
@@ -153,7 +157,7 @@ impl Peer {
         let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
 
         let connection_id = ConnectionId {
-            epoch: self.epoch,
+            epoch: self.epoch.load(SeqCst),
             id: self.next_connection_id.fetch_add(1, SeqCst),
         };
         let connection_state = ConnectionState {
@@ -352,9 +356,14 @@ impl Peer {
         self.connections.write().remove(&connection_id);
     }
 
-    pub fn reset(&self) {
-        self.connections.write().clear();
+    pub fn reset(&self, epoch: u32) {
+        self.teardown();
         self.next_connection_id.store(0, SeqCst);
+        self.epoch.store(epoch, SeqCst);
+    }
+
+    pub fn teardown(&self) {
+        self.connections.write().clear();
     }
 
     pub fn request<T: RequestMessage>(