store.rs

  1use crate::db::{ChannelId, UserId};
  2use anyhow::anyhow;
  3use collections::{HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use std::collections::hash_map;
  6
  7#[derive(Default)]
  8pub struct Store {
  9    connections: HashMap<ConnectionId, ConnectionState>,
 10    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 11    projects: HashMap<u64, Project>,
 12    visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
 13    channels: HashMap<ChannelId, Channel>,
 14    next_project_id: u64,
 15}
 16
 17struct ConnectionState {
 18    user_id: UserId,
 19    projects: HashSet<u64>,
 20    channels: HashSet<ChannelId>,
 21}
 22
 23pub struct Project {
 24    pub host_connection_id: ConnectionId,
 25    pub host_user_id: UserId,
 26    pub share: Option<ProjectShare>,
 27    pub worktrees: HashMap<u64, Worktree>,
 28}
 29
 30pub struct Worktree {
 31    pub authorized_user_ids: Vec<UserId>,
 32    pub root_name: String,
 33    pub share: Option<WorktreeShare>,
 34}
 35
 36#[derive(Default)]
 37pub struct ProjectShare {
 38    pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
 39    pub active_replica_ids: HashSet<ReplicaId>,
 40}
 41
 42pub struct WorktreeShare {
 43    pub entries: HashMap<u64, proto::Entry>,
 44}
 45
 46#[derive(Default)]
 47pub struct Channel {
 48    pub connection_ids: HashSet<ConnectionId>,
 49}
 50
 51pub type ReplicaId = u16;
 52
 53#[derive(Default)]
 54pub struct RemovedConnectionState {
 55    pub hosted_projects: HashMap<u64, Project>,
 56    pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
 57    pub contact_ids: HashSet<UserId>,
 58}
 59
 60pub struct JoinedProject<'a> {
 61    pub replica_id: ReplicaId,
 62    pub project: &'a Project,
 63}
 64
 65pub struct UnsharedWorktree {
 66    pub connection_ids: Vec<ConnectionId>,
 67    pub authorized_user_ids: Vec<UserId>,
 68}
 69
 70pub struct LeftProject {
 71    pub connection_ids: Vec<ConnectionId>,
 72    pub authorized_user_ids: Vec<UserId>,
 73}
 74
 75impl Store {
 76    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 77        self.connections.insert(
 78            connection_id,
 79            ConnectionState {
 80                user_id,
 81                projects: Default::default(),
 82                channels: Default::default(),
 83            },
 84        );
 85        self.connections_by_user_id
 86            .entry(user_id)
 87            .or_default()
 88            .insert(connection_id);
 89    }
 90
 91    pub fn remove_connection(
 92        &mut self,
 93        connection_id: ConnectionId,
 94    ) -> tide::Result<RemovedConnectionState> {
 95        let connection = if let Some(connection) = self.connections.remove(&connection_id) {
 96            connection
 97        } else {
 98            return Err(anyhow!("no such connection"))?;
 99        };
100
101        for channel_id in &connection.channels {
102            if let Some(channel) = self.channels.get_mut(&channel_id) {
103                channel.connection_ids.remove(&connection_id);
104            }
105        }
106
107        let user_connections = self
108            .connections_by_user_id
109            .get_mut(&connection.user_id)
110            .unwrap();
111        user_connections.remove(&connection_id);
112        if user_connections.is_empty() {
113            self.connections_by_user_id.remove(&connection.user_id);
114        }
115
116        let mut result = RemovedConnectionState::default();
117        for project_id in connection.projects.clone() {
118            if let Some((project, authorized_user_ids)) =
119                self.unregister_project(project_id, connection_id)
120            {
121                result.contact_ids.extend(authorized_user_ids);
122                result.hosted_projects.insert(project_id, project);
123            } else if let Some(project) = self.leave_project(connection_id, project_id) {
124                result
125                    .guest_project_ids
126                    .insert(project_id, project.connection_ids);
127                result.contact_ids.extend(project.authorized_user_ids);
128            }
129        }
130
131        #[cfg(test)]
132        self.check_invariants();
133
134        Ok(result)
135    }
136
137    #[cfg(test)]
138    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
139        self.channels.get(&id)
140    }
141
142    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
143        if let Some(connection) = self.connections.get_mut(&connection_id) {
144            connection.channels.insert(channel_id);
145            self.channels
146                .entry(channel_id)
147                .or_default()
148                .connection_ids
149                .insert(connection_id);
150        }
151    }
152
153    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
154        if let Some(connection) = self.connections.get_mut(&connection_id) {
155            connection.channels.remove(&channel_id);
156            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
157                entry.get_mut().connection_ids.remove(&connection_id);
158                if entry.get_mut().connection_ids.is_empty() {
159                    entry.remove();
160                }
161            }
162        }
163    }
164
165    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
166        Ok(self
167            .connections
168            .get(&connection_id)
169            .ok_or_else(|| anyhow!("unknown connection"))?
170            .user_id)
171    }
172
173    pub fn connection_ids_for_user<'a>(
174        &'a self,
175        user_id: UserId,
176    ) -> impl 'a + Iterator<Item = ConnectionId> {
177        self.connections_by_user_id
178            .get(&user_id)
179            .into_iter()
180            .flatten()
181            .copied()
182    }
183
184    pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
185        let mut contacts = HashMap::default();
186        for project_id in self
187            .visible_projects_by_user_id
188            .get(&user_id)
189            .unwrap_or(&HashSet::default())
190        {
191            let project = &self.projects[project_id];
192
193            let mut guests = HashSet::default();
194            if let Ok(share) = project.share() {
195                for guest_connection_id in share.guests.keys() {
196                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
197                        guests.insert(user_id.to_proto());
198                    }
199                }
200            }
201
202            if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
203                let mut worktree_root_names = project
204                    .worktrees
205                    .values()
206                    .map(|worktree| worktree.root_name.clone())
207                    .collect::<Vec<_>>();
208                worktree_root_names.sort_unstable();
209                contacts
210                    .entry(host_user_id)
211                    .or_insert_with(|| proto::Contact {
212                        user_id: host_user_id.to_proto(),
213                        projects: Vec::new(),
214                    })
215                    .projects
216                    .push(proto::ProjectMetadata {
217                        id: *project_id,
218                        worktree_root_names,
219                        is_shared: project.share.is_some(),
220                        guests: guests.into_iter().collect(),
221                    });
222            }
223        }
224
225        contacts.into_values().collect()
226    }
227
228    pub fn register_project(
229        &mut self,
230        host_connection_id: ConnectionId,
231        host_user_id: UserId,
232    ) -> u64 {
233        let project_id = self.next_project_id;
234        self.projects.insert(
235            project_id,
236            Project {
237                host_connection_id,
238                host_user_id,
239                share: None,
240                worktrees: Default::default(),
241            },
242        );
243        self.next_project_id += 1;
244        project_id
245    }
246
247    pub fn register_worktree(
248        &mut self,
249        project_id: u64,
250        worktree_id: u64,
251        worktree: Worktree,
252    ) -> bool {
253        if let Some(project) = self.projects.get_mut(&project_id) {
254            for authorized_user_id in &worktree.authorized_user_ids {
255                self.visible_projects_by_user_id
256                    .entry(*authorized_user_id)
257                    .or_default()
258                    .insert(project_id);
259            }
260            if let Some(connection) = self.connections.get_mut(&project.host_connection_id) {
261                connection.projects.insert(project_id);
262            }
263            project.worktrees.insert(worktree_id, worktree);
264
265            #[cfg(test)]
266            self.check_invariants();
267            true
268        } else {
269            false
270        }
271    }
272
273    pub fn unregister_project(
274        &mut self,
275        project_id: u64,
276        connection_id: ConnectionId,
277    ) -> Option<(Project, Vec<UserId>)> {
278        match self.projects.entry(project_id) {
279            hash_map::Entry::Occupied(e) => {
280                if e.get().host_connection_id != connection_id {
281                    return None;
282                }
283            }
284            hash_map::Entry::Vacant(_) => return None,
285        }
286
287        todo!()
288    }
289
290    pub fn unregister_worktree(
291        &mut self,
292        project_id: u64,
293        worktree_id: u64,
294        acting_connection_id: ConnectionId,
295    ) -> tide::Result<(Worktree, Vec<ConnectionId>)> {
296        let project = self
297            .projects
298            .get_mut(&project_id)
299            .ok_or_else(|| anyhow!("no such project"))?;
300        if project.host_connection_id != acting_connection_id {
301            Err(anyhow!("not your worktree"))?;
302        }
303
304        let worktree = project
305            .worktrees
306            .remove(&worktree_id)
307            .ok_or_else(|| anyhow!("no such worktree"))?;
308
309        let mut guest_connection_ids = Vec::new();
310        if let Some(share) = &project.share {
311            guest_connection_ids.extend(share.guests.keys());
312        }
313
314        for authorized_user_id in &worktree.authorized_user_ids {
315            if let Some(visible_projects) =
316                self.visible_projects_by_user_id.get_mut(authorized_user_id)
317            {
318                if !project.has_authorized_user_id(*authorized_user_id) {
319                    visible_projects.remove(&project_id);
320                }
321            }
322        }
323
324        #[cfg(test)]
325        self.check_invariants();
326
327        Ok((worktree, guest_connection_ids))
328    }
329
330    pub fn share_project(&mut self, project_id: u64, connection_id: ConnectionId) -> bool {
331        if let Some(project) = self.projects.get_mut(&project_id) {
332            if project.host_connection_id == connection_id {
333                project.share = Some(ProjectShare::default());
334                return true;
335            }
336        }
337        false
338    }
339
340    pub fn unshare_project(
341        &mut self,
342        project_id: u64,
343        acting_connection_id: ConnectionId,
344    ) -> tide::Result<UnsharedWorktree> {
345        let project = if let Some(project) = self.projects.get_mut(&project_id) {
346            project
347        } else {
348            return Err(anyhow!("no such project"))?;
349        };
350
351        if project.host_connection_id != acting_connection_id {
352            return Err(anyhow!("not your project"))?;
353        }
354
355        let connection_ids = project.connection_ids();
356        let authorized_user_ids = project.authorized_user_ids();
357        if let Some(share) = project.share.take() {
358            for connection_id in share.guests.into_keys() {
359                if let Some(connection) = self.connections.get_mut(&connection_id) {
360                    connection.projects.remove(&project_id);
361                }
362            }
363
364            #[cfg(test)]
365            self.check_invariants();
366
367            Ok(UnsharedWorktree {
368                connection_ids,
369                authorized_user_ids,
370            })
371        } else {
372            Err(anyhow!("project is not shared"))?
373        }
374    }
375
376    pub fn share_worktree(
377        &mut self,
378        project_id: u64,
379        worktree_id: u64,
380        connection_id: ConnectionId,
381        entries: HashMap<u64, proto::Entry>,
382    ) -> Option<Vec<UserId>> {
383        let project = self.projects.get_mut(&project_id)?;
384        let worktree = project.worktrees.get_mut(&worktree_id)?;
385        if project.host_connection_id == connection_id && project.share.is_some() {
386            worktree.share = Some(WorktreeShare { entries });
387            Some(project.authorized_user_ids())
388        } else {
389            None
390        }
391    }
392
393    pub fn join_project(
394        &mut self,
395        connection_id: ConnectionId,
396        user_id: UserId,
397        project_id: u64,
398    ) -> tide::Result<JoinedProject> {
399        let connection = self
400            .connections
401            .get_mut(&connection_id)
402            .ok_or_else(|| anyhow!("no such connection"))?;
403        let project = self
404            .projects
405            .get_mut(&project_id)
406            .and_then(|project| {
407                if project.has_authorized_user_id(user_id) {
408                    Some(project)
409                } else {
410                    None
411                }
412            })
413            .ok_or_else(|| anyhow!("no such project"))?;
414
415        let share = project.share_mut()?;
416        connection.projects.insert(project_id);
417
418        let mut replica_id = 1;
419        while share.active_replica_ids.contains(&replica_id) {
420            replica_id += 1;
421        }
422        share.active_replica_ids.insert(replica_id);
423        share.guests.insert(connection_id, (replica_id, user_id));
424
425        #[cfg(test)]
426        self.check_invariants();
427
428        Ok(JoinedProject {
429            replica_id,
430            project: &self.projects[&project_id],
431        })
432    }
433
434    pub fn leave_project(
435        &mut self,
436        connection_id: ConnectionId,
437        project_id: u64,
438    ) -> Option<LeftProject> {
439        let project = self.projects.get_mut(&project_id)?;
440        let share = project.share.as_mut()?;
441        let (replica_id, _) = share.guests.remove(&connection_id)?;
442        share.active_replica_ids.remove(&replica_id);
443
444        if let Some(connection) = self.connections.get_mut(&connection_id) {
445            connection.projects.remove(&project_id);
446        }
447
448        let connection_ids = project.connection_ids();
449        let authorized_user_ids = project.authorized_user_ids();
450
451        #[cfg(test)]
452        self.check_invariants();
453
454        Some(LeftProject {
455            connection_ids,
456            authorized_user_ids,
457        })
458    }
459
460    pub fn update_worktree(
461        &mut self,
462        connection_id: ConnectionId,
463        project_id: u64,
464        worktree_id: u64,
465        removed_entries: &[u64],
466        updated_entries: &[proto::Entry],
467    ) -> Option<Vec<ConnectionId>> {
468        let project = self.write_project(project_id, connection_id)?;
469        let share = project.worktrees.get_mut(&worktree_id)?.share.as_mut()?;
470        for entry_id in removed_entries {
471            share.entries.remove(&entry_id);
472        }
473        for entry in updated_entries {
474            share.entries.insert(entry.id, entry.clone());
475        }
476        Some(project.connection_ids())
477    }
478
479    pub fn project_connection_ids(
480        &self,
481        project_id: u64,
482        acting_connection_id: ConnectionId,
483    ) -> Option<Vec<ConnectionId>> {
484        Some(
485            self.read_project(project_id, acting_connection_id)?
486                .connection_ids(),
487        )
488    }
489
490    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
491        Some(self.channels.get(&channel_id)?.connection_ids())
492    }
493
494    pub fn read_project(&self, project_id: u64, connection_id: ConnectionId) -> Option<&Project> {
495        let project = self.projects.get(&project_id)?;
496        if project.host_connection_id == connection_id
497            || project.share.as_ref()?.guests.contains_key(&connection_id)
498        {
499            Some(project)
500        } else {
501            None
502        }
503    }
504
505    fn write_project(
506        &mut self,
507        project_id: u64,
508        connection_id: ConnectionId,
509    ) -> Option<&mut Project> {
510        let project = self.projects.get_mut(&project_id)?;
511        if project.host_connection_id == connection_id
512            || project.share.as_ref()?.guests.contains_key(&connection_id)
513        {
514            Some(project)
515        } else {
516            None
517        }
518    }
519
520    #[cfg(test)]
521    fn check_invariants(&self) {
522        for (connection_id, connection) in &self.connections {
523            for project_id in &connection.projects {
524                let project = &self.projects.get(&project_id).unwrap();
525                if project.host_connection_id != *connection_id {
526                    assert!(project
527                        .share
528                        .as_ref()
529                        .unwrap()
530                        .guests
531                        .contains_key(connection_id));
532                }
533            }
534            for channel_id in &connection.channels {
535                let channel = self.channels.get(channel_id).unwrap();
536                assert!(channel.connection_ids.contains(connection_id));
537            }
538            assert!(self
539                .connections_by_user_id
540                .get(&connection.user_id)
541                .unwrap()
542                .contains(connection_id));
543        }
544
545        for (user_id, connection_ids) in &self.connections_by_user_id {
546            for connection_id in connection_ids {
547                assert_eq!(
548                    self.connections.get(connection_id).unwrap().user_id,
549                    *user_id
550                );
551            }
552        }
553
554        for (project_id, project) in &self.projects {
555            let host_connection = self.connections.get(&project.host_connection_id).unwrap();
556            assert!(host_connection.projects.contains(project_id));
557
558            for authorized_user_ids in project.authorized_user_ids() {
559                let visible_project_ids = self
560                    .visible_projects_by_user_id
561                    .get(&authorized_user_ids)
562                    .unwrap();
563                assert!(visible_project_ids.contains(project_id));
564            }
565
566            if let Some(share) = &project.share {
567                for guest_connection_id in share.guests.keys() {
568                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
569                    assert!(guest_connection.projects.contains(project_id));
570                }
571                assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
572                assert_eq!(
573                    share.active_replica_ids,
574                    share
575                        .guests
576                        .values()
577                        .map(|(replica_id, _)| *replica_id)
578                        .collect::<HashSet<_>>(),
579                );
580            }
581        }
582
583        for (user_id, visible_project_ids) in &self.visible_projects_by_user_id {
584            for project_id in visible_project_ids {
585                let project = self.projects.get(project_id).unwrap();
586                assert!(project.authorized_user_ids().contains(user_id));
587            }
588        }
589
590        for (channel_id, channel) in &self.channels {
591            for connection_id in &channel.connection_ids {
592                let connection = self.connections.get(connection_id).unwrap();
593                assert!(connection.channels.contains(channel_id));
594            }
595        }
596    }
597}
598
599impl Project {
600    pub fn has_authorized_user_id(&self, user_id: UserId) -> bool {
601        self.worktrees
602            .values()
603            .any(|worktree| worktree.authorized_user_ids.contains(&user_id))
604    }
605
606    pub fn authorized_user_ids(&self) -> Vec<UserId> {
607        let mut ids = self
608            .worktrees
609            .values()
610            .flat_map(|worktree| worktree.authorized_user_ids.iter())
611            .copied()
612            .collect::<Vec<_>>();
613        ids.sort_unstable();
614        ids.dedup();
615        ids
616    }
617
618    pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
619        if let Some(share) = &self.share {
620            share.guests.keys().copied().collect()
621        } else {
622            Vec::new()
623        }
624    }
625
626    pub fn connection_ids(&self) -> Vec<ConnectionId> {
627        if let Some(share) = &self.share {
628            share
629                .guests
630                .keys()
631                .copied()
632                .chain(Some(self.host_connection_id))
633                .collect()
634        } else {
635            vec![self.host_connection_id]
636        }
637    }
638
639    pub fn share(&self) -> tide::Result<&ProjectShare> {
640        Ok(self
641            .share
642            .as_ref()
643            .ok_or_else(|| anyhow!("worktree is not shared"))?)
644    }
645
646    fn share_mut(&mut self) -> tide::Result<&mut ProjectShare> {
647        Ok(self
648            .share
649            .as_mut()
650            .ok_or_else(|| anyhow!("worktree is not shared"))?)
651    }
652}
653
654impl Channel {
655    fn connection_ids(&self) -> Vec<ConnectionId> {
656        self.connection_ids.iter().copied().collect()
657    }
658}