Remove obsolete code from `Store`

Antonio Scandurra created

Change summary

crates/collab/src/db.rs        |  15 ++
crates/collab/src/main.rs      |  53 ---------
crates/collab/src/rpc.rs       |  60 +++++----
crates/collab/src/rpc/store.rs | 205 +----------------------------------
4 files changed, 58 insertions(+), 275 deletions(-)

Detailed changes

crates/collab/src/db.rs 🔗

@@ -1464,6 +1464,21 @@ where
 
     // projects
 
+    pub async fn project_count_excluding_admins(&self) -> Result<usize> {
+        self.transact(|mut tx| async move {
+            Ok(sqlx::query_scalar::<_, i32>(
+                "
+                SELECT COUNT(*)
+                FROM projects, users
+                WHERE projects.host_user_id = users.id AND users.admin IS FALSE
+                ",
+            )
+            .fetch_one(&mut tx)
+            .await? as usize)
+        })
+        .await
+    }
+
     pub async fn share_project(
         &self,
         expected_room_id: RoomId,

crates/collab/src/main.rs 🔗

@@ -9,7 +9,6 @@ mod db_tests;
 #[cfg(test)]
 mod integration_tests;
 
-use crate::rpc::ResultExt as _;
 use anyhow::anyhow;
 use axum::{routing::get, Router};
 use collab::{Error, Result};
@@ -20,9 +19,7 @@ use std::{
     net::{SocketAddr, TcpListener},
     path::{Path, PathBuf},
     sync::Arc,
-    time::Duration,
 };
-use tokio::signal;
 use tracing_log::LogTracer;
 use tracing_subscriber::{filter::EnvFilter, fmt::format::JsonFields, Layer};
 use util::ResultExt;
@@ -129,7 +126,6 @@ async fn main() -> Result<()> {
 
             axum::Server::from_tcp(listener)?
                 .serve(app.into_make_service_with_connect_info::<SocketAddr>())
-                .with_graceful_shutdown(graceful_shutdown(rpc_server, state))
                 .await?;
         }
         _ => {
@@ -174,52 +170,3 @@ pub fn init_tracing(config: &Config) -> Option<()> {
 
     None
 }
-
-async fn graceful_shutdown(rpc_server: Arc<rpc::Server>, state: Arc<AppState>) {
-    let ctrl_c = async {
-        signal::ctrl_c()
-            .await
-            .expect("failed to install Ctrl+C handler");
-    };
-
-    #[cfg(unix)]
-    let terminate = async {
-        signal::unix::signal(signal::unix::SignalKind::terminate())
-            .expect("failed to install signal handler")
-            .recv()
-            .await;
-    };
-
-    #[cfg(not(unix))]
-    let terminate = std::future::pending::<()>();
-
-    tokio::select! {
-        _ = ctrl_c => {},
-        _ = terminate => {},
-    }
-
-    if let Some(live_kit) = state.live_kit_client.as_ref() {
-        let deletions = rpc_server
-            .store()
-            .await
-            .rooms()
-            .values()
-            .map(|room| {
-                let name = room.live_kit_room.clone();
-                async {
-                    live_kit.delete_room(name).await.trace_err();
-                }
-            })
-            .collect::<Vec<_>>();
-
-        tracing::info!("deleting all live-kit rooms");
-        if let Err(_) = tokio::time::timeout(
-            Duration::from_secs(10),
-            futures::future::join_all(deletions),
-        )
-        .await
-        {
-            tracing::error!("timed out waiting for live-kit room deletion");
-        }
-    }
-}

crates/collab/src/rpc.rs 🔗

@@ -49,7 +49,7 @@ use std::{
     },
     time::Duration,
 };
-pub use store::{Store, Worktree};
+pub use store::Store;
 use tokio::{
     sync::{Mutex, MutexGuard},
     time::Sleep,
@@ -437,7 +437,7 @@ impl Server {
         let decline_calls = {
             let mut store = self.store().await;
             store.remove_connection(connection_id)?;
-            let mut connections = store.connection_ids_for_user(user_id);
+            let mut connections = store.user_connection_ids(user_id);
             connections.next().is_none()
         };
 
@@ -470,7 +470,7 @@ impl Server {
             if let Some(code) = &user.invite_code {
                 let store = self.store().await;
                 let invitee_contact = store.contact_for_user(invitee_id, true, false);
-                for connection_id in store.connection_ids_for_user(inviter_id) {
+                for connection_id in store.user_connection_ids(inviter_id) {
                     self.peer.send(
                         connection_id,
                         proto::UpdateContacts {
@@ -495,7 +495,7 @@ impl Server {
         if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
             if let Some(invite_code) = &user.invite_code {
                 let store = self.store().await;
-                for connection_id in store.connection_ids_for_user(user_id) {
+                for connection_id in store.user_connection_ids(user_id) {
                     self.peer.send(
                         connection_id,
                         proto::UpdateInviteInfo {
@@ -582,7 +582,7 @@ impl Server {
                 session.connection_id,
             )
             .await?;
-        for connection_id in self.store().await.connection_ids_for_user(session.user_id) {
+        for connection_id in self.store().await.user_connection_ids(session.user_id) {
             self.peer
                 .send(connection_id, proto::CallCanceled {})
                 .trace_err();
@@ -674,7 +674,7 @@ impl Server {
         {
             let store = self.store().await;
             for canceled_user_id in left_room.canceled_calls_to_user_ids {
-                for connection_id in store.connection_ids_for_user(canceled_user_id) {
+                for connection_id in store.user_connection_ids(canceled_user_id) {
                     self.peer
                         .send(connection_id, proto::CallCanceled {})
                         .trace_err();
@@ -744,7 +744,7 @@ impl Server {
         let mut calls = self
             .store()
             .await
-            .connection_ids_for_user(called_user_id)
+            .user_connection_ids(called_user_id)
             .map(|connection_id| self.peer.request(connection_id, incoming_call.clone()))
             .collect::<FuturesUnordered<_>>();
 
@@ -784,7 +784,7 @@ impl Server {
             .db
             .cancel_call(Some(room_id), session.connection_id, called_user_id)
             .await?;
-        for connection_id in self.store().await.connection_ids_for_user(called_user_id) {
+        for connection_id in self.store().await.user_connection_ids(called_user_id) {
             self.peer
                 .send(connection_id, proto::CallCanceled {})
                 .trace_err();
@@ -807,7 +807,7 @@ impl Server {
             .db
             .decline_call(Some(room_id), session.user_id)
             .await?;
-        for connection_id in self.store().await.connection_ids_for_user(session.user_id) {
+        for connection_id in self.store().await.user_connection_ids(session.user_id) {
             self.peer
                 .send(connection_id, proto::CallCanceled {})
                 .trace_err();
@@ -905,7 +905,7 @@ impl Server {
                 ..
             } = contact
             {
-                for contact_conn_id in store.connection_ids_for_user(contact_user_id) {
+                for contact_conn_id in store.user_connection_ids(contact_user_id) {
                     self.peer
                         .send(
                             contact_conn_id,
@@ -1522,7 +1522,7 @@ impl Server {
         // Update outgoing contact requests of requester
         let mut update = proto::UpdateContacts::default();
         update.outgoing_requests.push(responder_id.to_proto());
-        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
+        for connection_id in self.store().await.user_connection_ids(requester_id) {
             self.peer.send(connection_id, update.clone())?;
         }
 
@@ -1534,7 +1534,7 @@ impl Server {
                 requester_id: requester_id.to_proto(),
                 should_notify: true,
             });
-        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
+        for connection_id in self.store().await.user_connection_ids(responder_id) {
             self.peer.send(connection_id, update.clone())?;
         }
 
@@ -1574,7 +1574,7 @@ impl Server {
             update
                 .remove_incoming_requests
                 .push(requester_id.to_proto());
-            for connection_id in store.connection_ids_for_user(responder_id) {
+            for connection_id in store.user_connection_ids(responder_id) {
                 self.peer.send(connection_id, update.clone())?;
             }
 
@@ -1588,7 +1588,7 @@ impl Server {
             update
                 .remove_outgoing_requests
                 .push(responder_id.to_proto());
-            for connection_id in store.connection_ids_for_user(requester_id) {
+            for connection_id in store.user_connection_ids(requester_id) {
                 self.peer.send(connection_id, update.clone())?;
             }
         }
@@ -1615,7 +1615,7 @@ impl Server {
         update
             .remove_outgoing_requests
             .push(responder_id.to_proto());
-        for connection_id in self.store().await.connection_ids_for_user(requester_id) {
+        for connection_id in self.store().await.user_connection_ids(requester_id) {
             self.peer.send(connection_id, update.clone())?;
         }
 
@@ -1624,7 +1624,7 @@ impl Server {
         update
             .remove_incoming_requests
             .push(requester_id.to_proto());
-        for connection_id in self.store().await.connection_ids_for_user(responder_id) {
+        for connection_id in self.store().await.user_connection_ids(responder_id) {
             self.peer.send(connection_id, update.clone())?;
         }
 
@@ -1819,21 +1819,25 @@ pub async fn handle_websocket_request(
     })
 }
 
-pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> axum::response::Response {
-    let metrics = server.store().await.metrics();
-    METRIC_CONNECTIONS.set(metrics.connections as _);
-    METRIC_SHARED_PROJECTS.set(metrics.shared_projects as _);
+pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
+    let connections = server
+        .store()
+        .await
+        .connections()
+        .filter(|connection| !connection.admin)
+        .count();
+
+    METRIC_CONNECTIONS.set(connections as _);
+
+    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
+    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 
     let encoder = prometheus::TextEncoder::new();
     let metric_families = prometheus::gather();
-    match encoder.encode_to_string(&metric_families) {
-        Ok(string) => (StatusCode::OK, string).into_response(),
-        Err(error) => (
-            StatusCode::INTERNAL_SERVER_ERROR,
-            format!("failed to encode metrics {:?}", error),
-        )
-            .into_response(),
-    }
+    let encoded_metrics = encoder
+        .encode_to_string(&metric_families)
+        .map_err(|err| anyhow!("{}", err))?;
+    Ok(encoded_metrics)
 }
 
 fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {

crates/collab/src/rpc/store.rs 🔗

@@ -1,111 +1,32 @@
-use crate::db::{self, ProjectId, UserId};
+use crate::db::{self, UserId};
 use anyhow::{anyhow, Result};
-use collections::{BTreeMap, BTreeSet, HashMap, HashSet};
+use collections::{BTreeMap, HashSet};
 use rpc::{proto, ConnectionId};
 use serde::Serialize;
-use std::path::PathBuf;
 use tracing::instrument;
 
-pub type RoomId = u64;
-
 #[derive(Default, Serialize)]
 pub struct Store {
-    connections: BTreeMap<ConnectionId, ConnectionState>,
+    connections: BTreeMap<ConnectionId, Connection>,
     connected_users: BTreeMap<UserId, ConnectedUser>,
-    next_room_id: RoomId,
-    rooms: BTreeMap<RoomId, proto::Room>,
-    projects: BTreeMap<ProjectId, Project>,
 }
 
 #[derive(Default, Serialize)]
 struct ConnectedUser {
     connection_ids: HashSet<ConnectionId>,
-    active_call: Option<Call>,
 }
 
 #[derive(Serialize)]
-struct ConnectionState {
-    user_id: UserId,
-    admin: bool,
-    projects: BTreeSet<ProjectId>,
-}
-
-#[derive(Copy, Clone, Eq, PartialEq, Serialize)]
-pub struct Call {
-    pub calling_user_id: UserId,
-    pub room_id: RoomId,
-    pub connection_id: Option<ConnectionId>,
-    pub initial_project_id: Option<ProjectId>,
-}
-
-#[derive(Serialize)]
-pub struct Project {
-    pub id: ProjectId,
-    pub room_id: RoomId,
-    pub host_connection_id: ConnectionId,
-    pub host: Collaborator,
-    pub guests: HashMap<ConnectionId, Collaborator>,
-    pub active_replica_ids: HashSet<ReplicaId>,
-    pub worktrees: BTreeMap<u64, Worktree>,
-    pub language_servers: Vec<proto::LanguageServer>,
-}
-
-#[derive(Serialize)]
-pub struct Collaborator {
-    pub replica_id: ReplicaId,
+pub struct Connection {
     pub user_id: UserId,
     pub admin: bool,
 }
 
-#[derive(Default, Serialize)]
-pub struct Worktree {
-    pub abs_path: PathBuf,
-    pub root_name: String,
-    pub visible: bool,
-    #[serde(skip)]
-    pub entries: BTreeMap<u64, proto::Entry>,
-    #[serde(skip)]
-    pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
-    pub scan_id: u64,
-    pub is_complete: bool,
-}
-
-pub type ReplicaId = u16;
-
-#[derive(Copy, Clone)]
-pub struct Metrics {
-    pub connections: usize,
-    pub shared_projects: usize,
-}
-
 impl Store {
-    pub fn metrics(&self) -> Metrics {
-        let connections = self.connections.values().filter(|c| !c.admin).count();
-        let mut shared_projects = 0;
-        for project in self.projects.values() {
-            if let Some(connection) = self.connections.get(&project.host_connection_id) {
-                if !connection.admin {
-                    shared_projects += 1;
-                }
-            }
-        }
-
-        Metrics {
-            connections,
-            shared_projects,
-        }
-    }
-
     #[instrument(skip(self))]
     pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId, admin: bool) {
-        self.connections.insert(
-            connection_id,
-            ConnectionState {
-                user_id,
-                admin,
-                projects: Default::default(),
-            },
-        );
+        self.connections
+            .insert(connection_id, Connection { user_id, admin });
         let connected_user = self.connected_users.entry(user_id).or_default();
         connected_user.connection_ids.insert(connection_id);
     }
@@ -127,10 +48,11 @@ impl Store {
         Ok(())
     }
 
-    pub fn connection_ids_for_user(
-        &self,
-        user_id: UserId,
-    ) -> impl Iterator<Item = ConnectionId> + '_ {
+    pub fn connections(&self) -> impl Iterator<Item = &Connection> {
+        self.connections.values()
+    }
+
+    pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
         self.connected_users
             .get(&user_id)
             .into_iter()
@@ -197,35 +119,9 @@ impl Store {
         }
     }
 
-    pub fn rooms(&self) -> &BTreeMap<RoomId, proto::Room> {
-        &self.rooms
-    }
-
     #[cfg(test)]
     pub fn check_invariants(&self) {
         for (connection_id, connection) in &self.connections {
-            for project_id in &connection.projects {
-                let project = &self.projects.get(project_id).unwrap();
-                if project.host_connection_id != *connection_id {
-                    assert!(project.guests.contains_key(connection_id));
-                }
-
-                for (worktree_id, worktree) in project.worktrees.iter() {
-                    let mut paths = HashMap::default();
-                    for entry in worktree.entries.values() {
-                        let prev_entry = paths.insert(&entry.path, entry);
-                        assert_eq!(
-                            prev_entry,
-                            None,
-                            "worktree {:?}, duplicate path for entries {:?} and {:?}",
-                            worktree_id,
-                            prev_entry.unwrap(),
-                            entry
-                        );
-                    }
-                }
-            }
-
             assert!(self
                 .connected_users
                 .get(&connection.user_id)
@@ -241,85 +137,6 @@ impl Store {
                     *user_id
                 );
             }
-
-            if let Some(active_call) = state.active_call.as_ref() {
-                if let Some(active_call_connection_id) = active_call.connection_id {
-                    assert!(
-                        state.connection_ids.contains(&active_call_connection_id),
-                        "call is active on a dead connection"
-                    );
-                    assert!(
-                        state.connection_ids.contains(&active_call_connection_id),
-                        "call is active on a dead connection"
-                    );
-                }
-            }
-        }
-
-        for (room_id, room) in &self.rooms {
-            // for pending_user_id in &room.pending_participant_user_ids {
-            //     assert!(
-            //         self.connected_users
-            //             .contains_key(&UserId::from_proto(*pending_user_id)),
-            //         "call is active on a user that has disconnected"
-            //     );
-            // }
-
-            for participant in &room.participants {
-                assert!(
-                    self.connections
-                        .contains_key(&ConnectionId(participant.peer_id)),
-                    "room {} contains participant {:?} that has disconnected",
-                    room_id,
-                    participant
-                );
-
-                for participant_project in &participant.projects {
-                    let project = &self.projects[&ProjectId::from_proto(participant_project.id)];
-                    assert_eq!(
-                        project.room_id, *room_id,
-                        "project was shared on a different room"
-                    );
-                }
-            }
-
-            // assert!(
-            //     !room.pending_participant_user_ids.is_empty() || !room.participants.is_empty(),
-            //     "room can't be empty"
-            // );
-        }
-
-        for (project_id, project) in &self.projects {
-            let host_connection = self.connections.get(&project.host_connection_id).unwrap();
-            assert!(host_connection.projects.contains(project_id));
-
-            for guest_connection_id in project.guests.keys() {
-                let guest_connection = self.connections.get(guest_connection_id).unwrap();
-                assert!(guest_connection.projects.contains(project_id));
-            }
-            assert_eq!(project.active_replica_ids.len(), project.guests.len());
-            assert_eq!(
-                project.active_replica_ids,
-                project
-                    .guests
-                    .values()
-                    .map(|guest| guest.replica_id)
-                    .collect::<HashSet<_>>(),
-            );
-
-            let room = &self.rooms[&project.room_id];
-            let room_participant = room
-                .participants
-                .iter()
-                .find(|participant| participant.peer_id == project.host_connection_id.0)
-                .unwrap();
-            assert!(
-                room_participant
-                    .projects
-                    .iter()
-                    .any(|project| project.id == project_id.to_proto()),
-                "project was not shared in room"
-            );
         }
     }
 }