connection_pool.rs

 1use crate::db::UserId;
 2use anyhow::{anyhow, Result};
 3use collections::{BTreeMap, HashSet};
 4use rpc::ConnectionId;
 5use serde::Serialize;
 6use tracing::instrument;
 7
 8#[derive(Default, Serialize)]
 9pub struct ConnectionPool {
10    connections: BTreeMap<ConnectionId, Connection>,
11    connected_users: BTreeMap<UserId, ConnectedUser>,
12}
13
14#[derive(Default, Serialize)]
15struct ConnectedUser {
16    connection_ids: HashSet<ConnectionId>,
17}
18
19#[derive(Serialize)]
20pub struct Connection {
21    pub user_id: UserId,
22    pub admin: bool,
23}
24
25impl ConnectionPool {
26    pub fn reset(&mut self) {
27        self.connections.clear();
28        self.connected_users.clear();
29    }
30
31    #[instrument(skip(self))]
32    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId, admin: bool) {
33        self.connections
34            .insert(connection_id, Connection { user_id, admin });
35        let connected_user = self.connected_users.entry(user_id).or_default();
36        connected_user.connection_ids.insert(connection_id);
37    }
38
39    #[instrument(skip(self))]
40    pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
41        let connection = self
42            .connections
43            .get_mut(&connection_id)
44            .ok_or_else(|| anyhow!("no such connection"))?;
45
46        let user_id = connection.user_id;
47        let connected_user = self.connected_users.get_mut(&user_id).unwrap();
48        connected_user.connection_ids.remove(&connection_id);
49        if connected_user.connection_ids.is_empty() {
50            self.connected_users.remove(&user_id);
51        }
52        self.connections.remove(&connection_id).unwrap();
53        Ok(())
54    }
55
56    pub fn connections(&self) -> impl Iterator<Item = &Connection> {
57        self.connections.values()
58    }
59
60    pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
61        self.connected_users
62            .get(&user_id)
63            .into_iter()
64            .map(|state| &state.connection_ids)
65            .flatten()
66            .copied()
67    }
68
69    pub fn is_user_online(&self, user_id: UserId) -> bool {
70        !self
71            .connected_users
72            .get(&user_id)
73            .unwrap_or(&Default::default())
74            .connection_ids
75            .is_empty()
76    }
77
78    #[cfg(test)]
79    pub fn check_invariants(&self) {
80        for (connection_id, connection) in &self.connections {
81            assert!(self
82                .connected_users
83                .get(&connection.user_id)
84                .unwrap()
85                .connection_ids
86                .contains(connection_id));
87        }
88
89        for (user_id, state) in &self.connected_users {
90            for connection_id in &state.connection_ids {
91                assert_eq!(
92                    self.connections.get(connection_id).unwrap().user_id,
93                    *user_id
94                );
95            }
96        }
97    }
98}