connection_pool.rs

  1use crate::db::{ChannelId, ChannelRole, DevServerId, PrincipalId, UserId};
  2use anyhow::{anyhow, Result};
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use semantic_version::SemanticVersion;
  6use serde::Serialize;
  7use std::fmt;
  8use tracing::instrument;
  9
 10#[derive(Default, Serialize)]
 11pub struct ConnectionPool {
 12    connections: BTreeMap<ConnectionId, Connection>,
 13    connected_users: BTreeMap<UserId, ConnectedPrincipal>,
 14    connected_dev_servers: BTreeMap<DevServerId, ConnectionId>,
 15    channels: ChannelPool,
 16    offline_dev_servers: HashSet<DevServerId>,
 17}
 18
 19#[derive(Default, Serialize)]
 20struct ConnectedPrincipal {
 21    connection_ids: HashSet<ConnectionId>,
 22}
 23
 24#[derive(Debug, Serialize)]
 25pub struct ZedVersion(pub SemanticVersion);
 26
 27impl fmt::Display for ZedVersion {
 28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 29        write!(f, "{}", self.0)
 30    }
 31}
 32
 33impl ZedVersion {
 34    pub fn can_collaborate(&self) -> bool {
 35        self.0 >= SemanticVersion::new(0, 129, 2)
 36    }
 37}
 38
 39#[derive(Serialize)]
 40pub struct Connection {
 41    pub principal_id: PrincipalId,
 42    pub admin: bool,
 43    pub zed_version: ZedVersion,
 44}
 45
 46impl ConnectionPool {
 47    pub fn reset(&mut self) {
 48        self.connections.clear();
 49        self.connected_users.clear();
 50        self.channels.clear();
 51    }
 52
 53    #[instrument(skip(self))]
 54    pub fn add_connection(
 55        &mut self,
 56        connection_id: ConnectionId,
 57        user_id: UserId,
 58        admin: bool,
 59        zed_version: ZedVersion,
 60    ) {
 61        self.connections.insert(
 62            connection_id,
 63            Connection {
 64                principal_id: PrincipalId::UserId(user_id),
 65                admin,
 66                zed_version,
 67            },
 68        );
 69        let connected_user = self.connected_users.entry(user_id).or_default();
 70        connected_user.connection_ids.insert(connection_id);
 71    }
 72
 73    pub fn add_dev_server(
 74        &mut self,
 75        connection_id: ConnectionId,
 76        dev_server_id: DevServerId,
 77        zed_version: ZedVersion,
 78    ) {
 79        self.connections.insert(
 80            connection_id,
 81            Connection {
 82                principal_id: PrincipalId::DevServerId(dev_server_id),
 83                admin: false,
 84                zed_version,
 85            },
 86        );
 87
 88        self.connected_dev_servers
 89            .insert(dev_server_id, connection_id);
 90    }
 91
 92    #[instrument(skip(self))]
 93    pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
 94        let connection = self
 95            .connections
 96            .get_mut(&connection_id)
 97            .ok_or_else(|| anyhow!("no such connection"))?;
 98
 99        match connection.principal_id {
100            PrincipalId::UserId(user_id) => {
101                let connected_user = self.connected_users.get_mut(&user_id).unwrap();
102                connected_user.connection_ids.remove(&connection_id);
103                if connected_user.connection_ids.is_empty() {
104                    self.connected_users.remove(&user_id);
105                    self.channels.remove_user(&user_id);
106                }
107            }
108            PrincipalId::DevServerId(dev_server_id) => {
109                self.connected_dev_servers.remove(&dev_server_id);
110                self.offline_dev_servers.remove(&dev_server_id);
111            }
112        }
113        self.connections.remove(&connection_id).unwrap();
114        Ok(())
115    }
116
117    pub fn set_dev_server_offline(&mut self, dev_server_id: DevServerId) {
118        self.offline_dev_servers.insert(dev_server_id);
119    }
120
121    pub fn connections(&self) -> impl Iterator<Item = &Connection> {
122        self.connections.values()
123    }
124
125    pub fn user_connections(&self, user_id: UserId) -> impl Iterator<Item = &Connection> + '_ {
126        self.connected_users
127            .get(&user_id)
128            .into_iter()
129            .flat_map(|state| {
130                state
131                    .connection_ids
132                    .iter()
133                    .flat_map(|cid| self.connections.get(cid))
134            })
135    }
136
137    pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
138        self.connected_users
139            .get(&user_id)
140            .into_iter()
141            .flat_map(|state| &state.connection_ids)
142            .copied()
143    }
144
145    pub fn dev_server_status(&self, dev_server_id: DevServerId) -> proto::DevServerStatus {
146        if self.dev_server_connection_id(dev_server_id).is_some()
147            && !self.offline_dev_servers.contains(&dev_server_id)
148        {
149            proto::DevServerStatus::Online
150        } else {
151            proto::DevServerStatus::Offline
152        }
153    }
154
155    pub fn dev_server_connection_id(&self, dev_server_id: DevServerId) -> Option<ConnectionId> {
156        self.connected_dev_servers.get(&dev_server_id).copied()
157    }
158
159    pub fn channel_user_ids(
160        &self,
161        channel_id: ChannelId,
162    ) -> impl Iterator<Item = (UserId, ChannelRole)> + '_ {
163        self.channels.users_to_notify(channel_id)
164    }
165
166    pub fn channel_connection_ids(
167        &self,
168        channel_id: ChannelId,
169    ) -> impl Iterator<Item = (ConnectionId, ChannelRole)> + '_ {
170        self.channels
171            .users_to_notify(channel_id)
172            .flat_map(|(user_id, role)| {
173                self.user_connection_ids(user_id)
174                    .map(move |connection_id| (connection_id, role))
175            })
176    }
177
178    pub fn subscribe_to_channel(
179        &mut self,
180        user_id: UserId,
181        channel_id: ChannelId,
182        role: ChannelRole,
183    ) {
184        self.channels.subscribe(user_id, channel_id, role);
185    }
186
187    pub fn unsubscribe_from_channel(&mut self, user_id: &UserId, channel_id: &ChannelId) {
188        self.channels.unsubscribe(user_id, channel_id);
189    }
190
191    pub fn is_user_online(&self, user_id: UserId) -> bool {
192        !self
193            .connected_users
194            .get(&user_id)
195            .unwrap_or(&Default::default())
196            .connection_ids
197            .is_empty()
198    }
199
200    #[cfg(test)]
201    pub fn check_invariants(&self) {
202        for (connection_id, connection) in &self.connections {
203            match &connection.principal_id {
204                PrincipalId::UserId(user_id) => {
205                    assert!(self
206                        .connected_users
207                        .get(user_id)
208                        .unwrap()
209                        .connection_ids
210                        .contains(connection_id));
211                }
212                PrincipalId::DevServerId(dev_server_id) => {
213                    assert_eq!(
214                        self.connected_dev_servers.get(&dev_server_id).unwrap(),
215                        connection_id
216                    );
217                }
218            }
219        }
220
221        for (user_id, state) in &self.connected_users {
222            for connection_id in &state.connection_ids {
223                assert_eq!(
224                    self.connections.get(connection_id).unwrap().principal_id,
225                    PrincipalId::UserId(*user_id)
226                );
227            }
228        }
229
230        for (dev_server_id, connection_id) in &self.connected_dev_servers {
231            assert_eq!(
232                self.connections.get(connection_id).unwrap().principal_id,
233                PrincipalId::DevServerId(*dev_server_id)
234            );
235        }
236    }
237}
238
239#[derive(Default, Serialize)]
240pub struct ChannelPool {
241    by_user: HashMap<UserId, HashMap<ChannelId, ChannelRole>>,
242    by_channel: HashMap<ChannelId, HashSet<UserId>>,
243}
244
245impl ChannelPool {
246    pub fn clear(&mut self) {
247        self.by_user.clear();
248        self.by_channel.clear();
249    }
250
251    pub fn subscribe(&mut self, user_id: UserId, channel_id: ChannelId, role: ChannelRole) {
252        self.by_user
253            .entry(user_id)
254            .or_default()
255            .insert(channel_id, role);
256        self.by_channel
257            .entry(channel_id)
258            .or_default()
259            .insert(user_id);
260    }
261
262    pub fn unsubscribe(&mut self, user_id: &UserId, channel_id: &ChannelId) {
263        if let Some(channels) = self.by_user.get_mut(user_id) {
264            channels.remove(channel_id);
265            if channels.is_empty() {
266                self.by_user.remove(user_id);
267            }
268        }
269        if let Some(users) = self.by_channel.get_mut(channel_id) {
270            users.remove(user_id);
271            if users.is_empty() {
272                self.by_channel.remove(channel_id);
273            }
274        }
275    }
276
277    pub fn remove_user(&mut self, user_id: &UserId) {
278        if let Some(channels) = self.by_user.remove(&user_id) {
279            for channel_id in channels.keys() {
280                self.unsubscribe(user_id, &channel_id)
281            }
282        }
283    }
284
285    pub fn users_to_notify(
286        &self,
287        channel_id: ChannelId,
288    ) -> impl '_ + Iterator<Item = (UserId, ChannelRole)> {
289        self.by_channel
290            .get(&channel_id)
291            .into_iter()
292            .flat_map(move |users| {
293                users.iter().flat_map(move |user_id| {
294                    Some((
295                        *user_id,
296                        self.by_user
297                            .get(user_id)
298                            .and_then(|channels| channels.get(&channel_id))
299                            .copied()?,
300                    ))
301                })
302            })
303    }
304}