1use crate::db::UserId;
2use anyhow::{anyhow, Result};
3use collections::{BTreeMap, HashSet};
4use rpc::ConnectionId;
5use serde::Serialize;
6use tracing::instrument;
7use util::SemanticVersion;
8
9#[derive(Default, Serialize)]
10pub struct ConnectionPool {
11 connections: BTreeMap<ConnectionId, Connection>,
12 connected_users: BTreeMap<UserId, ConnectedUser>,
13}
14
15#[derive(Default, Serialize)]
16struct ConnectedUser {
17 connection_ids: HashSet<ConnectionId>,
18}
19
20#[derive(Debug, Serialize)]
21pub struct ZedVersion(pub SemanticVersion);
22use std::fmt;
23
24impl fmt::Display for ZedVersion {
25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 write!(f, "{}", self.0)
27 }
28}
29
30impl ZedVersion {
31 pub fn is_supported(&self) -> bool {
32 self.0 != SemanticVersion::new(0, 123, 0)
33 }
34 pub fn supports_talker_role(&self) -> bool {
35 self.0 >= SemanticVersion::new(0, 125, 0)
36 }
37}
38
39#[derive(Serialize)]
40pub struct Connection {
41 pub user_id: UserId,
42 pub admin: bool,
43 pub zed_version: ZedVersion,
44}
45
46impl ConnectionPool {
47 pub fn reset(&mut self) {
48 self.connections.clear();
49 self.connected_users.clear();
50 }
51
52 #[instrument(skip(self))]
53 pub fn add_connection(
54 &mut self,
55 connection_id: ConnectionId,
56 user_id: UserId,
57 admin: bool,
58 zed_version: ZedVersion,
59 ) {
60 self.connections.insert(
61 connection_id,
62 Connection {
63 user_id,
64 admin,
65 zed_version,
66 },
67 );
68 let connected_user = self.connected_users.entry(user_id).or_default();
69 connected_user.connection_ids.insert(connection_id);
70 }
71
72 #[instrument(skip(self))]
73 pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
74 let connection = self
75 .connections
76 .get_mut(&connection_id)
77 .ok_or_else(|| anyhow!("no such connection"))?;
78
79 let user_id = connection.user_id;
80 let connected_user = self.connected_users.get_mut(&user_id).unwrap();
81 connected_user.connection_ids.remove(&connection_id);
82 if connected_user.connection_ids.is_empty() {
83 self.connected_users.remove(&user_id);
84 }
85 self.connections.remove(&connection_id).unwrap();
86 Ok(())
87 }
88
89 pub fn connections(&self) -> impl Iterator<Item = &Connection> {
90 self.connections.values()
91 }
92
93 pub fn user_connections(&self, user_id: UserId) -> impl Iterator<Item = &Connection> + '_ {
94 self.connected_users
95 .get(&user_id)
96 .into_iter()
97 .flat_map(|state| {
98 state
99 .connection_ids
100 .iter()
101 .flat_map(|cid| self.connections.get(cid))
102 })
103 }
104
105 pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
106 self.connected_users
107 .get(&user_id)
108 .into_iter()
109 .flat_map(|state| &state.connection_ids)
110 .copied()
111 }
112
113 pub fn is_user_online(&self, user_id: UserId) -> bool {
114 !self
115 .connected_users
116 .get(&user_id)
117 .unwrap_or(&Default::default())
118 .connection_ids
119 .is_empty()
120 }
121
122 #[cfg(test)]
123 pub fn check_invariants(&self) {
124 for (connection_id, connection) in &self.connections {
125 assert!(self
126 .connected_users
127 .get(&connection.user_id)
128 .unwrap()
129 .connection_ids
130 .contains(connection_id));
131 }
132
133 for (user_id, state) in &self.connected_users {
134 for connection_id in &state.connection_ids {
135 assert_eq!(
136 self.connections.get(connection_id).unwrap().user_id,
137 *user_id
138 );
139 }
140 }
141 }
142}