connection_pool.rs

  1use crate::db::{ChannelId, ChannelRole, UserId};
  2use anyhow::{Context as _, Result};
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::ConnectionId;
  5use semver::Version;
  6use serde::Serialize;
  7use std::fmt;
  8use tracing::instrument;
  9
 10#[derive(Default, Serialize)]
 11pub struct ConnectionPool {
 12    connections: BTreeMap<ConnectionId, Connection>,
 13    connected_users: BTreeMap<UserId, ConnectedPrincipal>,
 14    channels: ChannelPool,
 15}
 16
 17#[derive(Default, Serialize)]
 18struct ConnectedPrincipal {
 19    connection_ids: HashSet<ConnectionId>,
 20}
 21
 22#[derive(Clone, Debug, Serialize, PartialOrd, PartialEq, Eq, Ord)]
 23pub struct ZedVersion(pub Version);
 24
 25impl fmt::Display for ZedVersion {
 26    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 27        write!(f, "{}", self.0)
 28    }
 29}
 30
 31impl ZedVersion {
 32    pub fn can_collaborate(&self) -> bool {
 33        // v0.220.0 was the version in which we've updated the project search protocol.
 34        self.0 >= Version::new(0, 220, 0)
 35    }
 36}
 37
 38#[derive(Serialize)]
 39pub struct Connection {
 40    pub user_id: UserId,
 41    pub admin: bool,
 42    pub zed_version: ZedVersion,
 43}
 44
 45impl ConnectionPool {
 46    pub fn reset(&mut self) {
 47        self.connections.clear();
 48        self.connected_users.clear();
 49        self.channels.clear();
 50    }
 51
 52    pub fn connection(&mut self, connection_id: ConnectionId) -> Option<&Connection> {
 53        self.connections.get(&connection_id)
 54    }
 55
 56    #[instrument(skip(self))]
 57    pub fn add_connection(
 58        &mut self,
 59        connection_id: ConnectionId,
 60        user_id: UserId,
 61        admin: bool,
 62        zed_version: ZedVersion,
 63    ) {
 64        self.connections.insert(
 65            connection_id,
 66            Connection {
 67                user_id,
 68                admin,
 69                zed_version,
 70            },
 71        );
 72        let connected_user = self.connected_users.entry(user_id).or_default();
 73        connected_user.connection_ids.insert(connection_id);
 74    }
 75
 76    #[instrument(skip(self))]
 77    pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
 78        let connection = self
 79            .connections
 80            .get_mut(&connection_id)
 81            .context("no such connection")?;
 82
 83        let user_id = connection.user_id;
 84
 85        let connected_user = self.connected_users.get_mut(&user_id).unwrap();
 86        connected_user.connection_ids.remove(&connection_id);
 87        if connected_user.connection_ids.is_empty() {
 88            self.connected_users.remove(&user_id);
 89            self.channels.remove_user(&user_id);
 90        };
 91        self.connections.remove(&connection_id).unwrap();
 92        Ok(())
 93    }
 94
 95    pub fn connections(&self) -> impl Iterator<Item = &Connection> {
 96        self.connections.values()
 97    }
 98
 99    pub fn user_connections(&self, user_id: UserId) -> impl Iterator<Item = &Connection> + '_ {
100        self.connected_users
101            .get(&user_id)
102            .into_iter()
103            .flat_map(|state| {
104                state
105                    .connection_ids
106                    .iter()
107                    .flat_map(|cid| self.connections.get(cid))
108            })
109    }
110
111    pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
112        self.connected_users
113            .get(&user_id)
114            .into_iter()
115            .flat_map(|state| &state.connection_ids)
116            .copied()
117    }
118
119    pub fn channel_user_ids(
120        &self,
121        channel_id: ChannelId,
122    ) -> impl Iterator<Item = (UserId, ChannelRole)> + '_ {
123        self.channels.users_to_notify(channel_id)
124    }
125
126    pub fn channel_connection_ids(
127        &self,
128        channel_id: ChannelId,
129    ) -> impl Iterator<Item = (ConnectionId, ChannelRole)> + '_ {
130        self.channels
131            .users_to_notify(channel_id)
132            .flat_map(|(user_id, role)| {
133                self.user_connection_ids(user_id)
134                    .map(move |connection_id| (connection_id, role))
135            })
136    }
137
138    pub fn subscribe_to_channel(
139        &mut self,
140        user_id: UserId,
141        channel_id: ChannelId,
142        role: ChannelRole,
143    ) {
144        self.channels.subscribe(user_id, channel_id, role);
145    }
146
147    pub fn unsubscribe_from_channel(&mut self, user_id: &UserId, channel_id: &ChannelId) {
148        self.channels.unsubscribe(user_id, channel_id);
149    }
150
151    pub fn is_user_online(&self, user_id: UserId) -> bool {
152        !self
153            .connected_users
154            .get(&user_id)
155            .unwrap_or(&Default::default())
156            .connection_ids
157            .is_empty()
158    }
159
160    #[cfg(test)]
161    pub fn check_invariants(&self) {
162        for (connection_id, connection) in &self.connections {
163            assert!(
164                self.connected_users
165                    .get(&connection.user_id)
166                    .unwrap()
167                    .connection_ids
168                    .contains(connection_id)
169            );
170        }
171
172        for (user_id, state) in &self.connected_users {
173            for connection_id in &state.connection_ids {
174                assert_eq!(
175                    self.connections.get(connection_id).unwrap().user_id,
176                    *user_id
177                );
178            }
179        }
180    }
181}
182
183#[derive(Default, Serialize)]
184pub struct ChannelPool {
185    by_user: HashMap<UserId, HashMap<ChannelId, ChannelRole>>,
186    by_channel: HashMap<ChannelId, HashSet<UserId>>,
187}
188
189impl ChannelPool {
190    pub fn clear(&mut self) {
191        self.by_user.clear();
192        self.by_channel.clear();
193    }
194
195    pub fn subscribe(&mut self, user_id: UserId, channel_id: ChannelId, role: ChannelRole) {
196        self.by_user
197            .entry(user_id)
198            .or_default()
199            .insert(channel_id, role);
200        self.by_channel
201            .entry(channel_id)
202            .or_default()
203            .insert(user_id);
204    }
205
206    pub fn unsubscribe(&mut self, user_id: &UserId, channel_id: &ChannelId) {
207        if let Some(channels) = self.by_user.get_mut(user_id) {
208            channels.remove(channel_id);
209            if channels.is_empty() {
210                self.by_user.remove(user_id);
211            }
212        }
213        if let Some(users) = self.by_channel.get_mut(channel_id) {
214            users.remove(user_id);
215            if users.is_empty() {
216                self.by_channel.remove(channel_id);
217            }
218        }
219    }
220
221    pub fn remove_user(&mut self, user_id: &UserId) {
222        if let Some(channels) = self.by_user.remove(user_id) {
223            for channel_id in channels.keys() {
224                self.unsubscribe(user_id, channel_id)
225            }
226        }
227    }
228
229    pub fn users_to_notify(
230        &self,
231        channel_id: ChannelId,
232    ) -> impl '_ + Iterator<Item = (UserId, ChannelRole)> {
233        self.by_channel
234            .get(&channel_id)
235            .into_iter()
236            .flat_map(move |users| {
237                users.iter().flat_map(move |user_id| {
238                    Some((
239                        *user_id,
240                        self.by_user
241                            .get(user_id)
242                            .and_then(|channels| channels.get(&channel_id))
243                            .copied()?,
244                    ))
245                })
246            })
247    }
248}