connection_pool.rs

 1use crate::db::UserId;
 2use anyhow::{anyhow, Result};
 3use collections::{BTreeMap, HashSet};
 4use rpc::ConnectionId;
 5use serde::Serialize;
 6use tracing::instrument;
 7
 8#[derive(Default, Serialize)]
 9pub struct ConnectionPool {
10    connections: BTreeMap<ConnectionId, Connection>,
11    connected_users: BTreeMap<UserId, ConnectedUser>,
12}
13
14#[derive(Default, Serialize)]
15struct ConnectedUser {
16    connection_ids: HashSet<ConnectionId>,
17}
18
19#[derive(Serialize)]
20pub struct Connection {
21    pub user_id: UserId,
22    pub admin: bool,
23}
24
25impl ConnectionPool {
26    #[instrument(skip(self))]
27    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId, admin: bool) {
28        self.connections
29            .insert(connection_id, Connection { user_id, admin });
30        let connected_user = self.connected_users.entry(user_id).or_default();
31        connected_user.connection_ids.insert(connection_id);
32    }
33
34    #[instrument(skip(self))]
35    pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
36        let connection = self
37            .connections
38            .get_mut(&connection_id)
39            .ok_or_else(|| anyhow!("no such connection"))?;
40
41        let user_id = connection.user_id;
42        let connected_user = self.connected_users.get_mut(&user_id).unwrap();
43        connected_user.connection_ids.remove(&connection_id);
44        if connected_user.connection_ids.is_empty() {
45            self.connected_users.remove(&user_id);
46        }
47        self.connections.remove(&connection_id).unwrap();
48        Ok(())
49    }
50
51    pub fn connections(&self) -> impl Iterator<Item = &Connection> {
52        self.connections.values()
53    }
54
55    pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
56        self.connected_users
57            .get(&user_id)
58            .into_iter()
59            .map(|state| &state.connection_ids)
60            .flatten()
61            .copied()
62    }
63
64    pub fn is_user_online(&self, user_id: UserId) -> bool {
65        !self
66            .connected_users
67            .get(&user_id)
68            .unwrap_or(&Default::default())
69            .connection_ids
70            .is_empty()
71    }
72
73    #[cfg(test)]
74    pub fn check_invariants(&self) {
75        for (connection_id, connection) in &self.connections {
76            assert!(self
77                .connected_users
78                .get(&connection.user_id)
79                .unwrap()
80                .connection_ids
81                .contains(connection_id));
82        }
83
84        for (user_id, state) in &self.connected_users {
85            for connection_id in &state.connection_ids {
86                assert_eq!(
87                    self.connections.get(connection_id).unwrap().user_id,
88                    *user_id
89                );
90            }
91        }
92    }
93}