connection_pool.rs

  1use crate::db::{ChannelId, ChannelRole, UserId};
  2use anyhow::{Context as _, Result};
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::ConnectionId;
  5use semantic_version::SemanticVersion;
  6use serde::Serialize;
  7use std::fmt;
  8use tracing::instrument;
  9
 10#[derive(Default, Serialize)]
 11pub struct ConnectionPool {
 12    connections: BTreeMap<ConnectionId, Connection>,
 13    connected_users: BTreeMap<UserId, ConnectedPrincipal>,
 14    channels: ChannelPool,
 15}
 16
 17#[derive(Default, Serialize)]
 18struct ConnectedPrincipal {
 19    connection_ids: HashSet<ConnectionId>,
 20}
 21
 22#[derive(Copy, Clone, Debug, Serialize, PartialOrd, PartialEq, Eq, Ord)]
 23pub struct ZedVersion(pub SemanticVersion);
 24
 25impl fmt::Display for ZedVersion {
 26    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 27        write!(f, "{}", self.0)
 28    }
 29}
 30
 31impl ZedVersion {
 32    pub fn can_collaborate(&self) -> bool {
 33        self.0 >= SemanticVersion::new(0, 157, 0)
 34    }
 35}
 36
 37#[derive(Serialize)]
 38pub struct Connection {
 39    pub user_id: UserId,
 40    pub admin: bool,
 41    pub zed_version: ZedVersion,
 42}
 43
 44impl ConnectionPool {
 45    pub fn reset(&mut self) {
 46        self.connections.clear();
 47        self.connected_users.clear();
 48        self.channels.clear();
 49    }
 50
 51    pub fn connection(&mut self, connection_id: ConnectionId) -> Option<&Connection> {
 52        self.connections.get(&connection_id)
 53    }
 54
 55    #[instrument(skip(self))]
 56    pub fn add_connection(
 57        &mut self,
 58        connection_id: ConnectionId,
 59        user_id: UserId,
 60        admin: bool,
 61        zed_version: ZedVersion,
 62    ) {
 63        self.connections.insert(
 64            connection_id,
 65            Connection {
 66                user_id,
 67                admin,
 68                zed_version,
 69            },
 70        );
 71        let connected_user = self.connected_users.entry(user_id).or_default();
 72        connected_user.connection_ids.insert(connection_id);
 73    }
 74
 75    #[instrument(skip(self))]
 76    pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
 77        let connection = self
 78            .connections
 79            .get_mut(&connection_id)
 80            .context("no such connection")?;
 81
 82        let user_id = connection.user_id;
 83
 84        let connected_user = self.connected_users.get_mut(&user_id).unwrap();
 85        connected_user.connection_ids.remove(&connection_id);
 86        if connected_user.connection_ids.is_empty() {
 87            self.connected_users.remove(&user_id);
 88            self.channels.remove_user(&user_id);
 89        };
 90        self.connections.remove(&connection_id).unwrap();
 91        Ok(())
 92    }
 93
 94    pub fn connections(&self) -> impl Iterator<Item = &Connection> {
 95        self.connections.values()
 96    }
 97
 98    pub fn user_connections(&self, user_id: UserId) -> impl Iterator<Item = &Connection> + '_ {
 99        self.connected_users
100            .get(&user_id)
101            .into_iter()
102            .flat_map(|state| {
103                state
104                    .connection_ids
105                    .iter()
106                    .flat_map(|cid| self.connections.get(cid))
107            })
108    }
109
110    pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
111        self.connected_users
112            .get(&user_id)
113            .into_iter()
114            .flat_map(|state| &state.connection_ids)
115            .copied()
116    }
117
118    pub fn channel_user_ids(
119        &self,
120        channel_id: ChannelId,
121    ) -> impl Iterator<Item = (UserId, ChannelRole)> + '_ {
122        self.channels.users_to_notify(channel_id)
123    }
124
125    pub fn channel_connection_ids(
126        &self,
127        channel_id: ChannelId,
128    ) -> impl Iterator<Item = (ConnectionId, ChannelRole)> + '_ {
129        self.channels
130            .users_to_notify(channel_id)
131            .flat_map(|(user_id, role)| {
132                self.user_connection_ids(user_id)
133                    .map(move |connection_id| (connection_id, role))
134            })
135    }
136
137    pub fn subscribe_to_channel(
138        &mut self,
139        user_id: UserId,
140        channel_id: ChannelId,
141        role: ChannelRole,
142    ) {
143        self.channels.subscribe(user_id, channel_id, role);
144    }
145
146    pub fn unsubscribe_from_channel(&mut self, user_id: &UserId, channel_id: &ChannelId) {
147        self.channels.unsubscribe(user_id, channel_id);
148    }
149
150    pub fn is_user_online(&self, user_id: UserId) -> bool {
151        !self
152            .connected_users
153            .get(&user_id)
154            .unwrap_or(&Default::default())
155            .connection_ids
156            .is_empty()
157    }
158
159    #[cfg(test)]
160    pub fn check_invariants(&self) {
161        for (connection_id, connection) in &self.connections {
162            assert!(
163                self.connected_users
164                    .get(&connection.user_id)
165                    .unwrap()
166                    .connection_ids
167                    .contains(connection_id)
168            );
169        }
170
171        for (user_id, state) in &self.connected_users {
172            for connection_id in &state.connection_ids {
173                assert_eq!(
174                    self.connections.get(connection_id).unwrap().user_id,
175                    *user_id
176                );
177            }
178        }
179    }
180}
181
182#[derive(Default, Serialize)]
183pub struct ChannelPool {
184    by_user: HashMap<UserId, HashMap<ChannelId, ChannelRole>>,
185    by_channel: HashMap<ChannelId, HashSet<UserId>>,
186}
187
188impl ChannelPool {
189    pub fn clear(&mut self) {
190        self.by_user.clear();
191        self.by_channel.clear();
192    }
193
194    pub fn subscribe(&mut self, user_id: UserId, channel_id: ChannelId, role: ChannelRole) {
195        self.by_user
196            .entry(user_id)
197            .or_default()
198            .insert(channel_id, role);
199        self.by_channel
200            .entry(channel_id)
201            .or_default()
202            .insert(user_id);
203    }
204
205    pub fn unsubscribe(&mut self, user_id: &UserId, channel_id: &ChannelId) {
206        if let Some(channels) = self.by_user.get_mut(user_id) {
207            channels.remove(channel_id);
208            if channels.is_empty() {
209                self.by_user.remove(user_id);
210            }
211        }
212        if let Some(users) = self.by_channel.get_mut(channel_id) {
213            users.remove(user_id);
214            if users.is_empty() {
215                self.by_channel.remove(channel_id);
216            }
217        }
218    }
219
220    pub fn remove_user(&mut self, user_id: &UserId) {
221        if let Some(channels) = self.by_user.remove(user_id) {
222            for channel_id in channels.keys() {
223                self.unsubscribe(user_id, channel_id)
224            }
225        }
226    }
227
228    pub fn users_to_notify(
229        &self,
230        channel_id: ChannelId,
231    ) -> impl '_ + Iterator<Item = (UserId, ChannelRole)> {
232        self.by_channel
233            .get(&channel_id)
234            .into_iter()
235            .flat_map(move |users| {
236                users.iter().flat_map(move |user_id| {
237                    Some((
238                        *user_id,
239                        self.by_user
240                            .get(user_id)
241                            .and_then(|channels| channels.get(&channel_id))
242                            .copied()?,
243                    ))
244                })
245            })
246    }
247}