connection_pool.rs

  1use crate::db::{ChannelId, ChannelRole, DevServerId, PrincipalId, UserId};
  2use anyhow::{anyhow, Result};
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use semantic_version::SemanticVersion;
  6use serde::Serialize;
  7use std::fmt;
  8use tracing::instrument;
  9
 10#[derive(Default, Serialize)]
 11pub struct ConnectionPool {
 12    connections: BTreeMap<ConnectionId, Connection>,
 13    connected_users: BTreeMap<UserId, ConnectedPrincipal>,
 14    connected_dev_servers: BTreeMap<DevServerId, ConnectionId>,
 15    channels: ChannelPool,
 16    offline_dev_servers: HashSet<DevServerId>,
 17}
 18
 19#[derive(Default, Serialize)]
 20struct ConnectedPrincipal {
 21    connection_ids: HashSet<ConnectionId>,
 22}
 23
 24#[derive(Copy, Clone, Debug, Serialize, PartialOrd, PartialEq, Eq, Ord)]
 25pub struct ZedVersion(pub SemanticVersion);
 26
 27impl fmt::Display for ZedVersion {
 28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 29        write!(f, "{}", self.0)
 30    }
 31}
 32
 33impl ZedVersion {
 34    pub fn can_collaborate(&self) -> bool {
 35        self.0 >= SemanticVersion::new(0, 134, 0)
 36    }
 37
 38    pub fn with_list_directory() -> ZedVersion {
 39        ZedVersion(SemanticVersion::new(0, 145, 0))
 40    }
 41
 42    pub fn with_search_candidates() -> ZedVersion {
 43        ZedVersion(SemanticVersion::new(0, 151, 0))
 44    }
 45}
 46
 47#[derive(Serialize)]
 48pub struct Connection {
 49    pub principal_id: PrincipalId,
 50    pub admin: bool,
 51    pub zed_version: ZedVersion,
 52}
 53
 54impl ConnectionPool {
 55    pub fn reset(&mut self) {
 56        self.connections.clear();
 57        self.connected_users.clear();
 58        self.connected_dev_servers.clear();
 59        self.channels.clear();
 60    }
 61
 62    pub fn connection(&mut self, connection_id: ConnectionId) -> Option<&Connection> {
 63        self.connections.get(&connection_id)
 64    }
 65
 66    #[instrument(skip(self))]
 67    pub fn add_connection(
 68        &mut self,
 69        connection_id: ConnectionId,
 70        user_id: UserId,
 71        admin: bool,
 72        zed_version: ZedVersion,
 73    ) {
 74        self.connections.insert(
 75            connection_id,
 76            Connection {
 77                principal_id: PrincipalId::UserId(user_id),
 78                admin,
 79                zed_version,
 80            },
 81        );
 82        let connected_user = self.connected_users.entry(user_id).or_default();
 83        connected_user.connection_ids.insert(connection_id);
 84    }
 85
 86    pub fn add_dev_server(
 87        &mut self,
 88        connection_id: ConnectionId,
 89        dev_server_id: DevServerId,
 90        zed_version: ZedVersion,
 91    ) {
 92        self.connections.insert(
 93            connection_id,
 94            Connection {
 95                principal_id: PrincipalId::DevServerId(dev_server_id),
 96                admin: false,
 97                zed_version,
 98            },
 99        );
100
101        self.connected_dev_servers
102            .insert(dev_server_id, connection_id);
103    }
104
105    #[instrument(skip(self))]
106    pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
107        let connection = self
108            .connections
109            .get_mut(&connection_id)
110            .ok_or_else(|| anyhow!("no such connection"))?;
111
112        match connection.principal_id {
113            PrincipalId::UserId(user_id) => {
114                let connected_user = self.connected_users.get_mut(&user_id).unwrap();
115                connected_user.connection_ids.remove(&connection_id);
116                if connected_user.connection_ids.is_empty() {
117                    self.connected_users.remove(&user_id);
118                    self.channels.remove_user(&user_id);
119                }
120            }
121            PrincipalId::DevServerId(dev_server_id) => {
122                self.connected_dev_servers.remove(&dev_server_id);
123                self.offline_dev_servers.remove(&dev_server_id);
124            }
125        }
126        self.connections.remove(&connection_id).unwrap();
127        Ok(())
128    }
129
130    pub fn set_dev_server_offline(&mut self, dev_server_id: DevServerId) {
131        self.offline_dev_servers.insert(dev_server_id);
132    }
133
134    pub fn connections(&self) -> impl Iterator<Item = &Connection> {
135        self.connections.values()
136    }
137
138    pub fn user_connections(&self, user_id: UserId) -> impl Iterator<Item = &Connection> + '_ {
139        self.connected_users
140            .get(&user_id)
141            .into_iter()
142            .flat_map(|state| {
143                state
144                    .connection_ids
145                    .iter()
146                    .flat_map(|cid| self.connections.get(cid))
147            })
148    }
149
150    pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
151        self.connected_users
152            .get(&user_id)
153            .into_iter()
154            .flat_map(|state| &state.connection_ids)
155            .copied()
156    }
157
158    pub fn dev_server_status(&self, dev_server_id: DevServerId) -> proto::DevServerStatus {
159        if self.dev_server_connection_id(dev_server_id).is_some()
160            && !self.offline_dev_servers.contains(&dev_server_id)
161        {
162            proto::DevServerStatus::Online
163        } else {
164            proto::DevServerStatus::Offline
165        }
166    }
167
168    pub fn dev_server_connection_id(&self, dev_server_id: DevServerId) -> Option<ConnectionId> {
169        self.connected_dev_servers.get(&dev_server_id).copied()
170    }
171
172    pub fn dev_server_connection_id_supporting(
173        &self,
174        dev_server_id: DevServerId,
175        required: ZedVersion,
176    ) -> Result<ConnectionId> {
177        match self.connected_dev_servers.get(&dev_server_id) {
178            Some(cid) if self.connections[cid].zed_version >= required => Ok(*cid),
179            Some(_) => Err(anyhow!(proto::ErrorCode::RemoteUpgradeRequired)),
180            None => Err(anyhow!(proto::ErrorCode::DevServerOffline)),
181        }
182    }
183
184    pub fn channel_user_ids(
185        &self,
186        channel_id: ChannelId,
187    ) -> impl Iterator<Item = (UserId, ChannelRole)> + '_ {
188        self.channels.users_to_notify(channel_id)
189    }
190
191    pub fn channel_connection_ids(
192        &self,
193        channel_id: ChannelId,
194    ) -> impl Iterator<Item = (ConnectionId, ChannelRole)> + '_ {
195        self.channels
196            .users_to_notify(channel_id)
197            .flat_map(|(user_id, role)| {
198                self.user_connection_ids(user_id)
199                    .map(move |connection_id| (connection_id, role))
200            })
201    }
202
203    pub fn subscribe_to_channel(
204        &mut self,
205        user_id: UserId,
206        channel_id: ChannelId,
207        role: ChannelRole,
208    ) {
209        self.channels.subscribe(user_id, channel_id, role);
210    }
211
212    pub fn unsubscribe_from_channel(&mut self, user_id: &UserId, channel_id: &ChannelId) {
213        self.channels.unsubscribe(user_id, channel_id);
214    }
215
216    pub fn is_user_online(&self, user_id: UserId) -> bool {
217        !self
218            .connected_users
219            .get(&user_id)
220            .unwrap_or(&Default::default())
221            .connection_ids
222            .is_empty()
223    }
224
225    #[cfg(test)]
226    pub fn check_invariants(&self) {
227        for (connection_id, connection) in &self.connections {
228            match &connection.principal_id {
229                PrincipalId::UserId(user_id) => {
230                    assert!(self
231                        .connected_users
232                        .get(user_id)
233                        .unwrap()
234                        .connection_ids
235                        .contains(connection_id));
236                }
237                PrincipalId::DevServerId(dev_server_id) => {
238                    assert_eq!(
239                        self.connected_dev_servers.get(&dev_server_id).unwrap(),
240                        connection_id
241                    );
242                }
243            }
244        }
245
246        for (user_id, state) in &self.connected_users {
247            for connection_id in &state.connection_ids {
248                assert_eq!(
249                    self.connections.get(connection_id).unwrap().principal_id,
250                    PrincipalId::UserId(*user_id)
251                );
252            }
253        }
254
255        for (dev_server_id, connection_id) in &self.connected_dev_servers {
256            assert_eq!(
257                self.connections.get(connection_id).unwrap().principal_id,
258                PrincipalId::DevServerId(*dev_server_id)
259            );
260        }
261    }
262}
263
264#[derive(Default, Serialize)]
265pub struct ChannelPool {
266    by_user: HashMap<UserId, HashMap<ChannelId, ChannelRole>>,
267    by_channel: HashMap<ChannelId, HashSet<UserId>>,
268}
269
270impl ChannelPool {
271    pub fn clear(&mut self) {
272        self.by_user.clear();
273        self.by_channel.clear();
274    }
275
276    pub fn subscribe(&mut self, user_id: UserId, channel_id: ChannelId, role: ChannelRole) {
277        self.by_user
278            .entry(user_id)
279            .or_default()
280            .insert(channel_id, role);
281        self.by_channel
282            .entry(channel_id)
283            .or_default()
284            .insert(user_id);
285    }
286
287    pub fn unsubscribe(&mut self, user_id: &UserId, channel_id: &ChannelId) {
288        if let Some(channels) = self.by_user.get_mut(user_id) {
289            channels.remove(channel_id);
290            if channels.is_empty() {
291                self.by_user.remove(user_id);
292            }
293        }
294        if let Some(users) = self.by_channel.get_mut(channel_id) {
295            users.remove(user_id);
296            if users.is_empty() {
297                self.by_channel.remove(channel_id);
298            }
299        }
300    }
301
302    pub fn remove_user(&mut self, user_id: &UserId) {
303        if let Some(channels) = self.by_user.remove(&user_id) {
304            for channel_id in channels.keys() {
305                self.unsubscribe(user_id, &channel_id)
306            }
307        }
308    }
309
310    pub fn users_to_notify(
311        &self,
312        channel_id: ChannelId,
313    ) -> impl '_ + Iterator<Item = (UserId, ChannelRole)> {
314        self.by_channel
315            .get(&channel_id)
316            .into_iter()
317            .flat_map(move |users| {
318                users.iter().flat_map(move |user_id| {
319                    Some((
320                        *user_id,
321                        self.by_user
322                            .get(user_id)
323                            .and_then(|channels| channels.get(&channel_id))
324                            .copied()?,
325                    ))
326                })
327            })
328    }
329}