connection_pool.rs

  1use crate::db::{ChannelId, ChannelRole, UserId};
  2use anyhow::{anyhow, Result};
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::ConnectionId;
  5use serde::Serialize;
  6use tracing::instrument;
  7use util::SemanticVersion;
  8
  9#[derive(Default, Serialize)]
 10pub struct ConnectionPool {
 11    connections: BTreeMap<ConnectionId, Connection>,
 12    connected_users: BTreeMap<UserId, ConnectedUser>,
 13    channels: ChannelPool,
 14}
 15
 16#[derive(Default, Serialize)]
 17struct ConnectedUser {
 18    connection_ids: HashSet<ConnectionId>,
 19}
 20
 21#[derive(Debug, Serialize)]
 22pub struct ZedVersion(pub SemanticVersion);
 23use std::fmt;
 24
 25impl fmt::Display for ZedVersion {
 26    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 27        write!(f, "{}", self.0)
 28    }
 29}
 30
 31impl ZedVersion {
 32    pub fn is_supported(&self) -> bool {
 33        self.0 != SemanticVersion::new(0, 123, 0)
 34    }
 35    pub fn supports_talker_role(&self) -> bool {
 36        self.0 >= SemanticVersion::new(0, 125, 0)
 37    }
 38}
 39
 40#[derive(Serialize)]
 41pub struct Connection {
 42    pub user_id: UserId,
 43    pub admin: bool,
 44    pub zed_version: ZedVersion,
 45}
 46
 47impl ConnectionPool {
 48    pub fn reset(&mut self) {
 49        self.connections.clear();
 50        self.connected_users.clear();
 51        self.channels.clear();
 52    }
 53
 54    #[instrument(skip(self))]
 55    pub fn add_connection(
 56        &mut self,
 57        connection_id: ConnectionId,
 58        user_id: UserId,
 59        admin: bool,
 60        zed_version: ZedVersion,
 61    ) {
 62        self.connections.insert(
 63            connection_id,
 64            Connection {
 65                user_id,
 66                admin,
 67                zed_version,
 68            },
 69        );
 70        let connected_user = self.connected_users.entry(user_id).or_default();
 71        connected_user.connection_ids.insert(connection_id);
 72    }
 73
 74    #[instrument(skip(self))]
 75    pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
 76        let connection = self
 77            .connections
 78            .get_mut(&connection_id)
 79            .ok_or_else(|| anyhow!("no such connection"))?;
 80
 81        let user_id = connection.user_id;
 82        let connected_user = self.connected_users.get_mut(&user_id).unwrap();
 83        connected_user.connection_ids.remove(&connection_id);
 84        if connected_user.connection_ids.is_empty() {
 85            self.connected_users.remove(&user_id);
 86            self.channels.remove_user(&user_id);
 87        }
 88        self.connections.remove(&connection_id).unwrap();
 89        Ok(())
 90    }
 91
 92    pub fn connections(&self) -> impl Iterator<Item = &Connection> {
 93        self.connections.values()
 94    }
 95
 96    pub fn user_connections(&self, user_id: UserId) -> impl Iterator<Item = &Connection> + '_ {
 97        self.connected_users
 98            .get(&user_id)
 99            .into_iter()
100            .flat_map(|state| {
101                state
102                    .connection_ids
103                    .iter()
104                    .flat_map(|cid| self.connections.get(cid))
105            })
106    }
107
108    pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
109        self.connected_users
110            .get(&user_id)
111            .into_iter()
112            .flat_map(|state| &state.connection_ids)
113            .copied()
114    }
115
116    pub fn channel_user_ids(
117        &self,
118        channel_id: ChannelId,
119    ) -> impl Iterator<Item = (UserId, ChannelRole)> + '_ {
120        self.channels.users_to_notify(channel_id)
121    }
122
123    pub fn channel_connection_ids(
124        &self,
125        channel_id: ChannelId,
126    ) -> impl Iterator<Item = (ConnectionId, ChannelRole)> + '_ {
127        self.channels
128            .users_to_notify(channel_id)
129            .flat_map(|(user_id, role)| {
130                self.user_connection_ids(user_id)
131                    .map(move |connection_id| (connection_id, role))
132            })
133    }
134
135    pub fn subscribe_to_channel(
136        &mut self,
137        user_id: UserId,
138        channel_id: ChannelId,
139        role: ChannelRole,
140    ) {
141        self.channels.subscribe(user_id, channel_id, role);
142    }
143
144    pub fn unsubscribe_from_channel(&mut self, user_id: &UserId, channel_id: &ChannelId) {
145        self.channels.unsubscribe(user_id, channel_id);
146    }
147
148    pub fn is_user_online(&self, user_id: UserId) -> bool {
149        !self
150            .connected_users
151            .get(&user_id)
152            .unwrap_or(&Default::default())
153            .connection_ids
154            .is_empty()
155    }
156
157    #[cfg(test)]
158    pub fn check_invariants(&self) {
159        for (connection_id, connection) in &self.connections {
160            assert!(self
161                .connected_users
162                .get(&connection.user_id)
163                .unwrap()
164                .connection_ids
165                .contains(connection_id));
166        }
167
168        for (user_id, state) in &self.connected_users {
169            for connection_id in &state.connection_ids {
170                assert_eq!(
171                    self.connections.get(connection_id).unwrap().user_id,
172                    *user_id
173                );
174            }
175        }
176    }
177}
178
179#[derive(Default, Serialize)]
180pub struct ChannelPool {
181    by_user: HashMap<UserId, HashMap<ChannelId, ChannelRole>>,
182    by_channel: HashMap<ChannelId, HashSet<UserId>>,
183}
184
185impl ChannelPool {
186    pub fn clear(&mut self) {
187        self.by_user.clear();
188        self.by_channel.clear();
189    }
190
191    pub fn subscribe(&mut self, user_id: UserId, channel_id: ChannelId, role: ChannelRole) {
192        self.by_user
193            .entry(user_id)
194            .or_default()
195            .insert(channel_id, role);
196        self.by_channel
197            .entry(channel_id)
198            .or_default()
199            .insert(user_id);
200    }
201
202    pub fn unsubscribe(&mut self, user_id: &UserId, channel_id: &ChannelId) {
203        if let Some(channels) = self.by_user.get_mut(user_id) {
204            channels.remove(channel_id);
205            if channels.is_empty() {
206                self.by_user.remove(user_id);
207            }
208        }
209        if let Some(users) = self.by_channel.get_mut(channel_id) {
210            users.remove(user_id);
211            if users.is_empty() {
212                self.by_channel.remove(channel_id);
213            }
214        }
215    }
216
217    pub fn remove_user(&mut self, user_id: &UserId) {
218        if let Some(channels) = self.by_user.remove(&user_id) {
219            for channel_id in channels.keys() {
220                self.unsubscribe(user_id, &channel_id)
221            }
222        }
223    }
224
225    pub fn users_to_notify(
226        &self,
227        channel_id: ChannelId,
228    ) -> impl '_ + Iterator<Item = (UserId, ChannelRole)> {
229        self.by_channel
230            .get(&channel_id)
231            .into_iter()
232            .flat_map(move |users| {
233                users.iter().flat_map(move |user_id| {
234                    Some((
235                        *user_id,
236                        self.by_user
237                            .get(user_id)
238                            .and_then(|channels| channels.get(&channel_id))
239                            .copied()?,
240                    ))
241                })
242            })
243    }
244}