1use crate::db::UserId;
2use anyhow::{anyhow, Result};
3use collections::{BTreeMap, HashSet};
4use rpc::ConnectionId;
5use serde::Serialize;
6use tracing::instrument;
7
8#[derive(Default, Serialize)]
9pub struct ConnectionPool {
10 connections: BTreeMap<ConnectionId, Connection>,
11 connected_users: BTreeMap<UserId, ConnectedUser>,
12}
13
14#[derive(Default, Serialize)]
15struct ConnectedUser {
16 connection_ids: HashSet<ConnectionId>,
17}
18
19#[derive(Serialize)]
20pub struct Connection {
21 pub user_id: UserId,
22 pub admin: bool,
23}
24
25impl ConnectionPool {
26 pub fn reset(&mut self) {
27 self.connections.clear();
28 self.connected_users.clear();
29 }
30
31 #[instrument(skip(self))]
32 pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId, admin: bool) {
33 self.connections
34 .insert(connection_id, Connection { user_id, admin });
35 let connected_user = self.connected_users.entry(user_id).or_default();
36 connected_user.connection_ids.insert(connection_id);
37 }
38
39 #[instrument(skip(self))]
40 pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
41 let connection = self
42 .connections
43 .get_mut(&connection_id)
44 .ok_or_else(|| anyhow!("no such connection"))?;
45
46 let user_id = connection.user_id;
47 let connected_user = self.connected_users.get_mut(&user_id).unwrap();
48 connected_user.connection_ids.remove(&connection_id);
49 if connected_user.connection_ids.is_empty() {
50 self.connected_users.remove(&user_id);
51 }
52 self.connections.remove(&connection_id).unwrap();
53 Ok(())
54 }
55
56 pub fn connections(&self) -> impl Iterator<Item = &Connection> {
57 self.connections.values()
58 }
59
60 pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
61 self.connected_users
62 .get(&user_id)
63 .into_iter()
64 .map(|state| &state.connection_ids)
65 .flatten()
66 .copied()
67 }
68
69 pub fn is_user_online(&self, user_id: UserId) -> bool {
70 !self
71 .connected_users
72 .get(&user_id)
73 .unwrap_or(&Default::default())
74 .connection_ids
75 .is_empty()
76 }
77
78 #[cfg(test)]
79 pub fn check_invariants(&self) {
80 for (connection_id, connection) in &self.connections {
81 assert!(self
82 .connected_users
83 .get(&connection.user_id)
84 .unwrap()
85 .connection_ids
86 .contains(connection_id));
87 }
88
89 for (user_id, state) in &self.connected_users {
90 for connection_id in &state.connection_ids {
91 assert_eq!(
92 self.connections.get(connection_id).unwrap().user_id,
93 *user_id
94 );
95 }
96 }
97 }
98}