connection_pool.rs

  1use crate::db::{ChannelId, ChannelRole, DevServerId, PrincipalId, UserId};
  2use anyhow::{anyhow, Result};
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use semantic_version::SemanticVersion;
  6use serde::Serialize;
  7use std::fmt;
  8use tracing::instrument;
  9
 10#[derive(Default, Serialize)]
 11pub struct ConnectionPool {
 12    connections: BTreeMap<ConnectionId, Connection>,
 13    connected_users: BTreeMap<UserId, ConnectedPrincipal>,
 14    connected_dev_servers: BTreeMap<DevServerId, ConnectionId>,
 15    channels: ChannelPool,
 16}
 17
 18#[derive(Default, Serialize)]
 19struct ConnectedPrincipal {
 20    connection_ids: HashSet<ConnectionId>,
 21}
 22
 23#[derive(Debug, Serialize)]
 24pub struct ZedVersion(pub SemanticVersion);
 25
 26impl fmt::Display for ZedVersion {
 27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 28        write!(f, "{}", self.0)
 29    }
 30}
 31
 32impl ZedVersion {
 33    pub fn can_collaborate(&self) -> bool {
 34        self.0 >= SemanticVersion::new(0, 129, 2)
 35    }
 36}
 37
 38#[derive(Serialize)]
 39pub struct Connection {
 40    pub principal_id: PrincipalId,
 41    pub admin: bool,
 42    pub zed_version: ZedVersion,
 43}
 44
 45impl ConnectionPool {
 46    pub fn reset(&mut self) {
 47        self.connections.clear();
 48        self.connected_users.clear();
 49        self.channels.clear();
 50    }
 51
 52    #[instrument(skip(self))]
 53    pub fn add_connection(
 54        &mut self,
 55        connection_id: ConnectionId,
 56        user_id: UserId,
 57        admin: bool,
 58        zed_version: ZedVersion,
 59    ) {
 60        self.connections.insert(
 61            connection_id,
 62            Connection {
 63                principal_id: PrincipalId::UserId(user_id),
 64                admin,
 65                zed_version,
 66            },
 67        );
 68        let connected_user = self.connected_users.entry(user_id).or_default();
 69        connected_user.connection_ids.insert(connection_id);
 70    }
 71
 72    pub fn add_dev_server(
 73        &mut self,
 74        connection_id: ConnectionId,
 75        dev_server_id: DevServerId,
 76        zed_version: ZedVersion,
 77    ) {
 78        self.connections.insert(
 79            connection_id,
 80            Connection {
 81                principal_id: PrincipalId::DevServerId(dev_server_id),
 82                admin: false,
 83                zed_version,
 84            },
 85        );
 86
 87        self.connected_dev_servers
 88            .insert(dev_server_id, connection_id);
 89    }
 90
 91    #[instrument(skip(self))]
 92    pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
 93        let connection = self
 94            .connections
 95            .get_mut(&connection_id)
 96            .ok_or_else(|| anyhow!("no such connection"))?;
 97
 98        match connection.principal_id {
 99            PrincipalId::UserId(user_id) => {
100                let connected_user = self.connected_users.get_mut(&user_id).unwrap();
101                connected_user.connection_ids.remove(&connection_id);
102                if connected_user.connection_ids.is_empty() {
103                    self.connected_users.remove(&user_id);
104                    self.channels.remove_user(&user_id);
105                }
106            }
107            PrincipalId::DevServerId(dev_server_id) => {
108                self.connected_dev_servers.remove(&dev_server_id);
109            }
110        }
111        self.connections.remove(&connection_id).unwrap();
112        Ok(())
113    }
114
115    pub fn connections(&self) -> impl Iterator<Item = &Connection> {
116        self.connections.values()
117    }
118
119    pub fn user_connections(&self, user_id: UserId) -> impl Iterator<Item = &Connection> + '_ {
120        self.connected_users
121            .get(&user_id)
122            .into_iter()
123            .flat_map(|state| {
124                state
125                    .connection_ids
126                    .iter()
127                    .flat_map(|cid| self.connections.get(cid))
128            })
129    }
130
131    pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
132        self.connected_users
133            .get(&user_id)
134            .into_iter()
135            .flat_map(|state| &state.connection_ids)
136            .copied()
137    }
138
139    pub fn dev_server_status(&self, dev_server_id: DevServerId) -> proto::DevServerStatus {
140        if self.dev_server_connection_id(dev_server_id).is_some() {
141            proto::DevServerStatus::Online
142        } else {
143            proto::DevServerStatus::Offline
144        }
145    }
146
147    pub fn dev_server_connection_id(&self, dev_server_id: DevServerId) -> Option<ConnectionId> {
148        self.connected_dev_servers.get(&dev_server_id).copied()
149    }
150
151    pub fn channel_user_ids(
152        &self,
153        channel_id: ChannelId,
154    ) -> impl Iterator<Item = (UserId, ChannelRole)> + '_ {
155        self.channels.users_to_notify(channel_id)
156    }
157
158    pub fn channel_connection_ids(
159        &self,
160        channel_id: ChannelId,
161    ) -> impl Iterator<Item = (ConnectionId, ChannelRole)> + '_ {
162        self.channels
163            .users_to_notify(channel_id)
164            .flat_map(|(user_id, role)| {
165                self.user_connection_ids(user_id)
166                    .map(move |connection_id| (connection_id, role))
167            })
168    }
169
170    pub fn subscribe_to_channel(
171        &mut self,
172        user_id: UserId,
173        channel_id: ChannelId,
174        role: ChannelRole,
175    ) {
176        self.channels.subscribe(user_id, channel_id, role);
177    }
178
179    pub fn unsubscribe_from_channel(&mut self, user_id: &UserId, channel_id: &ChannelId) {
180        self.channels.unsubscribe(user_id, channel_id);
181    }
182
183    pub fn is_user_online(&self, user_id: UserId) -> bool {
184        !self
185            .connected_users
186            .get(&user_id)
187            .unwrap_or(&Default::default())
188            .connection_ids
189            .is_empty()
190    }
191
192    #[cfg(test)]
193    pub fn check_invariants(&self) {
194        for (connection_id, connection) in &self.connections {
195            match &connection.principal_id {
196                PrincipalId::UserId(user_id) => {
197                    assert!(self
198                        .connected_users
199                        .get(user_id)
200                        .unwrap()
201                        .connection_ids
202                        .contains(connection_id));
203                }
204                PrincipalId::DevServerId(dev_server_id) => {
205                    assert_eq!(
206                        self.connected_dev_servers.get(&dev_server_id).unwrap(),
207                        connection_id
208                    );
209                }
210            }
211        }
212
213        for (user_id, state) in &self.connected_users {
214            for connection_id in &state.connection_ids {
215                assert_eq!(
216                    self.connections.get(connection_id).unwrap().principal_id,
217                    PrincipalId::UserId(*user_id)
218                );
219            }
220        }
221
222        for (dev_server_id, connection_id) in &self.connected_dev_servers {
223            assert_eq!(
224                self.connections.get(connection_id).unwrap().principal_id,
225                PrincipalId::DevServerId(*dev_server_id)
226            );
227        }
228    }
229}
230
231#[derive(Default, Serialize)]
232pub struct ChannelPool {
233    by_user: HashMap<UserId, HashMap<ChannelId, ChannelRole>>,
234    by_channel: HashMap<ChannelId, HashSet<UserId>>,
235}
236
237impl ChannelPool {
238    pub fn clear(&mut self) {
239        self.by_user.clear();
240        self.by_channel.clear();
241    }
242
243    pub fn subscribe(&mut self, user_id: UserId, channel_id: ChannelId, role: ChannelRole) {
244        self.by_user
245            .entry(user_id)
246            .or_default()
247            .insert(channel_id, role);
248        self.by_channel
249            .entry(channel_id)
250            .or_default()
251            .insert(user_id);
252    }
253
254    pub fn unsubscribe(&mut self, user_id: &UserId, channel_id: &ChannelId) {
255        if let Some(channels) = self.by_user.get_mut(user_id) {
256            channels.remove(channel_id);
257            if channels.is_empty() {
258                self.by_user.remove(user_id);
259            }
260        }
261        if let Some(users) = self.by_channel.get_mut(channel_id) {
262            users.remove(user_id);
263            if users.is_empty() {
264                self.by_channel.remove(channel_id);
265            }
266        }
267    }
268
269    pub fn remove_user(&mut self, user_id: &UserId) {
270        if let Some(channels) = self.by_user.remove(&user_id) {
271            for channel_id in channels.keys() {
272                self.unsubscribe(user_id, &channel_id)
273            }
274        }
275    }
276
277    pub fn users_to_notify(
278        &self,
279        channel_id: ChannelId,
280    ) -> impl '_ + Iterator<Item = (UserId, ChannelRole)> {
281        self.by_channel
282            .get(&channel_id)
283            .into_iter()
284            .flat_map(move |users| {
285                users.iter().flat_map(move |user_id| {
286                    Some((
287                        *user_id,
288                        self.by_user
289                            .get(user_id)
290                            .and_then(|channels| channels.get(&channel_id))
291                            .copied()?,
292                    ))
293                })
294            })
295    }
296}