connection_pool.rs

  1use crate::db::{ChannelId, ChannelRole, UserId};
  2use anyhow::{anyhow, Result};
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::ConnectionId;
  5use serde::Serialize;
  6use tracing::instrument;
  7use util::{semver, SemanticVersion};
  8
  9#[derive(Default, Serialize)]
 10pub struct ConnectionPool {
 11    connections: BTreeMap<ConnectionId, Connection>,
 12    connected_users: BTreeMap<UserId, ConnectedUser>,
 13    channels: ChannelPool,
 14}
 15
 16#[derive(Default, Serialize)]
 17struct ConnectedUser {
 18    connection_ids: HashSet<ConnectionId>,
 19}
 20
 21#[derive(Debug, Serialize)]
 22pub struct ZedVersion(pub SemanticVersion);
 23use std::fmt;
 24
 25impl fmt::Display for ZedVersion {
 26    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 27        write!(f, "{}", self.0)
 28    }
 29}
 30
 31impl ZedVersion {
 32    pub fn can_collaborate(&self) -> bool {
 33        self.0 >= semver(0, 127, 3)
 34    }
 35}
 36
 37#[derive(Serialize)]
 38pub struct Connection {
 39    pub user_id: UserId,
 40    pub admin: bool,
 41    pub zed_version: ZedVersion,
 42}
 43
 44impl ConnectionPool {
 45    pub fn reset(&mut self) {
 46        self.connections.clear();
 47        self.connected_users.clear();
 48        self.channels.clear();
 49    }
 50
 51    #[instrument(skip(self))]
 52    pub fn add_connection(
 53        &mut self,
 54        connection_id: ConnectionId,
 55        user_id: UserId,
 56        admin: bool,
 57        zed_version: ZedVersion,
 58    ) {
 59        self.connections.insert(
 60            connection_id,
 61            Connection {
 62                user_id,
 63                admin,
 64                zed_version,
 65            },
 66        );
 67        let connected_user = self.connected_users.entry(user_id).or_default();
 68        connected_user.connection_ids.insert(connection_id);
 69    }
 70
 71    #[instrument(skip(self))]
 72    pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
 73        let connection = self
 74            .connections
 75            .get_mut(&connection_id)
 76            .ok_or_else(|| anyhow!("no such connection"))?;
 77
 78        let user_id = connection.user_id;
 79        let connected_user = self.connected_users.get_mut(&user_id).unwrap();
 80        connected_user.connection_ids.remove(&connection_id);
 81        if connected_user.connection_ids.is_empty() {
 82            self.connected_users.remove(&user_id);
 83            self.channels.remove_user(&user_id);
 84        }
 85        self.connections.remove(&connection_id).unwrap();
 86        Ok(())
 87    }
 88
 89    pub fn connections(&self) -> impl Iterator<Item = &Connection> {
 90        self.connections.values()
 91    }
 92
 93    pub fn user_connections(&self, user_id: UserId) -> impl Iterator<Item = &Connection> + '_ {
 94        self.connected_users
 95            .get(&user_id)
 96            .into_iter()
 97            .flat_map(|state| {
 98                state
 99                    .connection_ids
100                    .iter()
101                    .flat_map(|cid| self.connections.get(cid))
102            })
103    }
104
105    pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
106        self.connected_users
107            .get(&user_id)
108            .into_iter()
109            .flat_map(|state| &state.connection_ids)
110            .copied()
111    }
112
113    pub fn channel_user_ids(
114        &self,
115        channel_id: ChannelId,
116    ) -> impl Iterator<Item = (UserId, ChannelRole)> + '_ {
117        self.channels.users_to_notify(channel_id)
118    }
119
120    pub fn channel_connection_ids(
121        &self,
122        channel_id: ChannelId,
123    ) -> impl Iterator<Item = (ConnectionId, ChannelRole)> + '_ {
124        self.channels
125            .users_to_notify(channel_id)
126            .flat_map(|(user_id, role)| {
127                self.user_connection_ids(user_id)
128                    .map(move |connection_id| (connection_id, role))
129            })
130    }
131
132    pub fn subscribe_to_channel(
133        &mut self,
134        user_id: UserId,
135        channel_id: ChannelId,
136        role: ChannelRole,
137    ) {
138        self.channels.subscribe(user_id, channel_id, role);
139    }
140
141    pub fn unsubscribe_from_channel(&mut self, user_id: &UserId, channel_id: &ChannelId) {
142        self.channels.unsubscribe(user_id, channel_id);
143    }
144
145    pub fn is_user_online(&self, user_id: UserId) -> bool {
146        !self
147            .connected_users
148            .get(&user_id)
149            .unwrap_or(&Default::default())
150            .connection_ids
151            .is_empty()
152    }
153
154    #[cfg(test)]
155    pub fn check_invariants(&self) {
156        for (connection_id, connection) in &self.connections {
157            assert!(self
158                .connected_users
159                .get(&connection.user_id)
160                .unwrap()
161                .connection_ids
162                .contains(connection_id));
163        }
164
165        for (user_id, state) in &self.connected_users {
166            for connection_id in &state.connection_ids {
167                assert_eq!(
168                    self.connections.get(connection_id).unwrap().user_id,
169                    *user_id
170                );
171            }
172        }
173    }
174}
175
176#[derive(Default, Serialize)]
177pub struct ChannelPool {
178    by_user: HashMap<UserId, HashMap<ChannelId, ChannelRole>>,
179    by_channel: HashMap<ChannelId, HashSet<UserId>>,
180}
181
182impl ChannelPool {
183    pub fn clear(&mut self) {
184        self.by_user.clear();
185        self.by_channel.clear();
186    }
187
188    pub fn subscribe(&mut self, user_id: UserId, channel_id: ChannelId, role: ChannelRole) {
189        self.by_user
190            .entry(user_id)
191            .or_default()
192            .insert(channel_id, role);
193        self.by_channel
194            .entry(channel_id)
195            .or_default()
196            .insert(user_id);
197    }
198
199    pub fn unsubscribe(&mut self, user_id: &UserId, channel_id: &ChannelId) {
200        if let Some(channels) = self.by_user.get_mut(user_id) {
201            channels.remove(channel_id);
202            if channels.is_empty() {
203                self.by_user.remove(user_id);
204            }
205        }
206        if let Some(users) = self.by_channel.get_mut(channel_id) {
207            users.remove(user_id);
208            if users.is_empty() {
209                self.by_channel.remove(channel_id);
210            }
211        }
212    }
213
214    pub fn remove_user(&mut self, user_id: &UserId) {
215        if let Some(channels) = self.by_user.remove(&user_id) {
216            for channel_id in channels.keys() {
217                self.unsubscribe(user_id, &channel_id)
218            }
219        }
220    }
221
222    pub fn users_to_notify(
223        &self,
224        channel_id: ChannelId,
225    ) -> impl '_ + Iterator<Item = (UserId, ChannelRole)> {
226        self.by_channel
227            .get(&channel_id)
228            .into_iter()
229            .flat_map(move |users| {
230                users.iter().flat_map(move |user_id| {
231                    Some((
232                        *user_id,
233                        self.by_user
234                            .get(user_id)
235                            .and_then(|channels| channels.get(&channel_id))
236                            .copied()?,
237                    ))
238                })
239            })
240    }
241}