connection_pool.rs

  1use crate::db::{ChannelId, ChannelRole, DevServerId, PrincipalId, UserId};
  2use anyhow::{anyhow, Result};
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use semantic_version::SemanticVersion;
  6use serde::Serialize;
  7use std::fmt;
  8use tracing::instrument;
  9
 10#[derive(Default, Serialize)]
 11pub struct ConnectionPool {
 12    connections: BTreeMap<ConnectionId, Connection>,
 13    connected_users: BTreeMap<UserId, ConnectedPrincipal>,
 14    connected_dev_servers: BTreeMap<DevServerId, ConnectionId>,
 15    channels: ChannelPool,
 16    offline_dev_servers: HashSet<DevServerId>,
 17}
 18
 19#[derive(Default, Serialize)]
 20struct ConnectedPrincipal {
 21    connection_ids: HashSet<ConnectionId>,
 22}
 23
 24#[derive(Copy, Clone, Debug, Serialize, PartialOrd, PartialEq, Eq, Ord)]
 25pub struct ZedVersion(pub SemanticVersion);
 26
 27impl fmt::Display for ZedVersion {
 28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 29        write!(f, "{}", self.0)
 30    }
 31}
 32
 33impl ZedVersion {
 34    pub fn can_collaborate(&self) -> bool {
 35        self.0 >= SemanticVersion::new(0, 129, 2)
 36    }
 37
 38    pub fn with_save_as() -> ZedVersion {
 39        ZedVersion(SemanticVersion::new(0, 134, 0))
 40    }
 41
 42    pub fn with_list_directory() -> ZedVersion {
 43        ZedVersion(SemanticVersion::new(0, 145, 0))
 44    }
 45}
 46
 47pub trait VersionedMessage {
 48    fn required_host_version(&self) -> Option<ZedVersion> {
 49        None
 50    }
 51}
 52
 53impl VersionedMessage for proto::SaveBuffer {
 54    fn required_host_version(&self) -> Option<ZedVersion> {
 55        if self.new_path.is_some() {
 56            Some(ZedVersion::with_save_as())
 57        } else {
 58            None
 59        }
 60    }
 61}
 62
 63impl VersionedMessage for proto::OpenNewBuffer {
 64    fn required_host_version(&self) -> Option<ZedVersion> {
 65        Some(ZedVersion::with_save_as())
 66    }
 67}
 68
 69#[derive(Serialize)]
 70pub struct Connection {
 71    pub principal_id: PrincipalId,
 72    pub admin: bool,
 73    pub zed_version: ZedVersion,
 74}
 75
 76impl ConnectionPool {
 77    pub fn reset(&mut self) {
 78        self.connections.clear();
 79        self.connected_users.clear();
 80        self.connected_dev_servers.clear();
 81        self.channels.clear();
 82    }
 83
 84    pub fn connection(&mut self, connection_id: ConnectionId) -> Option<&Connection> {
 85        self.connections.get(&connection_id)
 86    }
 87
 88    #[instrument(skip(self))]
 89    pub fn add_connection(
 90        &mut self,
 91        connection_id: ConnectionId,
 92        user_id: UserId,
 93        admin: bool,
 94        zed_version: ZedVersion,
 95    ) {
 96        self.connections.insert(
 97            connection_id,
 98            Connection {
 99                principal_id: PrincipalId::UserId(user_id),
100                admin,
101                zed_version,
102            },
103        );
104        let connected_user = self.connected_users.entry(user_id).or_default();
105        connected_user.connection_ids.insert(connection_id);
106    }
107
108    pub fn add_dev_server(
109        &mut self,
110        connection_id: ConnectionId,
111        dev_server_id: DevServerId,
112        zed_version: ZedVersion,
113    ) {
114        self.connections.insert(
115            connection_id,
116            Connection {
117                principal_id: PrincipalId::DevServerId(dev_server_id),
118                admin: false,
119                zed_version,
120            },
121        );
122
123        self.connected_dev_servers
124            .insert(dev_server_id, connection_id);
125    }
126
127    #[instrument(skip(self))]
128    pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
129        let connection = self
130            .connections
131            .get_mut(&connection_id)
132            .ok_or_else(|| anyhow!("no such connection"))?;
133
134        match connection.principal_id {
135            PrincipalId::UserId(user_id) => {
136                let connected_user = self.connected_users.get_mut(&user_id).unwrap();
137                connected_user.connection_ids.remove(&connection_id);
138                if connected_user.connection_ids.is_empty() {
139                    self.connected_users.remove(&user_id);
140                    self.channels.remove_user(&user_id);
141                }
142            }
143            PrincipalId::DevServerId(dev_server_id) => {
144                self.connected_dev_servers.remove(&dev_server_id);
145                self.offline_dev_servers.remove(&dev_server_id);
146            }
147        }
148        self.connections.remove(&connection_id).unwrap();
149        Ok(())
150    }
151
152    pub fn set_dev_server_offline(&mut self, dev_server_id: DevServerId) {
153        self.offline_dev_servers.insert(dev_server_id);
154    }
155
156    pub fn connections(&self) -> impl Iterator<Item = &Connection> {
157        self.connections.values()
158    }
159
160    pub fn user_connections(&self, user_id: UserId) -> impl Iterator<Item = &Connection> + '_ {
161        self.connected_users
162            .get(&user_id)
163            .into_iter()
164            .flat_map(|state| {
165                state
166                    .connection_ids
167                    .iter()
168                    .flat_map(|cid| self.connections.get(cid))
169            })
170    }
171
172    pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
173        self.connected_users
174            .get(&user_id)
175            .into_iter()
176            .flat_map(|state| &state.connection_ids)
177            .copied()
178    }
179
180    pub fn dev_server_status(&self, dev_server_id: DevServerId) -> proto::DevServerStatus {
181        if self.dev_server_connection_id(dev_server_id).is_some()
182            && !self.offline_dev_servers.contains(&dev_server_id)
183        {
184            proto::DevServerStatus::Online
185        } else {
186            proto::DevServerStatus::Offline
187        }
188    }
189
190    pub fn dev_server_connection_id(&self, dev_server_id: DevServerId) -> Option<ConnectionId> {
191        self.connected_dev_servers.get(&dev_server_id).copied()
192    }
193
194    pub fn dev_server_connection_id_supporting(
195        &self,
196        dev_server_id: DevServerId,
197        required: ZedVersion,
198    ) -> Result<ConnectionId> {
199        match self.connected_dev_servers.get(&dev_server_id) {
200            Some(cid) if self.connections[cid].zed_version >= required => Ok(*cid),
201            Some(_) => Err(anyhow!(proto::ErrorCode::RemoteUpgradeRequired)),
202            None => Err(anyhow!(proto::ErrorCode::DevServerOffline)),
203        }
204    }
205
206    pub fn channel_user_ids(
207        &self,
208        channel_id: ChannelId,
209    ) -> impl Iterator<Item = (UserId, ChannelRole)> + '_ {
210        self.channels.users_to_notify(channel_id)
211    }
212
213    pub fn channel_connection_ids(
214        &self,
215        channel_id: ChannelId,
216    ) -> impl Iterator<Item = (ConnectionId, ChannelRole)> + '_ {
217        self.channels
218            .users_to_notify(channel_id)
219            .flat_map(|(user_id, role)| {
220                self.user_connection_ids(user_id)
221                    .map(move |connection_id| (connection_id, role))
222            })
223    }
224
225    pub fn subscribe_to_channel(
226        &mut self,
227        user_id: UserId,
228        channel_id: ChannelId,
229        role: ChannelRole,
230    ) {
231        self.channels.subscribe(user_id, channel_id, role);
232    }
233
234    pub fn unsubscribe_from_channel(&mut self, user_id: &UserId, channel_id: &ChannelId) {
235        self.channels.unsubscribe(user_id, channel_id);
236    }
237
238    pub fn is_user_online(&self, user_id: UserId) -> bool {
239        !self
240            .connected_users
241            .get(&user_id)
242            .unwrap_or(&Default::default())
243            .connection_ids
244            .is_empty()
245    }
246
247    #[cfg(test)]
248    pub fn check_invariants(&self) {
249        for (connection_id, connection) in &self.connections {
250            match &connection.principal_id {
251                PrincipalId::UserId(user_id) => {
252                    assert!(self
253                        .connected_users
254                        .get(user_id)
255                        .unwrap()
256                        .connection_ids
257                        .contains(connection_id));
258                }
259                PrincipalId::DevServerId(dev_server_id) => {
260                    assert_eq!(
261                        self.connected_dev_servers.get(&dev_server_id).unwrap(),
262                        connection_id
263                    );
264                }
265            }
266        }
267
268        for (user_id, state) in &self.connected_users {
269            for connection_id in &state.connection_ids {
270                assert_eq!(
271                    self.connections.get(connection_id).unwrap().principal_id,
272                    PrincipalId::UserId(*user_id)
273                );
274            }
275        }
276
277        for (dev_server_id, connection_id) in &self.connected_dev_servers {
278            assert_eq!(
279                self.connections.get(connection_id).unwrap().principal_id,
280                PrincipalId::DevServerId(*dev_server_id)
281            );
282        }
283    }
284}
285
286#[derive(Default, Serialize)]
287pub struct ChannelPool {
288    by_user: HashMap<UserId, HashMap<ChannelId, ChannelRole>>,
289    by_channel: HashMap<ChannelId, HashSet<UserId>>,
290}
291
292impl ChannelPool {
293    pub fn clear(&mut self) {
294        self.by_user.clear();
295        self.by_channel.clear();
296    }
297
298    pub fn subscribe(&mut self, user_id: UserId, channel_id: ChannelId, role: ChannelRole) {
299        self.by_user
300            .entry(user_id)
301            .or_default()
302            .insert(channel_id, role);
303        self.by_channel
304            .entry(channel_id)
305            .or_default()
306            .insert(user_id);
307    }
308
309    pub fn unsubscribe(&mut self, user_id: &UserId, channel_id: &ChannelId) {
310        if let Some(channels) = self.by_user.get_mut(user_id) {
311            channels.remove(channel_id);
312            if channels.is_empty() {
313                self.by_user.remove(user_id);
314            }
315        }
316        if let Some(users) = self.by_channel.get_mut(channel_id) {
317            users.remove(user_id);
318            if users.is_empty() {
319                self.by_channel.remove(channel_id);
320            }
321        }
322    }
323
324    pub fn remove_user(&mut self, user_id: &UserId) {
325        if let Some(channels) = self.by_user.remove(&user_id) {
326            for channel_id in channels.keys() {
327                self.unsubscribe(user_id, &channel_id)
328            }
329        }
330    }
331
332    pub fn users_to_notify(
333        &self,
334        channel_id: ChannelId,
335    ) -> impl '_ + Iterator<Item = (UserId, ChannelRole)> {
336        self.by_channel
337            .get(&channel_id)
338            .into_iter()
339            .flat_map(move |users| {
340                users.iter().flat_map(move |user_id| {
341                    Some((
342                        *user_id,
343                        self.by_user
344                            .get(user_id)
345                            .and_then(|channels| channels.get(&channel_id))
346                            .copied()?,
347                    ))
348                })
349            })
350    }
351}