store.rs

  1use crate::db::{ChannelId, UserId};
  2use anyhow::anyhow;
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use std::{collections::hash_map, path::PathBuf};
  6
  7#[derive(Default)]
  8pub struct Store {
  9    connections: HashMap<ConnectionId, ConnectionState>,
 10    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 11    projects: HashMap<u64, Project>,
 12    visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
 13    channels: HashMap<ChannelId, Channel>,
 14    next_project_id: u64,
 15}
 16
 17struct ConnectionState {
 18    user_id: UserId,
 19    projects: HashSet<u64>,
 20    channels: HashSet<ChannelId>,
 21}
 22
 23pub struct Project {
 24    pub host_connection_id: ConnectionId,
 25    pub host_user_id: UserId,
 26    pub share: Option<ProjectShare>,
 27    pub worktrees: HashMap<u64, Worktree>,
 28}
 29
 30pub struct Worktree {
 31    pub authorized_user_ids: Vec<UserId>,
 32    pub root_name: String,
 33    pub visible: bool,
 34}
 35
 36#[derive(Default)]
 37pub struct ProjectShare {
 38    pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
 39    pub active_replica_ids: HashSet<ReplicaId>,
 40    pub worktrees: HashMap<u64, WorktreeShare>,
 41}
 42
 43#[derive(Default)]
 44pub struct WorktreeShare {
 45    pub entries: HashMap<u64, proto::Entry>,
 46    pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
 47}
 48
 49#[derive(Default)]
 50pub struct Channel {
 51    pub connection_ids: HashSet<ConnectionId>,
 52}
 53
 54pub type ReplicaId = u16;
 55
 56#[derive(Default)]
 57pub struct RemovedConnectionState {
 58    pub hosted_projects: HashMap<u64, Project>,
 59    pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
 60    pub contact_ids: HashSet<UserId>,
 61}
 62
 63pub struct JoinedProject<'a> {
 64    pub replica_id: ReplicaId,
 65    pub project: &'a Project,
 66}
 67
 68pub struct UnsharedProject {
 69    pub connection_ids: Vec<ConnectionId>,
 70    pub authorized_user_ids: Vec<UserId>,
 71}
 72
 73pub struct LeftProject {
 74    pub connection_ids: Vec<ConnectionId>,
 75    pub authorized_user_ids: Vec<UserId>,
 76}
 77
 78impl Store {
 79    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 80        self.connections.insert(
 81            connection_id,
 82            ConnectionState {
 83                user_id,
 84                projects: Default::default(),
 85                channels: Default::default(),
 86            },
 87        );
 88        self.connections_by_user_id
 89            .entry(user_id)
 90            .or_default()
 91            .insert(connection_id);
 92    }
 93
 94    pub fn remove_connection(
 95        &mut self,
 96        connection_id: ConnectionId,
 97    ) -> tide::Result<RemovedConnectionState> {
 98        let connection = if let Some(connection) = self.connections.remove(&connection_id) {
 99            connection
100        } else {
101            return Err(anyhow!("no such connection"))?;
102        };
103
104        for channel_id in &connection.channels {
105            if let Some(channel) = self.channels.get_mut(&channel_id) {
106                channel.connection_ids.remove(&connection_id);
107            }
108        }
109
110        let user_connections = self
111            .connections_by_user_id
112            .get_mut(&connection.user_id)
113            .unwrap();
114        user_connections.remove(&connection_id);
115        if user_connections.is_empty() {
116            self.connections_by_user_id.remove(&connection.user_id);
117        }
118
119        let mut result = RemovedConnectionState::default();
120        for project_id in connection.projects.clone() {
121            if let Ok(project) = self.unregister_project(project_id, connection_id) {
122                result.contact_ids.extend(project.authorized_user_ids());
123                result.hosted_projects.insert(project_id, project);
124            } else if let Ok(project) = self.leave_project(connection_id, project_id) {
125                result
126                    .guest_project_ids
127                    .insert(project_id, project.connection_ids);
128                result.contact_ids.extend(project.authorized_user_ids);
129            }
130        }
131
132        #[cfg(test)]
133        self.check_invariants();
134
135        Ok(result)
136    }
137
138    #[cfg(test)]
139    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
140        self.channels.get(&id)
141    }
142
143    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
144        if let Some(connection) = self.connections.get_mut(&connection_id) {
145            connection.channels.insert(channel_id);
146            self.channels
147                .entry(channel_id)
148                .or_default()
149                .connection_ids
150                .insert(connection_id);
151        }
152    }
153
154    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
155        if let Some(connection) = self.connections.get_mut(&connection_id) {
156            connection.channels.remove(&channel_id);
157            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
158                entry.get_mut().connection_ids.remove(&connection_id);
159                if entry.get_mut().connection_ids.is_empty() {
160                    entry.remove();
161                }
162            }
163        }
164    }
165
166    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
167        Ok(self
168            .connections
169            .get(&connection_id)
170            .ok_or_else(|| anyhow!("unknown connection"))?
171            .user_id)
172    }
173
174    pub fn connection_ids_for_user<'a>(
175        &'a self,
176        user_id: UserId,
177    ) -> impl 'a + Iterator<Item = ConnectionId> {
178        self.connections_by_user_id
179            .get(&user_id)
180            .into_iter()
181            .flatten()
182            .copied()
183    }
184
185    pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
186        let mut contacts = HashMap::default();
187        for project_id in self
188            .visible_projects_by_user_id
189            .get(&user_id)
190            .unwrap_or(&HashSet::default())
191        {
192            let project = &self.projects[project_id];
193
194            let mut guests = HashSet::default();
195            if let Ok(share) = project.share() {
196                for guest_connection_id in share.guests.keys() {
197                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
198                        guests.insert(user_id.to_proto());
199                    }
200                }
201            }
202
203            if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
204                let mut worktree_root_names = project
205                    .worktrees
206                    .values()
207                    .filter(|worktree| worktree.visible)
208                    .map(|worktree| worktree.root_name.clone())
209                    .collect::<Vec<_>>();
210                worktree_root_names.sort_unstable();
211                contacts
212                    .entry(host_user_id)
213                    .or_insert_with(|| proto::Contact {
214                        user_id: host_user_id.to_proto(),
215                        projects: Vec::new(),
216                    })
217                    .projects
218                    .push(proto::ProjectMetadata {
219                        id: *project_id,
220                        worktree_root_names,
221                        is_shared: project.share.is_some(),
222                        guests: guests.into_iter().collect(),
223                    });
224            }
225        }
226
227        contacts.into_values().collect()
228    }
229
230    pub fn register_project(
231        &mut self,
232        host_connection_id: ConnectionId,
233        host_user_id: UserId,
234    ) -> u64 {
235        let project_id = self.next_project_id;
236        self.projects.insert(
237            project_id,
238            Project {
239                host_connection_id,
240                host_user_id,
241                share: None,
242                worktrees: Default::default(),
243            },
244        );
245        self.next_project_id += 1;
246        project_id
247    }
248
249    pub fn register_worktree(
250        &mut self,
251        project_id: u64,
252        worktree_id: u64,
253        connection_id: ConnectionId,
254        worktree: Worktree,
255    ) -> tide::Result<()> {
256        let project = self
257            .projects
258            .get_mut(&project_id)
259            .ok_or_else(|| anyhow!("no such project"))?;
260        if project.host_connection_id == connection_id {
261            for authorized_user_id in &worktree.authorized_user_ids {
262                self.visible_projects_by_user_id
263                    .entry(*authorized_user_id)
264                    .or_default()
265                    .insert(project_id);
266            }
267            if let Some(connection) = self.connections.get_mut(&project.host_connection_id) {
268                connection.projects.insert(project_id);
269            }
270            project.worktrees.insert(worktree_id, worktree);
271            if let Ok(share) = project.share_mut() {
272                share.worktrees.insert(worktree_id, Default::default());
273            }
274
275            #[cfg(test)]
276            self.check_invariants();
277            Ok(())
278        } else {
279            Err(anyhow!("no such project"))?
280        }
281    }
282
283    pub fn unregister_project(
284        &mut self,
285        project_id: u64,
286        connection_id: ConnectionId,
287    ) -> tide::Result<Project> {
288        match self.projects.entry(project_id) {
289            hash_map::Entry::Occupied(e) => {
290                if e.get().host_connection_id == connection_id {
291                    for user_id in e.get().authorized_user_ids() {
292                        if let hash_map::Entry::Occupied(mut projects) =
293                            self.visible_projects_by_user_id.entry(user_id)
294                        {
295                            projects.get_mut().remove(&project_id);
296                        }
297                    }
298
299                    let project = e.remove();
300
301                    if let Some(host_connection) = self.connections.get_mut(&connection_id) {
302                        host_connection.projects.remove(&project_id);
303                    }
304
305                    if let Some(share) = &project.share {
306                        for guest_connection in share.guests.keys() {
307                            if let Some(connection) = self.connections.get_mut(&guest_connection) {
308                                connection.projects.remove(&project_id);
309                            }
310                        }
311                    }
312
313                    #[cfg(test)]
314                    self.check_invariants();
315                    Ok(project)
316                } else {
317                    Err(anyhow!("no such project"))?
318                }
319            }
320            hash_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?,
321        }
322    }
323
324    pub fn unregister_worktree(
325        &mut self,
326        project_id: u64,
327        worktree_id: u64,
328        acting_connection_id: ConnectionId,
329    ) -> tide::Result<(Worktree, Vec<ConnectionId>)> {
330        let project = self
331            .projects
332            .get_mut(&project_id)
333            .ok_or_else(|| anyhow!("no such project"))?;
334        if project.host_connection_id != acting_connection_id {
335            Err(anyhow!("not your worktree"))?;
336        }
337
338        let worktree = project
339            .worktrees
340            .remove(&worktree_id)
341            .ok_or_else(|| anyhow!("no such worktree"))?;
342
343        let mut guest_connection_ids = Vec::new();
344        if let Ok(share) = project.share_mut() {
345            guest_connection_ids.extend(share.guests.keys());
346            share.worktrees.remove(&worktree_id);
347        }
348
349        for authorized_user_id in &worktree.authorized_user_ids {
350            if let Some(visible_projects) =
351                self.visible_projects_by_user_id.get_mut(authorized_user_id)
352            {
353                if !project.has_authorized_user_id(*authorized_user_id) {
354                    visible_projects.remove(&project_id);
355                }
356            }
357        }
358
359        #[cfg(test)]
360        self.check_invariants();
361
362        Ok((worktree, guest_connection_ids))
363    }
364
365    pub fn share_project(&mut self, project_id: u64, connection_id: ConnectionId) -> bool {
366        if let Some(project) = self.projects.get_mut(&project_id) {
367            if project.host_connection_id == connection_id {
368                let mut share = ProjectShare::default();
369                for worktree_id in project.worktrees.keys() {
370                    share.worktrees.insert(*worktree_id, Default::default());
371                }
372                project.share = Some(share);
373                return true;
374            }
375        }
376        false
377    }
378
379    pub fn unshare_project(
380        &mut self,
381        project_id: u64,
382        acting_connection_id: ConnectionId,
383    ) -> tide::Result<UnsharedProject> {
384        let project = if let Some(project) = self.projects.get_mut(&project_id) {
385            project
386        } else {
387            return Err(anyhow!("no such project"))?;
388        };
389
390        if project.host_connection_id != acting_connection_id {
391            return Err(anyhow!("not your project"))?;
392        }
393
394        let connection_ids = project.connection_ids();
395        let authorized_user_ids = project.authorized_user_ids();
396        if let Some(share) = project.share.take() {
397            for connection_id in share.guests.into_keys() {
398                if let Some(connection) = self.connections.get_mut(&connection_id) {
399                    connection.projects.remove(&project_id);
400                }
401            }
402
403            #[cfg(test)]
404            self.check_invariants();
405
406            Ok(UnsharedProject {
407                connection_ids,
408                authorized_user_ids,
409            })
410        } else {
411            Err(anyhow!("project is not shared"))?
412        }
413    }
414
415    pub fn update_diagnostic_summary(
416        &mut self,
417        project_id: u64,
418        worktree_id: u64,
419        connection_id: ConnectionId,
420        summary: proto::DiagnosticSummary,
421    ) -> tide::Result<Vec<ConnectionId>> {
422        let project = self
423            .projects
424            .get_mut(&project_id)
425            .ok_or_else(|| anyhow!("no such project"))?;
426        if project.host_connection_id == connection_id {
427            let worktree = project
428                .share_mut()?
429                .worktrees
430                .get_mut(&worktree_id)
431                .ok_or_else(|| anyhow!("no such worktree"))?;
432            worktree
433                .diagnostic_summaries
434                .insert(summary.path.clone().into(), summary);
435            return Ok(project.connection_ids());
436        }
437
438        Err(anyhow!("no such worktree"))?
439    }
440
441    pub fn join_project(
442        &mut self,
443        connection_id: ConnectionId,
444        user_id: UserId,
445        project_id: u64,
446    ) -> tide::Result<JoinedProject> {
447        let connection = self
448            .connections
449            .get_mut(&connection_id)
450            .ok_or_else(|| anyhow!("no such connection"))?;
451        let project = self
452            .projects
453            .get_mut(&project_id)
454            .and_then(|project| {
455                if project.has_authorized_user_id(user_id) {
456                    Some(project)
457                } else {
458                    None
459                }
460            })
461            .ok_or_else(|| anyhow!("no such project"))?;
462
463        let share = project.share_mut()?;
464        connection.projects.insert(project_id);
465
466        let mut replica_id = 1;
467        while share.active_replica_ids.contains(&replica_id) {
468            replica_id += 1;
469        }
470        share.active_replica_ids.insert(replica_id);
471        share.guests.insert(connection_id, (replica_id, user_id));
472
473        #[cfg(test)]
474        self.check_invariants();
475
476        Ok(JoinedProject {
477            replica_id,
478            project: &self.projects[&project_id],
479        })
480    }
481
482    pub fn leave_project(
483        &mut self,
484        connection_id: ConnectionId,
485        project_id: u64,
486    ) -> tide::Result<LeftProject> {
487        let project = self
488            .projects
489            .get_mut(&project_id)
490            .ok_or_else(|| anyhow!("no such project"))?;
491        let share = project
492            .share
493            .as_mut()
494            .ok_or_else(|| anyhow!("project is not shared"))?;
495        let (replica_id, _) = share
496            .guests
497            .remove(&connection_id)
498            .ok_or_else(|| anyhow!("cannot leave a project before joining it"))?;
499        share.active_replica_ids.remove(&replica_id);
500
501        if let Some(connection) = self.connections.get_mut(&connection_id) {
502            connection.projects.remove(&project_id);
503        }
504
505        let connection_ids = project.connection_ids();
506        let authorized_user_ids = project.authorized_user_ids();
507
508        #[cfg(test)]
509        self.check_invariants();
510
511        Ok(LeftProject {
512            connection_ids,
513            authorized_user_ids,
514        })
515    }
516
517    pub fn update_worktree(
518        &mut self,
519        connection_id: ConnectionId,
520        project_id: u64,
521        worktree_id: u64,
522        removed_entries: &[u64],
523        updated_entries: &[proto::Entry],
524    ) -> tide::Result<Vec<ConnectionId>> {
525        let project = self.write_project(project_id, connection_id)?;
526        let worktree = project
527            .share_mut()?
528            .worktrees
529            .get_mut(&worktree_id)
530            .ok_or_else(|| anyhow!("no such worktree"))?;
531        for entry_id in removed_entries {
532            worktree.entries.remove(&entry_id);
533        }
534        for entry in updated_entries {
535            worktree.entries.insert(entry.id, entry.clone());
536        }
537        let connection_ids = project.connection_ids();
538
539        #[cfg(test)]
540        self.check_invariants();
541
542        Ok(connection_ids)
543    }
544
545    pub fn project_connection_ids(
546        &self,
547        project_id: u64,
548        acting_connection_id: ConnectionId,
549    ) -> tide::Result<Vec<ConnectionId>> {
550        Ok(self
551            .read_project(project_id, acting_connection_id)?
552            .connection_ids())
553    }
554
555    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> tide::Result<Vec<ConnectionId>> {
556        Ok(self
557            .channels
558            .get(&channel_id)
559            .ok_or_else(|| anyhow!("no such channel"))?
560            .connection_ids())
561    }
562
563    #[cfg(test)]
564    pub fn project(&self, project_id: u64) -> Option<&Project> {
565        self.projects.get(&project_id)
566    }
567
568    pub fn read_project(
569        &self,
570        project_id: u64,
571        connection_id: ConnectionId,
572    ) -> tide::Result<&Project> {
573        let project = self
574            .projects
575            .get(&project_id)
576            .ok_or_else(|| anyhow!("no such project"))?;
577        if project.host_connection_id == connection_id
578            || project
579                .share
580                .as_ref()
581                .ok_or_else(|| anyhow!("project is not shared"))?
582                .guests
583                .contains_key(&connection_id)
584        {
585            Ok(project)
586        } else {
587            Err(anyhow!("no such project"))?
588        }
589    }
590
591    fn write_project(
592        &mut self,
593        project_id: u64,
594        connection_id: ConnectionId,
595    ) -> tide::Result<&mut Project> {
596        let project = self
597            .projects
598            .get_mut(&project_id)
599            .ok_or_else(|| anyhow!("no such project"))?;
600        if project.host_connection_id == connection_id
601            || project
602                .share
603                .as_ref()
604                .ok_or_else(|| anyhow!("project is not shared"))?
605                .guests
606                .contains_key(&connection_id)
607        {
608            Ok(project)
609        } else {
610            Err(anyhow!("no such project"))?
611        }
612    }
613
614    #[cfg(test)]
615    fn check_invariants(&self) {
616        for (connection_id, connection) in &self.connections {
617            for project_id in &connection.projects {
618                let project = &self.projects.get(&project_id).unwrap();
619                if project.host_connection_id != *connection_id {
620                    assert!(project
621                        .share
622                        .as_ref()
623                        .unwrap()
624                        .guests
625                        .contains_key(connection_id));
626                }
627
628                if let Some(share) = project.share.as_ref() {
629                    for (worktree_id, worktree) in share.worktrees.iter() {
630                        let mut paths = HashMap::default();
631                        for entry in worktree.entries.values() {
632                            let prev_entry = paths.insert(&entry.path, entry);
633                            assert_eq!(
634                                prev_entry,
635                                None,
636                                "worktree {:?}, duplicate path for entries {:?} and {:?}",
637                                worktree_id,
638                                prev_entry.unwrap(),
639                                entry
640                            );
641                        }
642                    }
643                }
644            }
645            for channel_id in &connection.channels {
646                let channel = self.channels.get(channel_id).unwrap();
647                assert!(channel.connection_ids.contains(connection_id));
648            }
649            assert!(self
650                .connections_by_user_id
651                .get(&connection.user_id)
652                .unwrap()
653                .contains(connection_id));
654        }
655
656        for (user_id, connection_ids) in &self.connections_by_user_id {
657            for connection_id in connection_ids {
658                assert_eq!(
659                    self.connections.get(connection_id).unwrap().user_id,
660                    *user_id
661                );
662            }
663        }
664
665        for (project_id, project) in &self.projects {
666            let host_connection = self.connections.get(&project.host_connection_id).unwrap();
667            assert!(host_connection.projects.contains(project_id));
668
669            for authorized_user_ids in project.authorized_user_ids() {
670                let visible_project_ids = self
671                    .visible_projects_by_user_id
672                    .get(&authorized_user_ids)
673                    .unwrap();
674                assert!(visible_project_ids.contains(project_id));
675            }
676
677            if let Some(share) = &project.share {
678                for guest_connection_id in share.guests.keys() {
679                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
680                    assert!(guest_connection.projects.contains(project_id));
681                }
682                assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
683                assert_eq!(
684                    share.active_replica_ids,
685                    share
686                        .guests
687                        .values()
688                        .map(|(replica_id, _)| *replica_id)
689                        .collect::<HashSet<_>>(),
690                );
691            }
692        }
693
694        for (user_id, visible_project_ids) in &self.visible_projects_by_user_id {
695            for project_id in visible_project_ids {
696                let project = self.projects.get(project_id).unwrap();
697                assert!(project.authorized_user_ids().contains(user_id));
698            }
699        }
700
701        for (channel_id, channel) in &self.channels {
702            for connection_id in &channel.connection_ids {
703                let connection = self.connections.get(connection_id).unwrap();
704                assert!(connection.channels.contains(channel_id));
705            }
706        }
707    }
708}
709
710impl Project {
711    pub fn has_authorized_user_id(&self, user_id: UserId) -> bool {
712        self.worktrees
713            .values()
714            .any(|worktree| worktree.authorized_user_ids.contains(&user_id))
715    }
716
717    pub fn authorized_user_ids(&self) -> Vec<UserId> {
718        let mut ids = self
719            .worktrees
720            .values()
721            .flat_map(|worktree| worktree.authorized_user_ids.iter())
722            .copied()
723            .collect::<Vec<_>>();
724        ids.sort_unstable();
725        ids.dedup();
726        ids
727    }
728
729    pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
730        if let Some(share) = &self.share {
731            share.guests.keys().copied().collect()
732        } else {
733            Vec::new()
734        }
735    }
736
737    pub fn connection_ids(&self) -> Vec<ConnectionId> {
738        if let Some(share) = &self.share {
739            share
740                .guests
741                .keys()
742                .copied()
743                .chain(Some(self.host_connection_id))
744                .collect()
745        } else {
746            vec![self.host_connection_id]
747        }
748    }
749
750    pub fn share(&self) -> tide::Result<&ProjectShare> {
751        Ok(self
752            .share
753            .as_ref()
754            .ok_or_else(|| anyhow!("worktree is not shared"))?)
755    }
756
757    fn share_mut(&mut self) -> tide::Result<&mut ProjectShare> {
758        Ok(self
759            .share
760            .as_mut()
761            .ok_or_else(|| anyhow!("worktree is not shared"))?)
762    }
763}
764
765impl Channel {
766    fn connection_ids(&self) -> Vec<ConnectionId> {
767        self.connection_ids.iter().copied().collect()
768    }
769}