connection_pool.rs

  1use crate::db::UserId;
  2use anyhow::{anyhow, Result};
  3use collections::{BTreeMap, HashSet};
  4use rpc::ConnectionId;
  5use serde::Serialize;
  6use tracing::instrument;
  7use util::SemanticVersion;
  8
  9#[derive(Default, Serialize)]
 10pub struct ConnectionPool {
 11    connections: BTreeMap<ConnectionId, Connection>,
 12    connected_users: BTreeMap<UserId, ConnectedUser>,
 13}
 14
 15#[derive(Default, Serialize)]
 16struct ConnectedUser {
 17    connection_ids: HashSet<ConnectionId>,
 18}
 19
 20#[derive(Debug, Serialize)]
 21pub struct ZedVersion(pub SemanticVersion);
 22use std::fmt;
 23
 24impl fmt::Display for ZedVersion {
 25    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 26        write!(f, "{}", self.0)
 27    }
 28}
 29
 30impl ZedVersion {
 31    pub fn is_supported(&self) -> bool {
 32        self.0 != SemanticVersion::new(0, 123, 0)
 33    }
 34    pub fn supports_talker_role(&self) -> bool {
 35        self.0 >= SemanticVersion::new(0, 125, 0)
 36    }
 37}
 38
 39#[derive(Serialize)]
 40pub struct Connection {
 41    pub user_id: UserId,
 42    pub admin: bool,
 43    pub zed_version: ZedVersion,
 44}
 45
 46impl ConnectionPool {
 47    pub fn reset(&mut self) {
 48        self.connections.clear();
 49        self.connected_users.clear();
 50    }
 51
 52    #[instrument(skip(self))]
 53    pub fn add_connection(
 54        &mut self,
 55        connection_id: ConnectionId,
 56        user_id: UserId,
 57        admin: bool,
 58        zed_version: ZedVersion,
 59    ) {
 60        self.connections.insert(
 61            connection_id,
 62            Connection {
 63                user_id,
 64                admin,
 65                zed_version,
 66            },
 67        );
 68        let connected_user = self.connected_users.entry(user_id).or_default();
 69        connected_user.connection_ids.insert(connection_id);
 70    }
 71
 72    #[instrument(skip(self))]
 73    pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
 74        let connection = self
 75            .connections
 76            .get_mut(&connection_id)
 77            .ok_or_else(|| anyhow!("no such connection"))?;
 78
 79        let user_id = connection.user_id;
 80        let connected_user = self.connected_users.get_mut(&user_id).unwrap();
 81        connected_user.connection_ids.remove(&connection_id);
 82        if connected_user.connection_ids.is_empty() {
 83            self.connected_users.remove(&user_id);
 84        }
 85        self.connections.remove(&connection_id).unwrap();
 86        Ok(())
 87    }
 88
 89    pub fn connections(&self) -> impl Iterator<Item = &Connection> {
 90        self.connections.values()
 91    }
 92
 93    pub fn user_connections(&self, user_id: UserId) -> impl Iterator<Item = &Connection> + '_ {
 94        self.connected_users
 95            .get(&user_id)
 96            .into_iter()
 97            .map(|state| {
 98                state
 99                    .connection_ids
100                    .iter()
101                    .flat_map(|cid| self.connections.get(cid))
102            })
103            .flatten()
104    }
105
106    pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
107        self.connected_users
108            .get(&user_id)
109            .into_iter()
110            .map(|state| &state.connection_ids)
111            .flatten()
112            .copied()
113    }
114
115    pub fn is_user_online(&self, user_id: UserId) -> bool {
116        !self
117            .connected_users
118            .get(&user_id)
119            .unwrap_or(&Default::default())
120            .connection_ids
121            .is_empty()
122    }
123
124    #[cfg(test)]
125    pub fn check_invariants(&self) {
126        for (connection_id, connection) in &self.connections {
127            assert!(self
128                .connected_users
129                .get(&connection.user_id)
130                .unwrap()
131                .connection_ids
132                .contains(connection_id));
133        }
134
135        for (user_id, state) in &self.connected_users {
136            for connection_id in &state.connection_ids {
137                assert_eq!(
138                    self.connections.get(connection_id).unwrap().user_id,
139                    *user_id
140                );
141            }
142        }
143    }
144}