store.rs

  1use crate::db::{ChannelId, UserId};
  2use anyhow::anyhow;
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use std::{collections::hash_map, path::PathBuf};
  6
  7#[derive(Default)]
  8pub struct Store {
  9    connections: HashMap<ConnectionId, ConnectionState>,
 10    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 11    projects: HashMap<u64, Project>,
 12    visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
 13    channels: HashMap<ChannelId, Channel>,
 14    next_project_id: u64,
 15}
 16
 17struct ConnectionState {
 18    user_id: UserId,
 19    projects: HashSet<u64>,
 20    channels: HashSet<ChannelId>,
 21}
 22
 23pub struct Project {
 24    pub host_connection_id: ConnectionId,
 25    pub host_user_id: UserId,
 26    pub share: Option<ProjectShare>,
 27    pub worktrees: HashMap<u64, Worktree>,
 28}
 29
 30pub struct Worktree {
 31    pub authorized_user_ids: Vec<UserId>,
 32    pub root_name: String,
 33    pub share: Option<WorktreeShare>,
 34    pub weak: bool,
 35}
 36
 37#[derive(Default)]
 38pub struct ProjectShare {
 39    pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
 40    pub active_replica_ids: HashSet<ReplicaId>,
 41}
 42
 43pub struct WorktreeShare {
 44    pub entries: HashMap<u64, proto::Entry>,
 45    pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
 46    pub next_update_id: u64,
 47}
 48
 49#[derive(Default)]
 50pub struct Channel {
 51    pub connection_ids: HashSet<ConnectionId>,
 52}
 53
 54pub type ReplicaId = u16;
 55
 56#[derive(Default)]
 57pub struct RemovedConnectionState {
 58    pub hosted_projects: HashMap<u64, Project>,
 59    pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
 60    pub contact_ids: HashSet<UserId>,
 61}
 62
 63pub struct JoinedProject<'a> {
 64    pub replica_id: ReplicaId,
 65    pub project: &'a Project,
 66}
 67
 68pub struct UnsharedProject {
 69    pub connection_ids: Vec<ConnectionId>,
 70    pub authorized_user_ids: Vec<UserId>,
 71}
 72
 73pub struct LeftProject {
 74    pub connection_ids: Vec<ConnectionId>,
 75    pub authorized_user_ids: Vec<UserId>,
 76}
 77
 78pub struct SharedWorktree {
 79    pub authorized_user_ids: Vec<UserId>,
 80    pub connection_ids: Vec<ConnectionId>,
 81}
 82
 83impl Store {
 84    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 85        self.connections.insert(
 86            connection_id,
 87            ConnectionState {
 88                user_id,
 89                projects: Default::default(),
 90                channels: Default::default(),
 91            },
 92        );
 93        self.connections_by_user_id
 94            .entry(user_id)
 95            .or_default()
 96            .insert(connection_id);
 97    }
 98
 99    pub fn remove_connection(
100        &mut self,
101        connection_id: ConnectionId,
102    ) -> tide::Result<RemovedConnectionState> {
103        let connection = if let Some(connection) = self.connections.remove(&connection_id) {
104            connection
105        } else {
106            return Err(anyhow!("no such connection"))?;
107        };
108
109        for channel_id in &connection.channels {
110            if let Some(channel) = self.channels.get_mut(&channel_id) {
111                channel.connection_ids.remove(&connection_id);
112            }
113        }
114
115        let user_connections = self
116            .connections_by_user_id
117            .get_mut(&connection.user_id)
118            .unwrap();
119        user_connections.remove(&connection_id);
120        if user_connections.is_empty() {
121            self.connections_by_user_id.remove(&connection.user_id);
122        }
123
124        let mut result = RemovedConnectionState::default();
125        for project_id in connection.projects.clone() {
126            if let Ok(project) = self.unregister_project(project_id, connection_id) {
127                result.contact_ids.extend(project.authorized_user_ids());
128                result.hosted_projects.insert(project_id, project);
129            } else if let Ok(project) = self.leave_project(connection_id, project_id) {
130                result
131                    .guest_project_ids
132                    .insert(project_id, project.connection_ids);
133                result.contact_ids.extend(project.authorized_user_ids);
134            }
135        }
136
137        #[cfg(test)]
138        self.check_invariants();
139
140        Ok(result)
141    }
142
143    #[cfg(test)]
144    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
145        self.channels.get(&id)
146    }
147
148    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
149        if let Some(connection) = self.connections.get_mut(&connection_id) {
150            connection.channels.insert(channel_id);
151            self.channels
152                .entry(channel_id)
153                .or_default()
154                .connection_ids
155                .insert(connection_id);
156        }
157    }
158
159    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
160        if let Some(connection) = self.connections.get_mut(&connection_id) {
161            connection.channels.remove(&channel_id);
162            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
163                entry.get_mut().connection_ids.remove(&connection_id);
164                if entry.get_mut().connection_ids.is_empty() {
165                    entry.remove();
166                }
167            }
168        }
169    }
170
171    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
172        Ok(self
173            .connections
174            .get(&connection_id)
175            .ok_or_else(|| anyhow!("unknown connection"))?
176            .user_id)
177    }
178
179    pub fn connection_ids_for_user<'a>(
180        &'a self,
181        user_id: UserId,
182    ) -> impl 'a + Iterator<Item = ConnectionId> {
183        self.connections_by_user_id
184            .get(&user_id)
185            .into_iter()
186            .flatten()
187            .copied()
188    }
189
190    pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
191        let mut contacts = HashMap::default();
192        for project_id in self
193            .visible_projects_by_user_id
194            .get(&user_id)
195            .unwrap_or(&HashSet::default())
196        {
197            let project = &self.projects[project_id];
198
199            let mut guests = HashSet::default();
200            if let Ok(share) = project.share() {
201                for guest_connection_id in share.guests.keys() {
202                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
203                        guests.insert(user_id.to_proto());
204                    }
205                }
206            }
207
208            if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
209                let mut worktree_root_names = project
210                    .worktrees
211                    .values()
212                    .filter(|worktree| !worktree.weak)
213                    .map(|worktree| worktree.root_name.clone())
214                    .collect::<Vec<_>>();
215                worktree_root_names.sort_unstable();
216                contacts
217                    .entry(host_user_id)
218                    .or_insert_with(|| proto::Contact {
219                        user_id: host_user_id.to_proto(),
220                        projects: Vec::new(),
221                    })
222                    .projects
223                    .push(proto::ProjectMetadata {
224                        id: *project_id,
225                        worktree_root_names,
226                        is_shared: project.share.is_some(),
227                        guests: guests.into_iter().collect(),
228                    });
229            }
230        }
231
232        contacts.into_values().collect()
233    }
234
235    pub fn register_project(
236        &mut self,
237        host_connection_id: ConnectionId,
238        host_user_id: UserId,
239    ) -> u64 {
240        let project_id = self.next_project_id;
241        self.projects.insert(
242            project_id,
243            Project {
244                host_connection_id,
245                host_user_id,
246                share: None,
247                worktrees: Default::default(),
248            },
249        );
250        self.next_project_id += 1;
251        project_id
252    }
253
254    pub fn register_worktree(
255        &mut self,
256        project_id: u64,
257        worktree_id: u64,
258        connection_id: ConnectionId,
259        worktree: Worktree,
260    ) -> tide::Result<()> {
261        let project = self
262            .projects
263            .get_mut(&project_id)
264            .ok_or_else(|| anyhow!("no such project"))?;
265        if project.host_connection_id == connection_id {
266            for authorized_user_id in &worktree.authorized_user_ids {
267                self.visible_projects_by_user_id
268                    .entry(*authorized_user_id)
269                    .or_default()
270                    .insert(project_id);
271            }
272            if let Some(connection) = self.connections.get_mut(&project.host_connection_id) {
273                connection.projects.insert(project_id);
274            }
275            project.worktrees.insert(worktree_id, worktree);
276
277            #[cfg(test)]
278            self.check_invariants();
279            Ok(())
280        } else {
281            Err(anyhow!("no such project"))?
282        }
283    }
284
285    pub fn unregister_project(
286        &mut self,
287        project_id: u64,
288        connection_id: ConnectionId,
289    ) -> tide::Result<Project> {
290        match self.projects.entry(project_id) {
291            hash_map::Entry::Occupied(e) => {
292                if e.get().host_connection_id == connection_id {
293                    for user_id in e.get().authorized_user_ids() {
294                        if let hash_map::Entry::Occupied(mut projects) =
295                            self.visible_projects_by_user_id.entry(user_id)
296                        {
297                            projects.get_mut().remove(&project_id);
298                        }
299                    }
300
301                    Ok(e.remove())
302                } else {
303                    Err(anyhow!("no such project"))?
304                }
305            }
306            hash_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?,
307        }
308    }
309
310    pub fn unregister_worktree(
311        &mut self,
312        project_id: u64,
313        worktree_id: u64,
314        acting_connection_id: ConnectionId,
315    ) -> tide::Result<(Worktree, Vec<ConnectionId>)> {
316        let project = self
317            .projects
318            .get_mut(&project_id)
319            .ok_or_else(|| anyhow!("no such project"))?;
320        if project.host_connection_id != acting_connection_id {
321            Err(anyhow!("not your worktree"))?;
322        }
323
324        let worktree = project
325            .worktrees
326            .remove(&worktree_id)
327            .ok_or_else(|| anyhow!("no such worktree"))?;
328
329        let mut guest_connection_ids = Vec::new();
330        if let Some(share) = &project.share {
331            guest_connection_ids.extend(share.guests.keys());
332        }
333
334        for authorized_user_id in &worktree.authorized_user_ids {
335            if let Some(visible_projects) =
336                self.visible_projects_by_user_id.get_mut(authorized_user_id)
337            {
338                if !project.has_authorized_user_id(*authorized_user_id) {
339                    visible_projects.remove(&project_id);
340                }
341            }
342        }
343
344        #[cfg(test)]
345        self.check_invariants();
346
347        Ok((worktree, guest_connection_ids))
348    }
349
350    pub fn share_project(&mut self, project_id: u64, connection_id: ConnectionId) -> bool {
351        if let Some(project) = self.projects.get_mut(&project_id) {
352            if project.host_connection_id == connection_id {
353                project.share = Some(ProjectShare::default());
354                return true;
355            }
356        }
357        false
358    }
359
360    pub fn unshare_project(
361        &mut self,
362        project_id: u64,
363        acting_connection_id: ConnectionId,
364    ) -> tide::Result<UnsharedProject> {
365        let project = if let Some(project) = self.projects.get_mut(&project_id) {
366            project
367        } else {
368            return Err(anyhow!("no such project"))?;
369        };
370
371        if project.host_connection_id != acting_connection_id {
372            return Err(anyhow!("not your project"))?;
373        }
374
375        let connection_ids = project.connection_ids();
376        let authorized_user_ids = project.authorized_user_ids();
377        if let Some(share) = project.share.take() {
378            for connection_id in share.guests.into_keys() {
379                if let Some(connection) = self.connections.get_mut(&connection_id) {
380                    connection.projects.remove(&project_id);
381                }
382            }
383
384            for worktree in project.worktrees.values_mut() {
385                worktree.share.take();
386            }
387
388            #[cfg(test)]
389            self.check_invariants();
390
391            Ok(UnsharedProject {
392                connection_ids,
393                authorized_user_ids,
394            })
395        } else {
396            Err(anyhow!("project is not shared"))?
397        }
398    }
399
400    pub fn share_worktree(
401        &mut self,
402        project_id: u64,
403        worktree_id: u64,
404        connection_id: ConnectionId,
405        entries: HashMap<u64, proto::Entry>,
406        diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
407        next_update_id: u64,
408    ) -> tide::Result<SharedWorktree> {
409        let project = self
410            .projects
411            .get_mut(&project_id)
412            .ok_or_else(|| anyhow!("no such project"))?;
413        let worktree = project
414            .worktrees
415            .get_mut(&worktree_id)
416            .ok_or_else(|| anyhow!("no such worktree"))?;
417        if project.host_connection_id == connection_id && project.share.is_some() {
418            worktree.share = Some(WorktreeShare {
419                entries,
420                diagnostic_summaries,
421                next_update_id,
422            });
423            Ok(SharedWorktree {
424                authorized_user_ids: project.authorized_user_ids(),
425                connection_ids: project.guest_connection_ids(),
426            })
427        } else {
428            Err(anyhow!("no such worktree"))?
429        }
430    }
431
432    pub fn update_diagnostic_summary(
433        &mut self,
434        project_id: u64,
435        worktree_id: u64,
436        connection_id: ConnectionId,
437        summary: proto::DiagnosticSummary,
438    ) -> tide::Result<Vec<ConnectionId>> {
439        let project = self
440            .projects
441            .get_mut(&project_id)
442            .ok_or_else(|| anyhow!("no such project"))?;
443        let worktree = project
444            .worktrees
445            .get_mut(&worktree_id)
446            .ok_or_else(|| anyhow!("no such worktree"))?;
447        if project.host_connection_id == connection_id {
448            if let Some(share) = worktree.share.as_mut() {
449                share
450                    .diagnostic_summaries
451                    .insert(summary.path.clone().into(), summary);
452                return Ok(project.connection_ids());
453            }
454        }
455
456        Err(anyhow!("no such worktree"))?
457    }
458
459    pub fn join_project(
460        &mut self,
461        connection_id: ConnectionId,
462        user_id: UserId,
463        project_id: u64,
464    ) -> tide::Result<JoinedProject> {
465        let connection = self
466            .connections
467            .get_mut(&connection_id)
468            .ok_or_else(|| anyhow!("no such connection"))?;
469        let project = self
470            .projects
471            .get_mut(&project_id)
472            .and_then(|project| {
473                if project.has_authorized_user_id(user_id) {
474                    Some(project)
475                } else {
476                    None
477                }
478            })
479            .ok_or_else(|| anyhow!("no such project"))?;
480
481        let share = project.share_mut()?;
482        connection.projects.insert(project_id);
483
484        let mut replica_id = 1;
485        while share.active_replica_ids.contains(&replica_id) {
486            replica_id += 1;
487        }
488        share.active_replica_ids.insert(replica_id);
489        share.guests.insert(connection_id, (replica_id, user_id));
490
491        #[cfg(test)]
492        self.check_invariants();
493
494        Ok(JoinedProject {
495            replica_id,
496            project: &self.projects[&project_id],
497        })
498    }
499
500    pub fn leave_project(
501        &mut self,
502        connection_id: ConnectionId,
503        project_id: u64,
504    ) -> tide::Result<LeftProject> {
505        let project = self
506            .projects
507            .get_mut(&project_id)
508            .ok_or_else(|| anyhow!("no such project"))?;
509        let share = project
510            .share
511            .as_mut()
512            .ok_or_else(|| anyhow!("project is not shared"))?;
513        let (replica_id, _) = share
514            .guests
515            .remove(&connection_id)
516            .ok_or_else(|| anyhow!("cannot leave a project before joining it"))?;
517        share.active_replica_ids.remove(&replica_id);
518
519        if let Some(connection) = self.connections.get_mut(&connection_id) {
520            connection.projects.remove(&project_id);
521        }
522
523        let connection_ids = project.connection_ids();
524        let authorized_user_ids = project.authorized_user_ids();
525
526        #[cfg(test)]
527        self.check_invariants();
528
529        Ok(LeftProject {
530            connection_ids,
531            authorized_user_ids,
532        })
533    }
534
535    pub fn update_worktree(
536        &mut self,
537        connection_id: ConnectionId,
538        project_id: u64,
539        worktree_id: u64,
540        update_id: u64,
541        removed_entries: &[u64],
542        updated_entries: &[proto::Entry],
543    ) -> tide::Result<Vec<ConnectionId>> {
544        let project = self.write_project(project_id, connection_id)?;
545        let share = project
546            .worktrees
547            .get_mut(&worktree_id)
548            .ok_or_else(|| anyhow!("no such worktree"))?
549            .share
550            .as_mut()
551            .ok_or_else(|| anyhow!("worktree is not shared"))?;
552        if share.next_update_id != update_id {
553            return Err(anyhow!("received worktree updates out-of-order"))?;
554        }
555
556        share.next_update_id = update_id + 1;
557        for entry_id in removed_entries {
558            share.entries.remove(&entry_id);
559        }
560        for entry in updated_entries {
561            share.entries.insert(entry.id, entry.clone());
562        }
563        Ok(project.connection_ids())
564    }
565
566    pub fn project_connection_ids(
567        &self,
568        project_id: u64,
569        acting_connection_id: ConnectionId,
570    ) -> tide::Result<Vec<ConnectionId>> {
571        Ok(self
572            .read_project(project_id, acting_connection_id)?
573            .connection_ids())
574    }
575
576    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> tide::Result<Vec<ConnectionId>> {
577        Ok(self
578            .channels
579            .get(&channel_id)
580            .ok_or_else(|| anyhow!("no such channel"))?
581            .connection_ids())
582    }
583
584    #[cfg(test)]
585    pub fn project(&self, project_id: u64) -> Option<&Project> {
586        self.projects.get(&project_id)
587    }
588
589    pub fn read_project(
590        &self,
591        project_id: u64,
592        connection_id: ConnectionId,
593    ) -> tide::Result<&Project> {
594        let project = self
595            .projects
596            .get(&project_id)
597            .ok_or_else(|| anyhow!("no such project"))?;
598        if project.host_connection_id == connection_id
599            || project
600                .share
601                .as_ref()
602                .ok_or_else(|| anyhow!("project is not shared"))?
603                .guests
604                .contains_key(&connection_id)
605        {
606            Ok(project)
607        } else {
608            Err(anyhow!("no such project"))?
609        }
610    }
611
612    fn write_project(
613        &mut self,
614        project_id: u64,
615        connection_id: ConnectionId,
616    ) -> tide::Result<&mut Project> {
617        let project = self
618            .projects
619            .get_mut(&project_id)
620            .ok_or_else(|| anyhow!("no such project"))?;
621        if project.host_connection_id == connection_id
622            || project
623                .share
624                .as_ref()
625                .ok_or_else(|| anyhow!("project is not shared"))?
626                .guests
627                .contains_key(&connection_id)
628        {
629            Ok(project)
630        } else {
631            Err(anyhow!("no such project"))?
632        }
633    }
634
635    #[cfg(test)]
636    fn check_invariants(&self) {
637        for (connection_id, connection) in &self.connections {
638            for project_id in &connection.projects {
639                let project = &self.projects.get(&project_id).unwrap();
640                if project.host_connection_id != *connection_id {
641                    assert!(project
642                        .share
643                        .as_ref()
644                        .unwrap()
645                        .guests
646                        .contains_key(connection_id));
647                }
648            }
649            for channel_id in &connection.channels {
650                let channel = self.channels.get(channel_id).unwrap();
651                assert!(channel.connection_ids.contains(connection_id));
652            }
653            assert!(self
654                .connections_by_user_id
655                .get(&connection.user_id)
656                .unwrap()
657                .contains(connection_id));
658        }
659
660        for (user_id, connection_ids) in &self.connections_by_user_id {
661            for connection_id in connection_ids {
662                assert_eq!(
663                    self.connections.get(connection_id).unwrap().user_id,
664                    *user_id
665                );
666            }
667        }
668
669        for (project_id, project) in &self.projects {
670            let host_connection = self.connections.get(&project.host_connection_id).unwrap();
671            assert!(host_connection.projects.contains(project_id));
672
673            for authorized_user_ids in project.authorized_user_ids() {
674                let visible_project_ids = self
675                    .visible_projects_by_user_id
676                    .get(&authorized_user_ids)
677                    .unwrap();
678                assert!(visible_project_ids.contains(project_id));
679            }
680
681            if let Some(share) = &project.share {
682                for guest_connection_id in share.guests.keys() {
683                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
684                    assert!(guest_connection.projects.contains(project_id));
685                }
686                assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
687                assert_eq!(
688                    share.active_replica_ids,
689                    share
690                        .guests
691                        .values()
692                        .map(|(replica_id, _)| *replica_id)
693                        .collect::<HashSet<_>>(),
694                );
695            }
696        }
697
698        for (user_id, visible_project_ids) in &self.visible_projects_by_user_id {
699            for project_id in visible_project_ids {
700                let project = self.projects.get(project_id).unwrap();
701                assert!(project.authorized_user_ids().contains(user_id));
702            }
703        }
704
705        for (channel_id, channel) in &self.channels {
706            for connection_id in &channel.connection_ids {
707                let connection = self.connections.get(connection_id).unwrap();
708                assert!(connection.channels.contains(channel_id));
709            }
710        }
711    }
712}
713
714impl Project {
715    pub fn has_authorized_user_id(&self, user_id: UserId) -> bool {
716        self.worktrees
717            .values()
718            .any(|worktree| worktree.authorized_user_ids.contains(&user_id))
719    }
720
721    pub fn authorized_user_ids(&self) -> Vec<UserId> {
722        let mut ids = self
723            .worktrees
724            .values()
725            .flat_map(|worktree| worktree.authorized_user_ids.iter())
726            .copied()
727            .collect::<Vec<_>>();
728        ids.sort_unstable();
729        ids.dedup();
730        ids
731    }
732
733    pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
734        if let Some(share) = &self.share {
735            share.guests.keys().copied().collect()
736        } else {
737            Vec::new()
738        }
739    }
740
741    pub fn connection_ids(&self) -> Vec<ConnectionId> {
742        if let Some(share) = &self.share {
743            share
744                .guests
745                .keys()
746                .copied()
747                .chain(Some(self.host_connection_id))
748                .collect()
749        } else {
750            vec![self.host_connection_id]
751        }
752    }
753
754    pub fn share(&self) -> tide::Result<&ProjectShare> {
755        Ok(self
756            .share
757            .as_ref()
758            .ok_or_else(|| anyhow!("worktree is not shared"))?)
759    }
760
761    fn share_mut(&mut self) -> tide::Result<&mut ProjectShare> {
762        Ok(self
763            .share
764            .as_mut()
765            .ok_or_else(|| anyhow!("worktree is not shared"))?)
766    }
767}
768
769impl Channel {
770    fn connection_ids(&self) -> Vec<ConnectionId> {
771        self.connection_ids.iter().copied().collect()
772    }
773}