1use crate::db::{ChannelId, ChannelRole, UserId};
2use anyhow::{Context as _, Result};
3use collections::{BTreeMap, HashMap, HashSet};
4use rpc::ConnectionId;
5use semver::Version;
6use serde::Serialize;
7use std::fmt;
8use tracing::instrument;
9
10#[derive(Default, Serialize)]
11pub struct ConnectionPool {
12 connections: BTreeMap<ConnectionId, Connection>,
13 connected_users: BTreeMap<UserId, ConnectedPrincipal>,
14 channels: ChannelPool,
15}
16
17#[derive(Default, Serialize)]
18struct ConnectedPrincipal {
19 connection_ids: HashSet<ConnectionId>,
20}
21
22#[derive(Clone, Debug, Serialize, PartialOrd, PartialEq, Eq, Ord)]
23pub struct ZedVersion(pub Version);
24
25impl fmt::Display for ZedVersion {
26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 write!(f, "{}", self.0)
28 }
29}
30
31impl ZedVersion {
32 pub fn can_collaborate(&self) -> bool {
33 // v0.220.0 was the version in which we've updated the project search protocol.
34 self.0 >= Version::new(0, 220, 0)
35 }
36}
37
38#[derive(Serialize)]
39pub struct Connection {
40 pub user_id: UserId,
41 pub admin: bool,
42 pub zed_version: ZedVersion,
43}
44
45impl ConnectionPool {
46 pub fn reset(&mut self) {
47 self.connections.clear();
48 self.connected_users.clear();
49 self.channels.clear();
50 }
51
52 pub fn connection(&mut self, connection_id: ConnectionId) -> Option<&Connection> {
53 self.connections.get(&connection_id)
54 }
55
56 #[instrument(skip(self))]
57 pub fn add_connection(
58 &mut self,
59 connection_id: ConnectionId,
60 user_id: UserId,
61 admin: bool,
62 zed_version: ZedVersion,
63 ) {
64 self.connections.insert(
65 connection_id,
66 Connection {
67 user_id,
68 admin,
69 zed_version,
70 },
71 );
72 let connected_user = self.connected_users.entry(user_id).or_default();
73 connected_user.connection_ids.insert(connection_id);
74 }
75
76 #[instrument(skip(self))]
77 pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
78 let connection = self
79 .connections
80 .get_mut(&connection_id)
81 .context("no such connection")?;
82
83 let user_id = connection.user_id;
84
85 let connected_user = self.connected_users.get_mut(&user_id).unwrap();
86 connected_user.connection_ids.remove(&connection_id);
87 if connected_user.connection_ids.is_empty() {
88 self.connected_users.remove(&user_id);
89 self.channels.remove_user(&user_id);
90 };
91 self.connections.remove(&connection_id).unwrap();
92 Ok(())
93 }
94
95 pub fn connections(&self) -> impl Iterator<Item = &Connection> {
96 self.connections.values()
97 }
98
99 pub fn user_connections(&self, user_id: UserId) -> impl Iterator<Item = &Connection> + '_ {
100 self.connected_users
101 .get(&user_id)
102 .into_iter()
103 .flat_map(|state| {
104 state
105 .connection_ids
106 .iter()
107 .flat_map(|cid| self.connections.get(cid))
108 })
109 }
110
111 pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
112 self.connected_users
113 .get(&user_id)
114 .into_iter()
115 .flat_map(|state| &state.connection_ids)
116 .copied()
117 }
118
119 pub fn channel_user_ids(
120 &self,
121 channel_id: ChannelId,
122 ) -> impl Iterator<Item = (UserId, ChannelRole)> + '_ {
123 self.channels.users_to_notify(channel_id)
124 }
125
126 pub fn channel_connection_ids(
127 &self,
128 channel_id: ChannelId,
129 ) -> impl Iterator<Item = (ConnectionId, ChannelRole)> + '_ {
130 self.channels
131 .users_to_notify(channel_id)
132 .flat_map(|(user_id, role)| {
133 self.user_connection_ids(user_id)
134 .map(move |connection_id| (connection_id, role))
135 })
136 }
137
138 pub fn subscribe_to_channel(
139 &mut self,
140 user_id: UserId,
141 channel_id: ChannelId,
142 role: ChannelRole,
143 ) {
144 self.channels.subscribe(user_id, channel_id, role);
145 }
146
147 pub fn unsubscribe_from_channel(&mut self, user_id: &UserId, channel_id: &ChannelId) {
148 self.channels.unsubscribe(user_id, channel_id);
149 }
150
151 pub fn is_user_online(&self, user_id: UserId) -> bool {
152 !self
153 .connected_users
154 .get(&user_id)
155 .unwrap_or(&Default::default())
156 .connection_ids
157 .is_empty()
158 }
159
160 #[cfg(test)]
161 pub fn check_invariants(&self) {
162 for (connection_id, connection) in &self.connections {
163 assert!(
164 self.connected_users
165 .get(&connection.user_id)
166 .unwrap()
167 .connection_ids
168 .contains(connection_id)
169 );
170 }
171
172 for (user_id, state) in &self.connected_users {
173 for connection_id in &state.connection_ids {
174 assert_eq!(
175 self.connections.get(connection_id).unwrap().user_id,
176 *user_id
177 );
178 }
179 }
180 }
181}
182
183#[derive(Default, Serialize)]
184pub struct ChannelPool {
185 by_user: HashMap<UserId, HashMap<ChannelId, ChannelRole>>,
186 by_channel: HashMap<ChannelId, HashSet<UserId>>,
187}
188
189impl ChannelPool {
190 pub fn clear(&mut self) {
191 self.by_user.clear();
192 self.by_channel.clear();
193 }
194
195 pub fn subscribe(&mut self, user_id: UserId, channel_id: ChannelId, role: ChannelRole) {
196 self.by_user
197 .entry(user_id)
198 .or_default()
199 .insert(channel_id, role);
200 self.by_channel
201 .entry(channel_id)
202 .or_default()
203 .insert(user_id);
204 }
205
206 pub fn unsubscribe(&mut self, user_id: &UserId, channel_id: &ChannelId) {
207 if let Some(channels) = self.by_user.get_mut(user_id) {
208 channels.remove(channel_id);
209 if channels.is_empty() {
210 self.by_user.remove(user_id);
211 }
212 }
213 if let Some(users) = self.by_channel.get_mut(channel_id) {
214 users.remove(user_id);
215 if users.is_empty() {
216 self.by_channel.remove(channel_id);
217 }
218 }
219 }
220
221 pub fn remove_user(&mut self, user_id: &UserId) {
222 if let Some(channels) = self.by_user.remove(user_id) {
223 for channel_id in channels.keys() {
224 self.unsubscribe(user_id, channel_id)
225 }
226 }
227 }
228
229 pub fn users_to_notify(
230 &self,
231 channel_id: ChannelId,
232 ) -> impl '_ + Iterator<Item = (UserId, ChannelRole)> {
233 self.by_channel
234 .get(&channel_id)
235 .into_iter()
236 .flat_map(move |users| {
237 users.iter().flat_map(move |user_id| {
238 Some((
239 *user_id,
240 self.by_user
241 .get(user_id)
242 .and_then(|channels| channels.get(&channel_id))
243 .copied()?,
244 ))
245 })
246 })
247 }
248}