connection_pool.rs

  1use crate::db::{ChannelId, ChannelRole, DevServerId, PrincipalId, UserId};
  2use anyhow::{anyhow, Result};
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use semantic_version::SemanticVersion;
  6use serde::Serialize;
  7use std::fmt;
  8use tracing::instrument;
  9
 10#[derive(Default, Serialize)]
 11pub struct ConnectionPool {
 12    connections: BTreeMap<ConnectionId, Connection>,
 13    connected_users: BTreeMap<UserId, ConnectedPrincipal>,
 14    connected_dev_servers: BTreeMap<DevServerId, ConnectionId>,
 15    channels: ChannelPool,
 16    offline_dev_servers: HashSet<DevServerId>,
 17}
 18
 19#[derive(Default, Serialize)]
 20struct ConnectedPrincipal {
 21    connection_ids: HashSet<ConnectionId>,
 22}
 23
 24#[derive(Copy, Clone, Debug, Serialize, PartialOrd, PartialEq, Eq, Ord)]
 25pub struct ZedVersion(pub SemanticVersion);
 26
 27impl fmt::Display for ZedVersion {
 28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 29        write!(f, "{}", self.0)
 30    }
 31}
 32
 33impl ZedVersion {
 34    pub fn can_collaborate(&self) -> bool {
 35        self.0 >= SemanticVersion::new(0, 151, 0)
 36    }
 37}
 38
 39#[derive(Serialize)]
 40pub struct Connection {
 41    pub principal_id: PrincipalId,
 42    pub admin: bool,
 43    pub zed_version: ZedVersion,
 44}
 45
 46impl ConnectionPool {
 47    pub fn reset(&mut self) {
 48        self.connections.clear();
 49        self.connected_users.clear();
 50        self.connected_dev_servers.clear();
 51        self.channels.clear();
 52    }
 53
 54    pub fn connection(&mut self, connection_id: ConnectionId) -> Option<&Connection> {
 55        self.connections.get(&connection_id)
 56    }
 57
 58    #[instrument(skip(self))]
 59    pub fn add_connection(
 60        &mut self,
 61        connection_id: ConnectionId,
 62        user_id: UserId,
 63        admin: bool,
 64        zed_version: ZedVersion,
 65    ) {
 66        self.connections.insert(
 67            connection_id,
 68            Connection {
 69                principal_id: PrincipalId::UserId(user_id),
 70                admin,
 71                zed_version,
 72            },
 73        );
 74        let connected_user = self.connected_users.entry(user_id).or_default();
 75        connected_user.connection_ids.insert(connection_id);
 76    }
 77
 78    pub fn add_dev_server(
 79        &mut self,
 80        connection_id: ConnectionId,
 81        dev_server_id: DevServerId,
 82        zed_version: ZedVersion,
 83    ) {
 84        self.connections.insert(
 85            connection_id,
 86            Connection {
 87                principal_id: PrincipalId::DevServerId(dev_server_id),
 88                admin: false,
 89                zed_version,
 90            },
 91        );
 92
 93        self.connected_dev_servers
 94            .insert(dev_server_id, connection_id);
 95    }
 96
 97    #[instrument(skip(self))]
 98    pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
 99        let connection = self
100            .connections
101            .get_mut(&connection_id)
102            .ok_or_else(|| anyhow!("no such connection"))?;
103
104        match connection.principal_id {
105            PrincipalId::UserId(user_id) => {
106                let connected_user = self.connected_users.get_mut(&user_id).unwrap();
107                connected_user.connection_ids.remove(&connection_id);
108                if connected_user.connection_ids.is_empty() {
109                    self.connected_users.remove(&user_id);
110                    self.channels.remove_user(&user_id);
111                }
112            }
113            PrincipalId::DevServerId(dev_server_id) => {
114                self.connected_dev_servers.remove(&dev_server_id);
115                self.offline_dev_servers.remove(&dev_server_id);
116            }
117        }
118        self.connections.remove(&connection_id).unwrap();
119        Ok(())
120    }
121
122    pub fn set_dev_server_offline(&mut self, dev_server_id: DevServerId) {
123        self.offline_dev_servers.insert(dev_server_id);
124    }
125
126    pub fn connections(&self) -> impl Iterator<Item = &Connection> {
127        self.connections.values()
128    }
129
130    pub fn user_connections(&self, user_id: UserId) -> impl Iterator<Item = &Connection> + '_ {
131        self.connected_users
132            .get(&user_id)
133            .into_iter()
134            .flat_map(|state| {
135                state
136                    .connection_ids
137                    .iter()
138                    .flat_map(|cid| self.connections.get(cid))
139            })
140    }
141
142    pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
143        self.connected_users
144            .get(&user_id)
145            .into_iter()
146            .flat_map(|state| &state.connection_ids)
147            .copied()
148    }
149
150    pub fn dev_server_status(&self, dev_server_id: DevServerId) -> proto::DevServerStatus {
151        if self.dev_server_connection_id(dev_server_id).is_some()
152            && !self.offline_dev_servers.contains(&dev_server_id)
153        {
154            proto::DevServerStatus::Online
155        } else {
156            proto::DevServerStatus::Offline
157        }
158    }
159
160    pub fn dev_server_connection_id(&self, dev_server_id: DevServerId) -> Option<ConnectionId> {
161        self.connected_dev_servers.get(&dev_server_id).copied()
162    }
163
164    pub fn online_dev_server_connection_id(
165        &self,
166        dev_server_id: DevServerId,
167    ) -> Result<ConnectionId> {
168        match self.connected_dev_servers.get(&dev_server_id) {
169            Some(cid) => Ok(*cid),
170            None => Err(anyhow!(proto::ErrorCode::DevServerOffline)),
171        }
172    }
173
174    pub fn dev_server_connection_id_supporting(
175        &self,
176        dev_server_id: DevServerId,
177        required: ZedVersion,
178    ) -> Result<ConnectionId> {
179        match self.connected_dev_servers.get(&dev_server_id) {
180            Some(cid) if self.connections[cid].zed_version >= required => Ok(*cid),
181            Some(_) => Err(anyhow!(proto::ErrorCode::RemoteUpgradeRequired)),
182            None => Err(anyhow!(proto::ErrorCode::DevServerOffline)),
183        }
184    }
185
186    pub fn channel_user_ids(
187        &self,
188        channel_id: ChannelId,
189    ) -> impl Iterator<Item = (UserId, ChannelRole)> + '_ {
190        self.channels.users_to_notify(channel_id)
191    }
192
193    pub fn channel_connection_ids(
194        &self,
195        channel_id: ChannelId,
196    ) -> impl Iterator<Item = (ConnectionId, ChannelRole)> + '_ {
197        self.channels
198            .users_to_notify(channel_id)
199            .flat_map(|(user_id, role)| {
200                self.user_connection_ids(user_id)
201                    .map(move |connection_id| (connection_id, role))
202            })
203    }
204
205    pub fn subscribe_to_channel(
206        &mut self,
207        user_id: UserId,
208        channel_id: ChannelId,
209        role: ChannelRole,
210    ) {
211        self.channels.subscribe(user_id, channel_id, role);
212    }
213
214    pub fn unsubscribe_from_channel(&mut self, user_id: &UserId, channel_id: &ChannelId) {
215        self.channels.unsubscribe(user_id, channel_id);
216    }
217
218    pub fn is_user_online(&self, user_id: UserId) -> bool {
219        !self
220            .connected_users
221            .get(&user_id)
222            .unwrap_or(&Default::default())
223            .connection_ids
224            .is_empty()
225    }
226
227    #[cfg(test)]
228    pub fn check_invariants(&self) {
229        for (connection_id, connection) in &self.connections {
230            match &connection.principal_id {
231                PrincipalId::UserId(user_id) => {
232                    assert!(self
233                        .connected_users
234                        .get(user_id)
235                        .unwrap()
236                        .connection_ids
237                        .contains(connection_id));
238                }
239                PrincipalId::DevServerId(dev_server_id) => {
240                    assert_eq!(
241                        self.connected_dev_servers.get(dev_server_id).unwrap(),
242                        connection_id
243                    );
244                }
245            }
246        }
247
248        for (user_id, state) in &self.connected_users {
249            for connection_id in &state.connection_ids {
250                assert_eq!(
251                    self.connections.get(connection_id).unwrap().principal_id,
252                    PrincipalId::UserId(*user_id)
253                );
254            }
255        }
256
257        for (dev_server_id, connection_id) in &self.connected_dev_servers {
258            assert_eq!(
259                self.connections.get(connection_id).unwrap().principal_id,
260                PrincipalId::DevServerId(*dev_server_id)
261            );
262        }
263    }
264}
265
266#[derive(Default, Serialize)]
267pub struct ChannelPool {
268    by_user: HashMap<UserId, HashMap<ChannelId, ChannelRole>>,
269    by_channel: HashMap<ChannelId, HashSet<UserId>>,
270}
271
272impl ChannelPool {
273    pub fn clear(&mut self) {
274        self.by_user.clear();
275        self.by_channel.clear();
276    }
277
278    pub fn subscribe(&mut self, user_id: UserId, channel_id: ChannelId, role: ChannelRole) {
279        self.by_user
280            .entry(user_id)
281            .or_default()
282            .insert(channel_id, role);
283        self.by_channel
284            .entry(channel_id)
285            .or_default()
286            .insert(user_id);
287    }
288
289    pub fn unsubscribe(&mut self, user_id: &UserId, channel_id: &ChannelId) {
290        if let Some(channels) = self.by_user.get_mut(user_id) {
291            channels.remove(channel_id);
292            if channels.is_empty() {
293                self.by_user.remove(user_id);
294            }
295        }
296        if let Some(users) = self.by_channel.get_mut(channel_id) {
297            users.remove(user_id);
298            if users.is_empty() {
299                self.by_channel.remove(channel_id);
300            }
301        }
302    }
303
304    pub fn remove_user(&mut self, user_id: &UserId) {
305        if let Some(channels) = self.by_user.remove(user_id) {
306            for channel_id in channels.keys() {
307                self.unsubscribe(user_id, channel_id)
308            }
309        }
310    }
311
312    pub fn users_to_notify(
313        &self,
314        channel_id: ChannelId,
315    ) -> impl '_ + Iterator<Item = (UserId, ChannelRole)> {
316        self.by_channel
317            .get(&channel_id)
318            .into_iter()
319            .flat_map(move |users| {
320                users.iter().flat_map(move |user_id| {
321                    Some((
322                        *user_id,
323                        self.by_user
324                            .get(user_id)
325                            .and_then(|channels| channels.get(&channel_id))
326                            .copied()?,
327                    ))
328                })
329            })
330    }
331}