store.rs

  1use crate::db::{ChannelId, UserId};
  2use anyhow::anyhow;
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use std::{collections::hash_map, path::PathBuf};
  6
  7#[derive(Default)]
  8pub struct Store {
  9    connections: HashMap<ConnectionId, ConnectionState>,
 10    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 11    projects: HashMap<u64, Project>,
 12    visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
 13    channels: HashMap<ChannelId, Channel>,
 14    next_project_id: u64,
 15}
 16
 17struct ConnectionState {
 18    user_id: UserId,
 19    projects: HashSet<u64>,
 20    channels: HashSet<ChannelId>,
 21}
 22
 23pub struct Project {
 24    pub host_connection_id: ConnectionId,
 25    pub host_user_id: UserId,
 26    pub share: Option<ProjectShare>,
 27    pub worktrees: HashMap<u64, Worktree>,
 28}
 29
 30pub struct Worktree {
 31    pub authorized_user_ids: Vec<UserId>,
 32    pub root_name: String,
 33    pub weak: bool,
 34}
 35
 36#[derive(Default)]
 37pub struct ProjectShare {
 38    pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
 39    pub active_replica_ids: HashSet<ReplicaId>,
 40    pub worktrees: HashMap<u64, WorktreeShare>,
 41}
 42
 43#[derive(Default)]
 44pub struct WorktreeShare {
 45    pub entries: HashMap<u64, proto::Entry>,
 46    pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
 47}
 48
 49#[derive(Default)]
 50pub struct Channel {
 51    pub connection_ids: HashSet<ConnectionId>,
 52}
 53
 54pub type ReplicaId = u16;
 55
 56#[derive(Default)]
 57pub struct RemovedConnectionState {
 58    pub hosted_projects: HashMap<u64, Project>,
 59    pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
 60    pub contact_ids: HashSet<UserId>,
 61}
 62
 63pub struct JoinedProject<'a> {
 64    pub replica_id: ReplicaId,
 65    pub project: &'a Project,
 66}
 67
 68pub struct UnsharedProject {
 69    pub connection_ids: Vec<ConnectionId>,
 70    pub authorized_user_ids: Vec<UserId>,
 71}
 72
 73pub struct LeftProject {
 74    pub connection_ids: Vec<ConnectionId>,
 75    pub authorized_user_ids: Vec<UserId>,
 76}
 77
 78impl Store {
 79    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 80        self.connections.insert(
 81            connection_id,
 82            ConnectionState {
 83                user_id,
 84                projects: Default::default(),
 85                channels: Default::default(),
 86            },
 87        );
 88        self.connections_by_user_id
 89            .entry(user_id)
 90            .or_default()
 91            .insert(connection_id);
 92    }
 93
 94    pub fn remove_connection(
 95        &mut self,
 96        connection_id: ConnectionId,
 97    ) -> tide::Result<RemovedConnectionState> {
 98        let connection = if let Some(connection) = self.connections.remove(&connection_id) {
 99            connection
100        } else {
101            return Err(anyhow!("no such connection"))?;
102        };
103
104        for channel_id in &connection.channels {
105            if let Some(channel) = self.channels.get_mut(&channel_id) {
106                channel.connection_ids.remove(&connection_id);
107            }
108        }
109
110        let user_connections = self
111            .connections_by_user_id
112            .get_mut(&connection.user_id)
113            .unwrap();
114        user_connections.remove(&connection_id);
115        if user_connections.is_empty() {
116            self.connections_by_user_id.remove(&connection.user_id);
117        }
118
119        let mut result = RemovedConnectionState::default();
120        for project_id in connection.projects.clone() {
121            if let Ok(project) = self.unregister_project(project_id, connection_id) {
122                result.contact_ids.extend(project.authorized_user_ids());
123                result.hosted_projects.insert(project_id, project);
124            } else if let Ok(project) = self.leave_project(connection_id, project_id) {
125                result
126                    .guest_project_ids
127                    .insert(project_id, project.connection_ids);
128                result.contact_ids.extend(project.authorized_user_ids);
129            }
130        }
131
132        #[cfg(test)]
133        self.check_invariants();
134
135        Ok(result)
136    }
137
138    #[cfg(test)]
139    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
140        self.channels.get(&id)
141    }
142
143    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
144        if let Some(connection) = self.connections.get_mut(&connection_id) {
145            connection.channels.insert(channel_id);
146            self.channels
147                .entry(channel_id)
148                .or_default()
149                .connection_ids
150                .insert(connection_id);
151        }
152    }
153
154    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
155        if let Some(connection) = self.connections.get_mut(&connection_id) {
156            connection.channels.remove(&channel_id);
157            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
158                entry.get_mut().connection_ids.remove(&connection_id);
159                if entry.get_mut().connection_ids.is_empty() {
160                    entry.remove();
161                }
162            }
163        }
164    }
165
166    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
167        Ok(self
168            .connections
169            .get(&connection_id)
170            .ok_or_else(|| anyhow!("unknown connection"))?
171            .user_id)
172    }
173
174    pub fn connection_ids_for_user<'a>(
175        &'a self,
176        user_id: UserId,
177    ) -> impl 'a + Iterator<Item = ConnectionId> {
178        self.connections_by_user_id
179            .get(&user_id)
180            .into_iter()
181            .flatten()
182            .copied()
183    }
184
185    pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
186        let mut contacts = HashMap::default();
187        for project_id in self
188            .visible_projects_by_user_id
189            .get(&user_id)
190            .unwrap_or(&HashSet::default())
191        {
192            let project = &self.projects[project_id];
193
194            let mut guests = HashSet::default();
195            if let Ok(share) = project.share() {
196                for guest_connection_id in share.guests.keys() {
197                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
198                        guests.insert(user_id.to_proto());
199                    }
200                }
201            }
202
203            if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
204                let mut worktree_root_names = project
205                    .worktrees
206                    .values()
207                    .filter(|worktree| !worktree.weak)
208                    .map(|worktree| worktree.root_name.clone())
209                    .collect::<Vec<_>>();
210                worktree_root_names.sort_unstable();
211                contacts
212                    .entry(host_user_id)
213                    .or_insert_with(|| proto::Contact {
214                        user_id: host_user_id.to_proto(),
215                        projects: Vec::new(),
216                    })
217                    .projects
218                    .push(proto::ProjectMetadata {
219                        id: *project_id,
220                        worktree_root_names,
221                        is_shared: project.share.is_some(),
222                        guests: guests.into_iter().collect(),
223                    });
224            }
225        }
226
227        contacts.into_values().collect()
228    }
229
230    pub fn register_project(
231        &mut self,
232        host_connection_id: ConnectionId,
233        host_user_id: UserId,
234    ) -> u64 {
235        let project_id = self.next_project_id;
236        self.projects.insert(
237            project_id,
238            Project {
239                host_connection_id,
240                host_user_id,
241                share: None,
242                worktrees: Default::default(),
243            },
244        );
245        self.next_project_id += 1;
246        project_id
247    }
248
249    pub fn register_worktree(
250        &mut self,
251        project_id: u64,
252        worktree_id: u64,
253        connection_id: ConnectionId,
254        worktree: Worktree,
255    ) -> tide::Result<()> {
256        let project = self
257            .projects
258            .get_mut(&project_id)
259            .ok_or_else(|| anyhow!("no such project"))?;
260        if project.host_connection_id == connection_id {
261            for authorized_user_id in &worktree.authorized_user_ids {
262                self.visible_projects_by_user_id
263                    .entry(*authorized_user_id)
264                    .or_default()
265                    .insert(project_id);
266            }
267            if let Some(connection) = self.connections.get_mut(&project.host_connection_id) {
268                connection.projects.insert(project_id);
269            }
270            project.worktrees.insert(worktree_id, worktree);
271            if let Ok(share) = project.share_mut() {
272                share.worktrees.insert(worktree_id, Default::default());
273            }
274
275            #[cfg(test)]
276            self.check_invariants();
277            Ok(())
278        } else {
279            Err(anyhow!("no such project"))?
280        }
281    }
282
283    pub fn unregister_project(
284        &mut self,
285        project_id: u64,
286        connection_id: ConnectionId,
287    ) -> tide::Result<Project> {
288        match self.projects.entry(project_id) {
289            hash_map::Entry::Occupied(e) => {
290                if e.get().host_connection_id == connection_id {
291                    for user_id in e.get().authorized_user_ids() {
292                        if let hash_map::Entry::Occupied(mut projects) =
293                            self.visible_projects_by_user_id.entry(user_id)
294                        {
295                            projects.get_mut().remove(&project_id);
296                        }
297                    }
298
299                    let project = e.remove();
300
301                    if let Some(host_connection) = self.connections.get_mut(&connection_id) {
302                        host_connection.projects.remove(&project_id);
303                    }
304
305                    if let Some(share) = &project.share {
306                        for guest_connection in share.guests.keys() {
307                            if let Some(connection) = self.connections.get_mut(&guest_connection) {
308                                connection.projects.remove(&project_id);
309                            }
310                        }
311                    }
312
313                    #[cfg(test)]
314                    self.check_invariants();
315                    Ok(project)
316                } else {
317                    Err(anyhow!("no such project"))?
318                }
319            }
320            hash_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?,
321        }
322    }
323
324    pub fn unregister_worktree(
325        &mut self,
326        project_id: u64,
327        worktree_id: u64,
328        acting_connection_id: ConnectionId,
329    ) -> tide::Result<(Worktree, Vec<ConnectionId>)> {
330        let project = self
331            .projects
332            .get_mut(&project_id)
333            .ok_or_else(|| anyhow!("no such project"))?;
334        if project.host_connection_id != acting_connection_id {
335            Err(anyhow!("not your worktree"))?;
336        }
337
338        let worktree = project
339            .worktrees
340            .remove(&worktree_id)
341            .ok_or_else(|| anyhow!("no such worktree"))?;
342
343        let mut guest_connection_ids = Vec::new();
344        if let Ok(share) = project.share_mut() {
345            guest_connection_ids.extend(share.guests.keys());
346            share.worktrees.remove(&worktree_id);
347        }
348
349        for authorized_user_id in &worktree.authorized_user_ids {
350            if let Some(visible_projects) =
351                self.visible_projects_by_user_id.get_mut(authorized_user_id)
352            {
353                if !project.has_authorized_user_id(*authorized_user_id) {
354                    visible_projects.remove(&project_id);
355                }
356            }
357        }
358
359        #[cfg(test)]
360        self.check_invariants();
361
362        Ok((worktree, guest_connection_ids))
363    }
364
365    pub fn share_project(&mut self, project_id: u64, connection_id: ConnectionId) -> bool {
366        if let Some(project) = self.projects.get_mut(&project_id) {
367            if project.host_connection_id == connection_id {
368                let mut share = ProjectShare::default();
369                for worktree_id in project.worktrees.keys() {
370                    share.worktrees.insert(*worktree_id, Default::default());
371                }
372                project.share = Some(share);
373                return true;
374            }
375        }
376        false
377    }
378
379    pub fn unshare_project(
380        &mut self,
381        project_id: u64,
382        acting_connection_id: ConnectionId,
383    ) -> tide::Result<UnsharedProject> {
384        let project = if let Some(project) = self.projects.get_mut(&project_id) {
385            project
386        } else {
387            return Err(anyhow!("no such project"))?;
388        };
389
390        if project.host_connection_id != acting_connection_id {
391            return Err(anyhow!("not your project"))?;
392        }
393
394        let connection_ids = project.connection_ids();
395        let authorized_user_ids = project.authorized_user_ids();
396        if let Some(share) = project.share.take() {
397            for connection_id in share.guests.into_keys() {
398                if let Some(connection) = self.connections.get_mut(&connection_id) {
399                    connection.projects.remove(&project_id);
400                }
401            }
402
403            #[cfg(test)]
404            self.check_invariants();
405
406            Ok(UnsharedProject {
407                connection_ids,
408                authorized_user_ids,
409            })
410        } else {
411            Err(anyhow!("project is not shared"))?
412        }
413    }
414
415    pub fn update_diagnostic_summary(
416        &mut self,
417        project_id: u64,
418        worktree_id: u64,
419        connection_id: ConnectionId,
420        summary: proto::DiagnosticSummary,
421    ) -> tide::Result<Vec<ConnectionId>> {
422        let project = self
423            .projects
424            .get_mut(&project_id)
425            .ok_or_else(|| anyhow!("no such project"))?;
426        if project.host_connection_id == connection_id {
427            let worktree = project
428                .share_mut()?
429                .worktrees
430                .get_mut(&worktree_id)
431                .ok_or_else(|| anyhow!("no such worktree"))?;
432            worktree
433                .diagnostic_summaries
434                .insert(summary.path.clone().into(), summary);
435            return Ok(project.connection_ids());
436        }
437
438        Err(anyhow!("no such worktree"))?
439    }
440
441    pub fn join_project(
442        &mut self,
443        connection_id: ConnectionId,
444        user_id: UserId,
445        project_id: u64,
446    ) -> tide::Result<JoinedProject> {
447        let connection = self
448            .connections
449            .get_mut(&connection_id)
450            .ok_or_else(|| anyhow!("no such connection"))?;
451        let project = self
452            .projects
453            .get_mut(&project_id)
454            .and_then(|project| {
455                if project.has_authorized_user_id(user_id) {
456                    Some(project)
457                } else {
458                    None
459                }
460            })
461            .ok_or_else(|| anyhow!("no such project"))?;
462
463        let share = project.share_mut()?;
464        connection.projects.insert(project_id);
465
466        let mut replica_id = 1;
467        while share.active_replica_ids.contains(&replica_id) {
468            replica_id += 1;
469        }
470        share.active_replica_ids.insert(replica_id);
471        share.guests.insert(connection_id, (replica_id, user_id));
472
473        #[cfg(test)]
474        self.check_invariants();
475
476        Ok(JoinedProject {
477            replica_id,
478            project: &self.projects[&project_id],
479        })
480    }
481
482    pub fn leave_project(
483        &mut self,
484        connection_id: ConnectionId,
485        project_id: u64,
486    ) -> tide::Result<LeftProject> {
487        let project = self
488            .projects
489            .get_mut(&project_id)
490            .ok_or_else(|| anyhow!("no such project"))?;
491        let share = project
492            .share
493            .as_mut()
494            .ok_or_else(|| anyhow!("project is not shared"))?;
495        let (replica_id, _) = share
496            .guests
497            .remove(&connection_id)
498            .ok_or_else(|| anyhow!("cannot leave a project before joining it"))?;
499        share.active_replica_ids.remove(&replica_id);
500
501        if let Some(connection) = self.connections.get_mut(&connection_id) {
502            connection.projects.remove(&project_id);
503        }
504
505        let connection_ids = project.connection_ids();
506        let authorized_user_ids = project.authorized_user_ids();
507
508        #[cfg(test)]
509        self.check_invariants();
510
511        Ok(LeftProject {
512            connection_ids,
513            authorized_user_ids,
514        })
515    }
516
517    pub fn update_worktree(
518        &mut self,
519        connection_id: ConnectionId,
520        project_id: u64,
521        worktree_id: u64,
522        removed_entries: &[u64],
523        updated_entries: &[proto::Entry],
524    ) -> tide::Result<Vec<ConnectionId>> {
525        let project = self.write_project(project_id, connection_id)?;
526        let worktree = project
527            .share_mut()?
528            .worktrees
529            .get_mut(&worktree_id)
530            .ok_or_else(|| anyhow!("no such worktree"))?;
531        for entry_id in removed_entries {
532            worktree.entries.remove(&entry_id);
533        }
534        for entry in updated_entries {
535            worktree.entries.insert(entry.id, entry.clone());
536        }
537        Ok(project.connection_ids())
538    }
539
540    pub fn project_connection_ids(
541        &self,
542        project_id: u64,
543        acting_connection_id: ConnectionId,
544    ) -> tide::Result<Vec<ConnectionId>> {
545        Ok(self
546            .read_project(project_id, acting_connection_id)?
547            .connection_ids())
548    }
549
550    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> tide::Result<Vec<ConnectionId>> {
551        Ok(self
552            .channels
553            .get(&channel_id)
554            .ok_or_else(|| anyhow!("no such channel"))?
555            .connection_ids())
556    }
557
558    #[cfg(test)]
559    pub fn project(&self, project_id: u64) -> Option<&Project> {
560        self.projects.get(&project_id)
561    }
562
563    pub fn read_project(
564        &self,
565        project_id: u64,
566        connection_id: ConnectionId,
567    ) -> tide::Result<&Project> {
568        let project = self
569            .projects
570            .get(&project_id)
571            .ok_or_else(|| anyhow!("no such project"))?;
572        if project.host_connection_id == connection_id
573            || project
574                .share
575                .as_ref()
576                .ok_or_else(|| anyhow!("project is not shared"))?
577                .guests
578                .contains_key(&connection_id)
579        {
580            Ok(project)
581        } else {
582            Err(anyhow!("no such project"))?
583        }
584    }
585
586    fn write_project(
587        &mut self,
588        project_id: u64,
589        connection_id: ConnectionId,
590    ) -> tide::Result<&mut Project> {
591        let project = self
592            .projects
593            .get_mut(&project_id)
594            .ok_or_else(|| anyhow!("no such project"))?;
595        if project.host_connection_id == connection_id
596            || project
597                .share
598                .as_ref()
599                .ok_or_else(|| anyhow!("project is not shared"))?
600                .guests
601                .contains_key(&connection_id)
602        {
603            Ok(project)
604        } else {
605            Err(anyhow!("no such project"))?
606        }
607    }
608
609    #[cfg(test)]
610    fn check_invariants(&self) {
611        for (connection_id, connection) in &self.connections {
612            for project_id in &connection.projects {
613                let project = &self.projects.get(&project_id).unwrap();
614                if project.host_connection_id != *connection_id {
615                    assert!(project
616                        .share
617                        .as_ref()
618                        .unwrap()
619                        .guests
620                        .contains_key(connection_id));
621                }
622            }
623            for channel_id in &connection.channels {
624                let channel = self.channels.get(channel_id).unwrap();
625                assert!(channel.connection_ids.contains(connection_id));
626            }
627            assert!(self
628                .connections_by_user_id
629                .get(&connection.user_id)
630                .unwrap()
631                .contains(connection_id));
632        }
633
634        for (user_id, connection_ids) in &self.connections_by_user_id {
635            for connection_id in connection_ids {
636                assert_eq!(
637                    self.connections.get(connection_id).unwrap().user_id,
638                    *user_id
639                );
640            }
641        }
642
643        for (project_id, project) in &self.projects {
644            let host_connection = self.connections.get(&project.host_connection_id).unwrap();
645            assert!(host_connection.projects.contains(project_id));
646
647            for authorized_user_ids in project.authorized_user_ids() {
648                let visible_project_ids = self
649                    .visible_projects_by_user_id
650                    .get(&authorized_user_ids)
651                    .unwrap();
652                assert!(visible_project_ids.contains(project_id));
653            }
654
655            if let Some(share) = &project.share {
656                for guest_connection_id in share.guests.keys() {
657                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
658                    assert!(guest_connection.projects.contains(project_id));
659                }
660                assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
661                assert_eq!(
662                    share.active_replica_ids,
663                    share
664                        .guests
665                        .values()
666                        .map(|(replica_id, _)| *replica_id)
667                        .collect::<HashSet<_>>(),
668                );
669            }
670        }
671
672        for (user_id, visible_project_ids) in &self.visible_projects_by_user_id {
673            for project_id in visible_project_ids {
674                let project = self.projects.get(project_id).unwrap();
675                assert!(project.authorized_user_ids().contains(user_id));
676            }
677        }
678
679        for (channel_id, channel) in &self.channels {
680            for connection_id in &channel.connection_ids {
681                let connection = self.connections.get(connection_id).unwrap();
682                assert!(connection.channels.contains(channel_id));
683            }
684        }
685    }
686}
687
688impl Project {
689    pub fn has_authorized_user_id(&self, user_id: UserId) -> bool {
690        self.worktrees
691            .values()
692            .any(|worktree| worktree.authorized_user_ids.contains(&user_id))
693    }
694
695    pub fn authorized_user_ids(&self) -> Vec<UserId> {
696        let mut ids = self
697            .worktrees
698            .values()
699            .flat_map(|worktree| worktree.authorized_user_ids.iter())
700            .copied()
701            .collect::<Vec<_>>();
702        ids.sort_unstable();
703        ids.dedup();
704        ids
705    }
706
707    pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
708        if let Some(share) = &self.share {
709            share.guests.keys().copied().collect()
710        } else {
711            Vec::new()
712        }
713    }
714
715    pub fn connection_ids(&self) -> Vec<ConnectionId> {
716        if let Some(share) = &self.share {
717            share
718                .guests
719                .keys()
720                .copied()
721                .chain(Some(self.host_connection_id))
722                .collect()
723        } else {
724            vec![self.host_connection_id]
725        }
726    }
727
728    pub fn share(&self) -> tide::Result<&ProjectShare> {
729        Ok(self
730            .share
731            .as_ref()
732            .ok_or_else(|| anyhow!("worktree is not shared"))?)
733    }
734
735    fn share_mut(&mut self) -> tide::Result<&mut ProjectShare> {
736        Ok(self
737            .share
738            .as_mut()
739            .ok_or_else(|| anyhow!("worktree is not shared"))?)
740    }
741}
742
743impl Channel {
744    fn connection_ids(&self) -> Vec<ConnectionId> {
745        self.connection_ids.iter().copied().collect()
746    }
747}