connection_pool.rs

  1use crate::db::{ChannelId, ChannelRole, UserId};
  2use anyhow::{Context as _, Result};
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::ConnectionId;
  5use semantic_version::SemanticVersion;
  6use serde::Serialize;
  7use std::fmt;
  8use tracing::instrument;
  9
 10#[derive(Default, Serialize)]
 11pub struct ConnectionPool {
 12    connections: BTreeMap<ConnectionId, Connection>,
 13    connected_users: BTreeMap<UserId, ConnectedPrincipal>,
 14    channels: ChannelPool,
 15}
 16
 17#[derive(Default, Serialize)]
 18struct ConnectedPrincipal {
 19    connection_ids: HashSet<ConnectionId>,
 20}
 21
 22#[derive(Copy, Clone, Debug, Serialize, PartialOrd, PartialEq, Eq, Ord)]
 23pub struct ZedVersion(pub SemanticVersion);
 24
 25impl fmt::Display for ZedVersion {
 26    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 27        write!(f, "{}", self.0)
 28    }
 29}
 30
 31impl ZedVersion {
 32    pub fn can_collaborate(&self) -> bool {
 33        // v0.198.4 is the first version where we no longer connect to Collab automatically.
 34        // We reject any clients older than that to prevent them from connecting to Collab just for authentication.
 35        if self.0 < SemanticVersion::new(0, 198, 4) {
 36            return false;
 37        }
 38
 39        // Since we hotfixed the changes to no longer connect to Collab automatically to Preview, we also need to reject
 40        // versions in the range [v0.199.0, v0.199.1].
 41        if self.0 >= SemanticVersion::new(0, 199, 0) && self.0 < SemanticVersion::new(0, 199, 2) {
 42            return false;
 43        }
 44
 45        true
 46    }
 47}
 48
 49#[derive(Serialize)]
 50pub struct Connection {
 51    pub user_id: UserId,
 52    pub admin: bool,
 53    pub zed_version: ZedVersion,
 54}
 55
 56impl ConnectionPool {
 57    pub fn reset(&mut self) {
 58        self.connections.clear();
 59        self.connected_users.clear();
 60        self.channels.clear();
 61    }
 62
 63    pub fn connection(&mut self, connection_id: ConnectionId) -> Option<&Connection> {
 64        self.connections.get(&connection_id)
 65    }
 66
 67    #[instrument(skip(self))]
 68    pub fn add_connection(
 69        &mut self,
 70        connection_id: ConnectionId,
 71        user_id: UserId,
 72        admin: bool,
 73        zed_version: ZedVersion,
 74    ) {
 75        self.connections.insert(
 76            connection_id,
 77            Connection {
 78                user_id,
 79                admin,
 80                zed_version,
 81            },
 82        );
 83        let connected_user = self.connected_users.entry(user_id).or_default();
 84        connected_user.connection_ids.insert(connection_id);
 85    }
 86
 87    #[instrument(skip(self))]
 88    pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
 89        let connection = self
 90            .connections
 91            .get_mut(&connection_id)
 92            .context("no such connection")?;
 93
 94        let user_id = connection.user_id;
 95
 96        let connected_user = self.connected_users.get_mut(&user_id).unwrap();
 97        connected_user.connection_ids.remove(&connection_id);
 98        if connected_user.connection_ids.is_empty() {
 99            self.connected_users.remove(&user_id);
100            self.channels.remove_user(&user_id);
101        };
102        self.connections.remove(&connection_id).unwrap();
103        Ok(())
104    }
105
106    pub fn connections(&self) -> impl Iterator<Item = &Connection> {
107        self.connections.values()
108    }
109
110    pub fn user_connections(&self, user_id: UserId) -> impl Iterator<Item = &Connection> + '_ {
111        self.connected_users
112            .get(&user_id)
113            .into_iter()
114            .flat_map(|state| {
115                state
116                    .connection_ids
117                    .iter()
118                    .flat_map(|cid| self.connections.get(cid))
119            })
120    }
121
122    pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
123        self.connected_users
124            .get(&user_id)
125            .into_iter()
126            .flat_map(|state| &state.connection_ids)
127            .copied()
128    }
129
130    pub fn channel_user_ids(
131        &self,
132        channel_id: ChannelId,
133    ) -> impl Iterator<Item = (UserId, ChannelRole)> + '_ {
134        self.channels.users_to_notify(channel_id)
135    }
136
137    pub fn channel_connection_ids(
138        &self,
139        channel_id: ChannelId,
140    ) -> impl Iterator<Item = (ConnectionId, ChannelRole)> + '_ {
141        self.channels
142            .users_to_notify(channel_id)
143            .flat_map(|(user_id, role)| {
144                self.user_connection_ids(user_id)
145                    .map(move |connection_id| (connection_id, role))
146            })
147    }
148
149    pub fn subscribe_to_channel(
150        &mut self,
151        user_id: UserId,
152        channel_id: ChannelId,
153        role: ChannelRole,
154    ) {
155        self.channels.subscribe(user_id, channel_id, role);
156    }
157
158    pub fn unsubscribe_from_channel(&mut self, user_id: &UserId, channel_id: &ChannelId) {
159        self.channels.unsubscribe(user_id, channel_id);
160    }
161
162    pub fn is_user_online(&self, user_id: UserId) -> bool {
163        !self
164            .connected_users
165            .get(&user_id)
166            .unwrap_or(&Default::default())
167            .connection_ids
168            .is_empty()
169    }
170
171    #[cfg(test)]
172    pub fn check_invariants(&self) {
173        for (connection_id, connection) in &self.connections {
174            assert!(
175                self.connected_users
176                    .get(&connection.user_id)
177                    .unwrap()
178                    .connection_ids
179                    .contains(connection_id)
180            );
181        }
182
183        for (user_id, state) in &self.connected_users {
184            for connection_id in &state.connection_ids {
185                assert_eq!(
186                    self.connections.get(connection_id).unwrap().user_id,
187                    *user_id
188                );
189            }
190        }
191    }
192}
193
194#[derive(Default, Serialize)]
195pub struct ChannelPool {
196    by_user: HashMap<UserId, HashMap<ChannelId, ChannelRole>>,
197    by_channel: HashMap<ChannelId, HashSet<UserId>>,
198}
199
200impl ChannelPool {
201    pub fn clear(&mut self) {
202        self.by_user.clear();
203        self.by_channel.clear();
204    }
205
206    pub fn subscribe(&mut self, user_id: UserId, channel_id: ChannelId, role: ChannelRole) {
207        self.by_user
208            .entry(user_id)
209            .or_default()
210            .insert(channel_id, role);
211        self.by_channel
212            .entry(channel_id)
213            .or_default()
214            .insert(user_id);
215    }
216
217    pub fn unsubscribe(&mut self, user_id: &UserId, channel_id: &ChannelId) {
218        if let Some(channels) = self.by_user.get_mut(user_id) {
219            channels.remove(channel_id);
220            if channels.is_empty() {
221                self.by_user.remove(user_id);
222            }
223        }
224        if let Some(users) = self.by_channel.get_mut(channel_id) {
225            users.remove(user_id);
226            if users.is_empty() {
227                self.by_channel.remove(channel_id);
228            }
229        }
230    }
231
232    pub fn remove_user(&mut self, user_id: &UserId) {
233        if let Some(channels) = self.by_user.remove(user_id) {
234            for channel_id in channels.keys() {
235                self.unsubscribe(user_id, channel_id)
236            }
237        }
238    }
239
240    pub fn users_to_notify(
241        &self,
242        channel_id: ChannelId,
243    ) -> impl '_ + Iterator<Item = (UserId, ChannelRole)> {
244        self.by_channel
245            .get(&channel_id)
246            .into_iter()
247            .flat_map(move |users| {
248                users.iter().flat_map(move |user_id| {
249                    Some((
250                        *user_id,
251                        self.by_user
252                            .get(user_id)
253                            .and_then(|channels| channels.get(&channel_id))
254                            .copied()?,
255                    ))
256                })
257            })
258    }
259}