connection_pool.rs

  1use crate::db::{ChannelId, ChannelRole, DevServerId, PrincipalId, UserId};
  2use anyhow::{anyhow, Result};
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use semantic_version::SemanticVersion;
  6use serde::Serialize;
  7use std::fmt;
  8use tracing::instrument;
  9
 10#[derive(Default, Serialize)]
 11pub struct ConnectionPool {
 12    connections: BTreeMap<ConnectionId, Connection>,
 13    connected_users: BTreeMap<UserId, ConnectedPrincipal>,
 14    connected_dev_servers: BTreeMap<DevServerId, ConnectionId>,
 15    channels: ChannelPool,
 16    offline_dev_servers: HashSet<DevServerId>,
 17}
 18
 19#[derive(Default, Serialize)]
 20struct ConnectedPrincipal {
 21    connection_ids: HashSet<ConnectionId>,
 22}
 23
 24#[derive(Copy, Clone, Debug, Serialize, PartialOrd, PartialEq, Eq, Ord)]
 25pub struct ZedVersion(pub SemanticVersion);
 26
 27impl fmt::Display for ZedVersion {
 28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 29        write!(f, "{}", self.0)
 30    }
 31}
 32
 33impl ZedVersion {
 34    pub fn can_collaborate(&self) -> bool {
 35        self.0 >= SemanticVersion::new(0, 129, 2)
 36    }
 37
 38    pub fn with_save_as() -> ZedVersion {
 39        ZedVersion(SemanticVersion::new(0, 134, 0))
 40    }
 41}
 42
 43pub trait VersionedMessage {
 44    fn required_host_version(&self) -> Option<ZedVersion> {
 45        None
 46    }
 47}
 48
 49impl VersionedMessage for proto::SaveBuffer {
 50    fn required_host_version(&self) -> Option<ZedVersion> {
 51        if self.new_path.is_some() {
 52            Some(ZedVersion::with_save_as())
 53        } else {
 54            None
 55        }
 56    }
 57}
 58
 59impl VersionedMessage for proto::OpenNewBuffer {
 60    fn required_host_version(&self) -> Option<ZedVersion> {
 61        Some(ZedVersion::with_save_as())
 62    }
 63}
 64
 65#[derive(Serialize)]
 66pub struct Connection {
 67    pub principal_id: PrincipalId,
 68    pub admin: bool,
 69    pub zed_version: ZedVersion,
 70}
 71
 72impl ConnectionPool {
 73    pub fn reset(&mut self) {
 74        self.connections.clear();
 75        self.connected_users.clear();
 76        self.connected_dev_servers.clear();
 77        self.channels.clear();
 78    }
 79
 80    pub fn connection(&mut self, connection_id: ConnectionId) -> Option<&Connection> {
 81        self.connections.get(&connection_id)
 82    }
 83
 84    #[instrument(skip(self))]
 85    pub fn add_connection(
 86        &mut self,
 87        connection_id: ConnectionId,
 88        user_id: UserId,
 89        admin: bool,
 90        zed_version: ZedVersion,
 91    ) {
 92        self.connections.insert(
 93            connection_id,
 94            Connection {
 95                principal_id: PrincipalId::UserId(user_id),
 96                admin,
 97                zed_version,
 98            },
 99        );
100        let connected_user = self.connected_users.entry(user_id).or_default();
101        connected_user.connection_ids.insert(connection_id);
102    }
103
104    pub fn add_dev_server(
105        &mut self,
106        connection_id: ConnectionId,
107        dev_server_id: DevServerId,
108        zed_version: ZedVersion,
109    ) {
110        self.connections.insert(
111            connection_id,
112            Connection {
113                principal_id: PrincipalId::DevServerId(dev_server_id),
114                admin: false,
115                zed_version,
116            },
117        );
118
119        self.connected_dev_servers
120            .insert(dev_server_id, connection_id);
121    }
122
123    #[instrument(skip(self))]
124    pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
125        let connection = self
126            .connections
127            .get_mut(&connection_id)
128            .ok_or_else(|| anyhow!("no such connection"))?;
129
130        match connection.principal_id {
131            PrincipalId::UserId(user_id) => {
132                let connected_user = self.connected_users.get_mut(&user_id).unwrap();
133                connected_user.connection_ids.remove(&connection_id);
134                if connected_user.connection_ids.is_empty() {
135                    self.connected_users.remove(&user_id);
136                    self.channels.remove_user(&user_id);
137                }
138            }
139            PrincipalId::DevServerId(dev_server_id) => {
140                self.connected_dev_servers.remove(&dev_server_id);
141                self.offline_dev_servers.remove(&dev_server_id);
142            }
143        }
144        self.connections.remove(&connection_id).unwrap();
145        Ok(())
146    }
147
148    pub fn set_dev_server_offline(&mut self, dev_server_id: DevServerId) {
149        self.offline_dev_servers.insert(dev_server_id);
150    }
151
152    pub fn connections(&self) -> impl Iterator<Item = &Connection> {
153        self.connections.values()
154    }
155
156    pub fn user_connections(&self, user_id: UserId) -> impl Iterator<Item = &Connection> + '_ {
157        self.connected_users
158            .get(&user_id)
159            .into_iter()
160            .flat_map(|state| {
161                state
162                    .connection_ids
163                    .iter()
164                    .flat_map(|cid| self.connections.get(cid))
165            })
166    }
167
168    pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
169        self.connected_users
170            .get(&user_id)
171            .into_iter()
172            .flat_map(|state| &state.connection_ids)
173            .copied()
174    }
175
176    pub fn dev_server_status(&self, dev_server_id: DevServerId) -> proto::DevServerStatus {
177        if self.dev_server_connection_id(dev_server_id).is_some()
178            && !self.offline_dev_servers.contains(&dev_server_id)
179        {
180            proto::DevServerStatus::Online
181        } else {
182            proto::DevServerStatus::Offline
183        }
184    }
185
186    pub fn dev_server_connection_id(&self, dev_server_id: DevServerId) -> Option<ConnectionId> {
187        self.connected_dev_servers.get(&dev_server_id).copied()
188    }
189
190    pub fn channel_user_ids(
191        &self,
192        channel_id: ChannelId,
193    ) -> impl Iterator<Item = (UserId, ChannelRole)> + '_ {
194        self.channels.users_to_notify(channel_id)
195    }
196
197    pub fn channel_connection_ids(
198        &self,
199        channel_id: ChannelId,
200    ) -> impl Iterator<Item = (ConnectionId, ChannelRole)> + '_ {
201        self.channels
202            .users_to_notify(channel_id)
203            .flat_map(|(user_id, role)| {
204                self.user_connection_ids(user_id)
205                    .map(move |connection_id| (connection_id, role))
206            })
207    }
208
209    pub fn subscribe_to_channel(
210        &mut self,
211        user_id: UserId,
212        channel_id: ChannelId,
213        role: ChannelRole,
214    ) {
215        self.channels.subscribe(user_id, channel_id, role);
216    }
217
218    pub fn unsubscribe_from_channel(&mut self, user_id: &UserId, channel_id: &ChannelId) {
219        self.channels.unsubscribe(user_id, channel_id);
220    }
221
222    pub fn is_user_online(&self, user_id: UserId) -> bool {
223        !self
224            .connected_users
225            .get(&user_id)
226            .unwrap_or(&Default::default())
227            .connection_ids
228            .is_empty()
229    }
230
231    #[cfg(test)]
232    pub fn check_invariants(&self) {
233        for (connection_id, connection) in &self.connections {
234            match &connection.principal_id {
235                PrincipalId::UserId(user_id) => {
236                    assert!(self
237                        .connected_users
238                        .get(user_id)
239                        .unwrap()
240                        .connection_ids
241                        .contains(connection_id));
242                }
243                PrincipalId::DevServerId(dev_server_id) => {
244                    assert_eq!(
245                        self.connected_dev_servers.get(&dev_server_id).unwrap(),
246                        connection_id
247                    );
248                }
249            }
250        }
251
252        for (user_id, state) in &self.connected_users {
253            for connection_id in &state.connection_ids {
254                assert_eq!(
255                    self.connections.get(connection_id).unwrap().principal_id,
256                    PrincipalId::UserId(*user_id)
257                );
258            }
259        }
260
261        for (dev_server_id, connection_id) in &self.connected_dev_servers {
262            assert_eq!(
263                self.connections.get(connection_id).unwrap().principal_id,
264                PrincipalId::DevServerId(*dev_server_id)
265            );
266        }
267    }
268}
269
270#[derive(Default, Serialize)]
271pub struct ChannelPool {
272    by_user: HashMap<UserId, HashMap<ChannelId, ChannelRole>>,
273    by_channel: HashMap<ChannelId, HashSet<UserId>>,
274}
275
276impl ChannelPool {
277    pub fn clear(&mut self) {
278        self.by_user.clear();
279        self.by_channel.clear();
280    }
281
282    pub fn subscribe(&mut self, user_id: UserId, channel_id: ChannelId, role: ChannelRole) {
283        self.by_user
284            .entry(user_id)
285            .or_default()
286            .insert(channel_id, role);
287        self.by_channel
288            .entry(channel_id)
289            .or_default()
290            .insert(user_id);
291    }
292
293    pub fn unsubscribe(&mut self, user_id: &UserId, channel_id: &ChannelId) {
294        if let Some(channels) = self.by_user.get_mut(user_id) {
295            channels.remove(channel_id);
296            if channels.is_empty() {
297                self.by_user.remove(user_id);
298            }
299        }
300        if let Some(users) = self.by_channel.get_mut(channel_id) {
301            users.remove(user_id);
302            if users.is_empty() {
303                self.by_channel.remove(channel_id);
304            }
305        }
306    }
307
308    pub fn remove_user(&mut self, user_id: &UserId) {
309        if let Some(channels) = self.by_user.remove(&user_id) {
310            for channel_id in channels.keys() {
311                self.unsubscribe(user_id, &channel_id)
312            }
313        }
314    }
315
316    pub fn users_to_notify(
317        &self,
318        channel_id: ChannelId,
319    ) -> impl '_ + Iterator<Item = (UserId, ChannelRole)> {
320        self.by_channel
321            .get(&channel_id)
322            .into_iter()
323            .flat_map(move |users| {
324                users.iter().flat_map(move |user_id| {
325                    Some((
326                        *user_id,
327                        self.by_user
328                            .get(user_id)
329                            .and_then(|channels| channels.get(&channel_id))
330                            .copied()?,
331                    ))
332                })
333            })
334    }
335}