connection_pool.rs

  1use crate::db::{ChannelId, ChannelRole, DevServerId, PrincipalId, UserId};
  2use anyhow::{anyhow, Result};
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use semantic_version::SemanticVersion;
  6use serde::Serialize;
  7use std::fmt;
  8use tracing::instrument;
  9
 10#[derive(Default, Serialize)]
 11pub struct ConnectionPool {
 12    connections: BTreeMap<ConnectionId, Connection>,
 13    connected_users: BTreeMap<UserId, ConnectedPrincipal>,
 14    connected_dev_servers: BTreeMap<DevServerId, ConnectionId>,
 15    channels: ChannelPool,
 16    offline_dev_servers: HashSet<DevServerId>,
 17}
 18
 19#[derive(Default, Serialize)]
 20struct ConnectedPrincipal {
 21    connection_ids: HashSet<ConnectionId>,
 22}
 23
 24#[derive(Copy, Clone, Debug, Serialize, PartialOrd, PartialEq, Eq, Ord)]
 25pub struct ZedVersion(pub SemanticVersion);
 26
 27impl fmt::Display for ZedVersion {
 28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 29        write!(f, "{}", self.0)
 30    }
 31}
 32
 33impl ZedVersion {
 34    pub fn can_collaborate(&self) -> bool {
 35        self.0 >= SemanticVersion::new(0, 129, 2)
 36    }
 37
 38    pub fn with_save_as() -> ZedVersion {
 39        ZedVersion(SemanticVersion::new(0, 134, 0))
 40    }
 41
 42    pub fn with_list_directory() -> ZedVersion {
 43        ZedVersion(SemanticVersion::new(0, 145, 0))
 44    }
 45
 46    pub fn with_search_candidates() -> ZedVersion {
 47        ZedVersion(SemanticVersion::new(0, 151, 0))
 48    }
 49}
 50
 51pub trait VersionedMessage {
 52    fn required_host_version(&self) -> Option<ZedVersion> {
 53        None
 54    }
 55}
 56
 57impl VersionedMessage for proto::SaveBuffer {
 58    fn required_host_version(&self) -> Option<ZedVersion> {
 59        if self.new_path.is_some() {
 60            Some(ZedVersion::with_save_as())
 61        } else {
 62            None
 63        }
 64    }
 65}
 66
 67impl VersionedMessage for proto::OpenNewBuffer {
 68    fn required_host_version(&self) -> Option<ZedVersion> {
 69        Some(ZedVersion::with_save_as())
 70    }
 71}
 72
 73#[derive(Serialize)]
 74pub struct Connection {
 75    pub principal_id: PrincipalId,
 76    pub admin: bool,
 77    pub zed_version: ZedVersion,
 78}
 79
 80impl ConnectionPool {
 81    pub fn reset(&mut self) {
 82        self.connections.clear();
 83        self.connected_users.clear();
 84        self.connected_dev_servers.clear();
 85        self.channels.clear();
 86    }
 87
 88    pub fn connection(&mut self, connection_id: ConnectionId) -> Option<&Connection> {
 89        self.connections.get(&connection_id)
 90    }
 91
 92    #[instrument(skip(self))]
 93    pub fn add_connection(
 94        &mut self,
 95        connection_id: ConnectionId,
 96        user_id: UserId,
 97        admin: bool,
 98        zed_version: ZedVersion,
 99    ) {
100        self.connections.insert(
101            connection_id,
102            Connection {
103                principal_id: PrincipalId::UserId(user_id),
104                admin,
105                zed_version,
106            },
107        );
108        let connected_user = self.connected_users.entry(user_id).or_default();
109        connected_user.connection_ids.insert(connection_id);
110    }
111
112    pub fn add_dev_server(
113        &mut self,
114        connection_id: ConnectionId,
115        dev_server_id: DevServerId,
116        zed_version: ZedVersion,
117    ) {
118        self.connections.insert(
119            connection_id,
120            Connection {
121                principal_id: PrincipalId::DevServerId(dev_server_id),
122                admin: false,
123                zed_version,
124            },
125        );
126
127        self.connected_dev_servers
128            .insert(dev_server_id, connection_id);
129    }
130
131    #[instrument(skip(self))]
132    pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
133        let connection = self
134            .connections
135            .get_mut(&connection_id)
136            .ok_or_else(|| anyhow!("no such connection"))?;
137
138        match connection.principal_id {
139            PrincipalId::UserId(user_id) => {
140                let connected_user = self.connected_users.get_mut(&user_id).unwrap();
141                connected_user.connection_ids.remove(&connection_id);
142                if connected_user.connection_ids.is_empty() {
143                    self.connected_users.remove(&user_id);
144                    self.channels.remove_user(&user_id);
145                }
146            }
147            PrincipalId::DevServerId(dev_server_id) => {
148                self.connected_dev_servers.remove(&dev_server_id);
149                self.offline_dev_servers.remove(&dev_server_id);
150            }
151        }
152        self.connections.remove(&connection_id).unwrap();
153        Ok(())
154    }
155
156    pub fn set_dev_server_offline(&mut self, dev_server_id: DevServerId) {
157        self.offline_dev_servers.insert(dev_server_id);
158    }
159
160    pub fn connections(&self) -> impl Iterator<Item = &Connection> {
161        self.connections.values()
162    }
163
164    pub fn user_connections(&self, user_id: UserId) -> impl Iterator<Item = &Connection> + '_ {
165        self.connected_users
166            .get(&user_id)
167            .into_iter()
168            .flat_map(|state| {
169                state
170                    .connection_ids
171                    .iter()
172                    .flat_map(|cid| self.connections.get(cid))
173            })
174    }
175
176    pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
177        self.connected_users
178            .get(&user_id)
179            .into_iter()
180            .flat_map(|state| &state.connection_ids)
181            .copied()
182    }
183
184    pub fn dev_server_status(&self, dev_server_id: DevServerId) -> proto::DevServerStatus {
185        if self.dev_server_connection_id(dev_server_id).is_some()
186            && !self.offline_dev_servers.contains(&dev_server_id)
187        {
188            proto::DevServerStatus::Online
189        } else {
190            proto::DevServerStatus::Offline
191        }
192    }
193
194    pub fn dev_server_connection_id(&self, dev_server_id: DevServerId) -> Option<ConnectionId> {
195        self.connected_dev_servers.get(&dev_server_id).copied()
196    }
197
198    pub fn dev_server_connection_id_supporting(
199        &self,
200        dev_server_id: DevServerId,
201        required: ZedVersion,
202    ) -> Result<ConnectionId> {
203        match self.connected_dev_servers.get(&dev_server_id) {
204            Some(cid) if self.connections[cid].zed_version >= required => Ok(*cid),
205            Some(_) => Err(anyhow!(proto::ErrorCode::RemoteUpgradeRequired)),
206            None => Err(anyhow!(proto::ErrorCode::DevServerOffline)),
207        }
208    }
209
210    pub fn channel_user_ids(
211        &self,
212        channel_id: ChannelId,
213    ) -> impl Iterator<Item = (UserId, ChannelRole)> + '_ {
214        self.channels.users_to_notify(channel_id)
215    }
216
217    pub fn channel_connection_ids(
218        &self,
219        channel_id: ChannelId,
220    ) -> impl Iterator<Item = (ConnectionId, ChannelRole)> + '_ {
221        self.channels
222            .users_to_notify(channel_id)
223            .flat_map(|(user_id, role)| {
224                self.user_connection_ids(user_id)
225                    .map(move |connection_id| (connection_id, role))
226            })
227    }
228
229    pub fn subscribe_to_channel(
230        &mut self,
231        user_id: UserId,
232        channel_id: ChannelId,
233        role: ChannelRole,
234    ) {
235        self.channels.subscribe(user_id, channel_id, role);
236    }
237
238    pub fn unsubscribe_from_channel(&mut self, user_id: &UserId, channel_id: &ChannelId) {
239        self.channels.unsubscribe(user_id, channel_id);
240    }
241
242    pub fn is_user_online(&self, user_id: UserId) -> bool {
243        !self
244            .connected_users
245            .get(&user_id)
246            .unwrap_or(&Default::default())
247            .connection_ids
248            .is_empty()
249    }
250
251    #[cfg(test)]
252    pub fn check_invariants(&self) {
253        for (connection_id, connection) in &self.connections {
254            match &connection.principal_id {
255                PrincipalId::UserId(user_id) => {
256                    assert!(self
257                        .connected_users
258                        .get(user_id)
259                        .unwrap()
260                        .connection_ids
261                        .contains(connection_id));
262                }
263                PrincipalId::DevServerId(dev_server_id) => {
264                    assert_eq!(
265                        self.connected_dev_servers.get(&dev_server_id).unwrap(),
266                        connection_id
267                    );
268                }
269            }
270        }
271
272        for (user_id, state) in &self.connected_users {
273            for connection_id in &state.connection_ids {
274                assert_eq!(
275                    self.connections.get(connection_id).unwrap().principal_id,
276                    PrincipalId::UserId(*user_id)
277                );
278            }
279        }
280
281        for (dev_server_id, connection_id) in &self.connected_dev_servers {
282            assert_eq!(
283                self.connections.get(connection_id).unwrap().principal_id,
284                PrincipalId::DevServerId(*dev_server_id)
285            );
286        }
287    }
288}
289
290#[derive(Default, Serialize)]
291pub struct ChannelPool {
292    by_user: HashMap<UserId, HashMap<ChannelId, ChannelRole>>,
293    by_channel: HashMap<ChannelId, HashSet<UserId>>,
294}
295
296impl ChannelPool {
297    pub fn clear(&mut self) {
298        self.by_user.clear();
299        self.by_channel.clear();
300    }
301
302    pub fn subscribe(&mut self, user_id: UserId, channel_id: ChannelId, role: ChannelRole) {
303        self.by_user
304            .entry(user_id)
305            .or_default()
306            .insert(channel_id, role);
307        self.by_channel
308            .entry(channel_id)
309            .or_default()
310            .insert(user_id);
311    }
312
313    pub fn unsubscribe(&mut self, user_id: &UserId, channel_id: &ChannelId) {
314        if let Some(channels) = self.by_user.get_mut(user_id) {
315            channels.remove(channel_id);
316            if channels.is_empty() {
317                self.by_user.remove(user_id);
318            }
319        }
320        if let Some(users) = self.by_channel.get_mut(channel_id) {
321            users.remove(user_id);
322            if users.is_empty() {
323                self.by_channel.remove(channel_id);
324            }
325        }
326    }
327
328    pub fn remove_user(&mut self, user_id: &UserId) {
329        if let Some(channels) = self.by_user.remove(&user_id) {
330            for channel_id in channels.keys() {
331                self.unsubscribe(user_id, &channel_id)
332            }
333        }
334    }
335
336    pub fn users_to_notify(
337        &self,
338        channel_id: ChannelId,
339    ) -> impl '_ + Iterator<Item = (UserId, ChannelRole)> {
340        self.by_channel
341            .get(&channel_id)
342            .into_iter()
343            .flat_map(move |users| {
344                users.iter().flat_map(move |user_id| {
345                    Some((
346                        *user_id,
347                        self.by_user
348                            .get(user_id)
349                            .and_then(|channels| channels.get(&channel_id))
350                            .copied()?,
351                    ))
352                })
353            })
354    }
355}