connection_pool.rs

  1use crate::db::UserId;
  2use anyhow::{anyhow, Result};
  3use collections::{BTreeMap, HashSet};
  4use rpc::ConnectionId;
  5use serde::Serialize;
  6use tracing::instrument;
  7use util::SemanticVersion;
  8
  9#[derive(Default, Serialize)]
 10pub struct ConnectionPool {
 11    connections: BTreeMap<ConnectionId, Connection>,
 12    connected_users: BTreeMap<UserId, ConnectedUser>,
 13}
 14
 15#[derive(Default, Serialize)]
 16struct ConnectedUser {
 17    connection_ids: HashSet<ConnectionId>,
 18}
 19
 20#[derive(Debug, Serialize)]
 21pub struct ZedVersion(pub SemanticVersion);
 22use std::fmt;
 23
 24impl fmt::Display for ZedVersion {
 25    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 26        write!(f, "{}", self.0)
 27    }
 28}
 29
 30impl ZedVersion {
 31    pub fn is_supported(&self) -> bool {
 32        self.0 != SemanticVersion::new(0, 123, 0)
 33    }
 34    pub fn supports_talker_role(&self) -> bool {
 35        self.0 >= SemanticVersion::new(0, 125, 0)
 36    }
 37}
 38
 39#[derive(Serialize)]
 40pub struct Connection {
 41    pub user_id: UserId,
 42    pub admin: bool,
 43    pub zed_version: ZedVersion,
 44}
 45
 46impl ConnectionPool {
 47    pub fn reset(&mut self) {
 48        self.connections.clear();
 49        self.connected_users.clear();
 50    }
 51
 52    #[instrument(skip(self))]
 53    pub fn add_connection(
 54        &mut self,
 55        connection_id: ConnectionId,
 56        user_id: UserId,
 57        admin: bool,
 58        zed_version: ZedVersion,
 59    ) {
 60        self.connections.insert(
 61            connection_id,
 62            Connection {
 63                user_id,
 64                admin,
 65                zed_version,
 66            },
 67        );
 68        let connected_user = self.connected_users.entry(user_id).or_default();
 69        connected_user.connection_ids.insert(connection_id);
 70    }
 71
 72    #[instrument(skip(self))]
 73    pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
 74        let connection = self
 75            .connections
 76            .get_mut(&connection_id)
 77            .ok_or_else(|| anyhow!("no such connection"))?;
 78
 79        let user_id = connection.user_id;
 80        let connected_user = self.connected_users.get_mut(&user_id).unwrap();
 81        connected_user.connection_ids.remove(&connection_id);
 82        if connected_user.connection_ids.is_empty() {
 83            self.connected_users.remove(&user_id);
 84        }
 85        self.connections.remove(&connection_id).unwrap();
 86        Ok(())
 87    }
 88
 89    pub fn connections(&self) -> impl Iterator<Item = &Connection> {
 90        self.connections.values()
 91    }
 92
 93    pub fn user_connections(&self, user_id: UserId) -> impl Iterator<Item = &Connection> + '_ {
 94        self.connected_users
 95            .get(&user_id)
 96            .into_iter()
 97            .flat_map(|state| {
 98                state
 99                    .connection_ids
100                    .iter()
101                    .flat_map(|cid| self.connections.get(cid))
102            })
103    }
104
105    pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
106        self.connected_users
107            .get(&user_id)
108            .into_iter()
109            .flat_map(|state| &state.connection_ids)
110            .copied()
111    }
112
113    pub fn is_user_online(&self, user_id: UserId) -> bool {
114        !self
115            .connected_users
116            .get(&user_id)
117            .unwrap_or(&Default::default())
118            .connection_ids
119            .is_empty()
120    }
121
122    #[cfg(test)]
123    pub fn check_invariants(&self) {
124        for (connection_id, connection) in &self.connections {
125            assert!(self
126                .connected_users
127                .get(&connection.user_id)
128                .unwrap()
129                .connection_ids
130                .contains(connection_id));
131        }
132
133        for (user_id, state) in &self.connected_users {
134            for connection_id in &state.connection_ids {
135                assert_eq!(
136                    self.connections.get(connection_id).unwrap().user_id,
137                    *user_id
138                );
139            }
140        }
141    }
142}