store.rs

  1use crate::db::{ChannelId, UserId};
  2use anyhow::anyhow;
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use std::{collections::hash_map, path::PathBuf};
  6
  7#[derive(Default)]
  8pub struct Store {
  9    connections: HashMap<ConnectionId, ConnectionState>,
 10    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 11    projects: HashMap<u64, Project>,
 12    visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
 13    channels: HashMap<ChannelId, Channel>,
 14    next_project_id: u64,
 15}
 16
 17struct ConnectionState {
 18    user_id: UserId,
 19    projects: HashSet<u64>,
 20    channels: HashSet<ChannelId>,
 21}
 22
 23pub struct Project {
 24    pub host_connection_id: ConnectionId,
 25    pub host_user_id: UserId,
 26    pub share: Option<ProjectShare>,
 27    pub worktrees: HashMap<u64, Worktree>,
 28}
 29
 30pub struct Worktree {
 31    pub authorized_user_ids: Vec<UserId>,
 32    pub root_name: String,
 33    pub weak: bool,
 34}
 35
 36#[derive(Default)]
 37pub struct ProjectShare {
 38    pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
 39    pub active_replica_ids: HashSet<ReplicaId>,
 40    pub worktrees: HashMap<u64, WorktreeShare>,
 41}
 42
 43#[derive(Default)]
 44pub struct WorktreeShare {
 45    pub entries: HashMap<u64, proto::Entry>,
 46    pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
 47}
 48
 49#[derive(Default)]
 50pub struct Channel {
 51    pub connection_ids: HashSet<ConnectionId>,
 52}
 53
 54pub type ReplicaId = u16;
 55
 56#[derive(Default)]
 57pub struct RemovedConnectionState {
 58    pub hosted_projects: HashMap<u64, Project>,
 59    pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
 60    pub contact_ids: HashSet<UserId>,
 61}
 62
 63pub struct JoinedProject<'a> {
 64    pub replica_id: ReplicaId,
 65    pub project: &'a Project,
 66}
 67
 68pub struct UnsharedProject {
 69    pub connection_ids: Vec<ConnectionId>,
 70    pub authorized_user_ids: Vec<UserId>,
 71}
 72
 73pub struct LeftProject {
 74    pub connection_ids: Vec<ConnectionId>,
 75    pub authorized_user_ids: Vec<UserId>,
 76}
 77
 78impl Store {
 79    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 80        self.connections.insert(
 81            connection_id,
 82            ConnectionState {
 83                user_id,
 84                projects: Default::default(),
 85                channels: Default::default(),
 86            },
 87        );
 88        self.connections_by_user_id
 89            .entry(user_id)
 90            .or_default()
 91            .insert(connection_id);
 92    }
 93
 94    pub fn remove_connection(
 95        &mut self,
 96        connection_id: ConnectionId,
 97    ) -> tide::Result<RemovedConnectionState> {
 98        let connection = if let Some(connection) = self.connections.remove(&connection_id) {
 99            connection
100        } else {
101            return Err(anyhow!("no such connection"))?;
102        };
103
104        for channel_id in &connection.channels {
105            if let Some(channel) = self.channels.get_mut(&channel_id) {
106                channel.connection_ids.remove(&connection_id);
107            }
108        }
109
110        let user_connections = self
111            .connections_by_user_id
112            .get_mut(&connection.user_id)
113            .unwrap();
114        user_connections.remove(&connection_id);
115        if user_connections.is_empty() {
116            self.connections_by_user_id.remove(&connection.user_id);
117        }
118
119        let mut result = RemovedConnectionState::default();
120        for project_id in connection.projects.clone() {
121            if let Ok(project) = self.unregister_project(project_id, connection_id) {
122                result.contact_ids.extend(project.authorized_user_ids());
123                result.hosted_projects.insert(project_id, project);
124            } else if let Ok(project) = self.leave_project(connection_id, project_id) {
125                result
126                    .guest_project_ids
127                    .insert(project_id, project.connection_ids);
128                result.contact_ids.extend(project.authorized_user_ids);
129            }
130        }
131
132        #[cfg(test)]
133        self.check_invariants();
134
135        Ok(result)
136    }
137
138    #[cfg(test)]
139    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
140        self.channels.get(&id)
141    }
142
143    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
144        if let Some(connection) = self.connections.get_mut(&connection_id) {
145            connection.channels.insert(channel_id);
146            self.channels
147                .entry(channel_id)
148                .or_default()
149                .connection_ids
150                .insert(connection_id);
151        }
152    }
153
154    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
155        if let Some(connection) = self.connections.get_mut(&connection_id) {
156            connection.channels.remove(&channel_id);
157            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
158                entry.get_mut().connection_ids.remove(&connection_id);
159                if entry.get_mut().connection_ids.is_empty() {
160                    entry.remove();
161                }
162            }
163        }
164    }
165
166    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
167        Ok(self
168            .connections
169            .get(&connection_id)
170            .ok_or_else(|| anyhow!("unknown connection"))?
171            .user_id)
172    }
173
174    pub fn connection_ids_for_user<'a>(
175        &'a self,
176        user_id: UserId,
177    ) -> impl 'a + Iterator<Item = ConnectionId> {
178        self.connections_by_user_id
179            .get(&user_id)
180            .into_iter()
181            .flatten()
182            .copied()
183    }
184
185    pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
186        let mut contacts = HashMap::default();
187        for project_id in self
188            .visible_projects_by_user_id
189            .get(&user_id)
190            .unwrap_or(&HashSet::default())
191        {
192            let project = &self.projects[project_id];
193
194            let mut guests = HashSet::default();
195            if let Ok(share) = project.share() {
196                for guest_connection_id in share.guests.keys() {
197                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
198                        guests.insert(user_id.to_proto());
199                    }
200                }
201            }
202
203            if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
204                let mut worktree_root_names = project
205                    .worktrees
206                    .values()
207                    .filter(|worktree| !worktree.weak)
208                    .map(|worktree| worktree.root_name.clone())
209                    .collect::<Vec<_>>();
210                worktree_root_names.sort_unstable();
211                contacts
212                    .entry(host_user_id)
213                    .or_insert_with(|| proto::Contact {
214                        user_id: host_user_id.to_proto(),
215                        projects: Vec::new(),
216                    })
217                    .projects
218                    .push(proto::ProjectMetadata {
219                        id: *project_id,
220                        worktree_root_names,
221                        is_shared: project.share.is_some(),
222                        guests: guests.into_iter().collect(),
223                    });
224            }
225        }
226
227        contacts.into_values().collect()
228    }
229
230    pub fn register_project(
231        &mut self,
232        host_connection_id: ConnectionId,
233        host_user_id: UserId,
234    ) -> u64 {
235        let project_id = self.next_project_id;
236        self.projects.insert(
237            project_id,
238            Project {
239                host_connection_id,
240                host_user_id,
241                share: None,
242                worktrees: Default::default(),
243            },
244        );
245        self.next_project_id += 1;
246        project_id
247    }
248
249    pub fn register_worktree(
250        &mut self,
251        project_id: u64,
252        worktree_id: u64,
253        connection_id: ConnectionId,
254        worktree: Worktree,
255    ) -> tide::Result<()> {
256        let project = self
257            .projects
258            .get_mut(&project_id)
259            .ok_or_else(|| anyhow!("no such project"))?;
260        if project.host_connection_id == connection_id {
261            for authorized_user_id in &worktree.authorized_user_ids {
262                self.visible_projects_by_user_id
263                    .entry(*authorized_user_id)
264                    .or_default()
265                    .insert(project_id);
266            }
267            if let Some(connection) = self.connections.get_mut(&project.host_connection_id) {
268                connection.projects.insert(project_id);
269            }
270            project.worktrees.insert(worktree_id, worktree);
271            if let Ok(share) = project.share_mut() {
272                share.worktrees.insert(worktree_id, Default::default());
273            }
274
275            #[cfg(test)]
276            self.check_invariants();
277            Ok(())
278        } else {
279            Err(anyhow!("no such project"))?
280        }
281    }
282
283    pub fn unregister_project(
284        &mut self,
285        project_id: u64,
286        connection_id: ConnectionId,
287    ) -> tide::Result<Project> {
288        match self.projects.entry(project_id) {
289            hash_map::Entry::Occupied(e) => {
290                if e.get().host_connection_id == connection_id {
291                    for user_id in e.get().authorized_user_ids() {
292                        if let hash_map::Entry::Occupied(mut projects) =
293                            self.visible_projects_by_user_id.entry(user_id)
294                        {
295                            projects.get_mut().remove(&project_id);
296                        }
297                    }
298
299                    Ok(e.remove())
300                } else {
301                    Err(anyhow!("no such project"))?
302                }
303            }
304            hash_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?,
305        }
306    }
307
308    pub fn unregister_worktree(
309        &mut self,
310        project_id: u64,
311        worktree_id: u64,
312        acting_connection_id: ConnectionId,
313    ) -> tide::Result<(Worktree, Vec<ConnectionId>)> {
314        let project = self
315            .projects
316            .get_mut(&project_id)
317            .ok_or_else(|| anyhow!("no such project"))?;
318        if project.host_connection_id != acting_connection_id {
319            Err(anyhow!("not your worktree"))?;
320        }
321
322        let worktree = project
323            .worktrees
324            .remove(&worktree_id)
325            .ok_or_else(|| anyhow!("no such worktree"))?;
326
327        let mut guest_connection_ids = Vec::new();
328        if let Ok(share) = project.share_mut() {
329            guest_connection_ids.extend(share.guests.keys());
330            share.worktrees.remove(&worktree_id);
331        }
332
333        for authorized_user_id in &worktree.authorized_user_ids {
334            if let Some(visible_projects) =
335                self.visible_projects_by_user_id.get_mut(authorized_user_id)
336            {
337                if !project.has_authorized_user_id(*authorized_user_id) {
338                    visible_projects.remove(&project_id);
339                }
340            }
341        }
342
343        #[cfg(test)]
344        self.check_invariants();
345
346        Ok((worktree, guest_connection_ids))
347    }
348
349    pub fn share_project(&mut self, project_id: u64, connection_id: ConnectionId) -> bool {
350        if let Some(project) = self.projects.get_mut(&project_id) {
351            if project.host_connection_id == connection_id {
352                let mut share = ProjectShare::default();
353                for worktree_id in project.worktrees.keys() {
354                    share.worktrees.insert(*worktree_id, Default::default());
355                }
356                project.share = Some(share);
357                return true;
358            }
359        }
360        false
361    }
362
363    pub fn unshare_project(
364        &mut self,
365        project_id: u64,
366        acting_connection_id: ConnectionId,
367    ) -> tide::Result<UnsharedProject> {
368        let project = if let Some(project) = self.projects.get_mut(&project_id) {
369            project
370        } else {
371            return Err(anyhow!("no such project"))?;
372        };
373
374        if project.host_connection_id != acting_connection_id {
375            return Err(anyhow!("not your project"))?;
376        }
377
378        let connection_ids = project.connection_ids();
379        let authorized_user_ids = project.authorized_user_ids();
380        if let Some(share) = project.share.take() {
381            for connection_id in share.guests.into_keys() {
382                if let Some(connection) = self.connections.get_mut(&connection_id) {
383                    connection.projects.remove(&project_id);
384                }
385            }
386
387            #[cfg(test)]
388            self.check_invariants();
389
390            Ok(UnsharedProject {
391                connection_ids,
392                authorized_user_ids,
393            })
394        } else {
395            Err(anyhow!("project is not shared"))?
396        }
397    }
398
399    pub fn update_diagnostic_summary(
400        &mut self,
401        project_id: u64,
402        worktree_id: u64,
403        connection_id: ConnectionId,
404        summary: proto::DiagnosticSummary,
405    ) -> tide::Result<Vec<ConnectionId>> {
406        let project = self
407            .projects
408            .get_mut(&project_id)
409            .ok_or_else(|| anyhow!("no such project"))?;
410        if project.host_connection_id == connection_id {
411            let worktree = project
412                .share_mut()?
413                .worktrees
414                .get_mut(&worktree_id)
415                .ok_or_else(|| anyhow!("no such worktree"))?;
416            worktree
417                .diagnostic_summaries
418                .insert(summary.path.clone().into(), summary);
419            return Ok(project.connection_ids());
420        }
421
422        Err(anyhow!("no such worktree"))?
423    }
424
425    pub fn join_project(
426        &mut self,
427        connection_id: ConnectionId,
428        user_id: UserId,
429        project_id: u64,
430    ) -> tide::Result<JoinedProject> {
431        let connection = self
432            .connections
433            .get_mut(&connection_id)
434            .ok_or_else(|| anyhow!("no such connection"))?;
435        let project = self
436            .projects
437            .get_mut(&project_id)
438            .and_then(|project| {
439                if project.has_authorized_user_id(user_id) {
440                    Some(project)
441                } else {
442                    None
443                }
444            })
445            .ok_or_else(|| anyhow!("no such project"))?;
446
447        let share = project.share_mut()?;
448        connection.projects.insert(project_id);
449
450        let mut replica_id = 1;
451        while share.active_replica_ids.contains(&replica_id) {
452            replica_id += 1;
453        }
454        share.active_replica_ids.insert(replica_id);
455        share.guests.insert(connection_id, (replica_id, user_id));
456
457        #[cfg(test)]
458        self.check_invariants();
459
460        Ok(JoinedProject {
461            replica_id,
462            project: &self.projects[&project_id],
463        })
464    }
465
466    pub fn leave_project(
467        &mut self,
468        connection_id: ConnectionId,
469        project_id: u64,
470    ) -> tide::Result<LeftProject> {
471        let project = self
472            .projects
473            .get_mut(&project_id)
474            .ok_or_else(|| anyhow!("no such project"))?;
475        let share = project
476            .share
477            .as_mut()
478            .ok_or_else(|| anyhow!("project is not shared"))?;
479        let (replica_id, _) = share
480            .guests
481            .remove(&connection_id)
482            .ok_or_else(|| anyhow!("cannot leave a project before joining it"))?;
483        share.active_replica_ids.remove(&replica_id);
484
485        if let Some(connection) = self.connections.get_mut(&connection_id) {
486            connection.projects.remove(&project_id);
487        }
488
489        let connection_ids = project.connection_ids();
490        let authorized_user_ids = project.authorized_user_ids();
491
492        #[cfg(test)]
493        self.check_invariants();
494
495        Ok(LeftProject {
496            connection_ids,
497            authorized_user_ids,
498        })
499    }
500
501    pub fn update_worktree(
502        &mut self,
503        connection_id: ConnectionId,
504        project_id: u64,
505        worktree_id: u64,
506        removed_entries: &[u64],
507        updated_entries: &[proto::Entry],
508    ) -> tide::Result<Vec<ConnectionId>> {
509        let project = self.write_project(project_id, connection_id)?;
510        let worktree = project
511            .share_mut()?
512            .worktrees
513            .get_mut(&worktree_id)
514            .ok_or_else(|| anyhow!("no such worktree"))?;
515        for entry_id in removed_entries {
516            worktree.entries.remove(&entry_id);
517        }
518        for entry in updated_entries {
519            worktree.entries.insert(entry.id, entry.clone());
520        }
521        Ok(project.connection_ids())
522    }
523
524    pub fn project_connection_ids(
525        &self,
526        project_id: u64,
527        acting_connection_id: ConnectionId,
528    ) -> tide::Result<Vec<ConnectionId>> {
529        Ok(self
530            .read_project(project_id, acting_connection_id)?
531            .connection_ids())
532    }
533
534    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> tide::Result<Vec<ConnectionId>> {
535        Ok(self
536            .channels
537            .get(&channel_id)
538            .ok_or_else(|| anyhow!("no such channel"))?
539            .connection_ids())
540    }
541
542    #[cfg(test)]
543    pub fn project(&self, project_id: u64) -> Option<&Project> {
544        self.projects.get(&project_id)
545    }
546
547    pub fn read_project(
548        &self,
549        project_id: u64,
550        connection_id: ConnectionId,
551    ) -> tide::Result<&Project> {
552        let project = self
553            .projects
554            .get(&project_id)
555            .ok_or_else(|| anyhow!("no such project"))?;
556        if project.host_connection_id == connection_id
557            || project
558                .share
559                .as_ref()
560                .ok_or_else(|| anyhow!("project is not shared"))?
561                .guests
562                .contains_key(&connection_id)
563        {
564            Ok(project)
565        } else {
566            Err(anyhow!("no such project"))?
567        }
568    }
569
570    fn write_project(
571        &mut self,
572        project_id: u64,
573        connection_id: ConnectionId,
574    ) -> tide::Result<&mut Project> {
575        let project = self
576            .projects
577            .get_mut(&project_id)
578            .ok_or_else(|| anyhow!("no such project"))?;
579        if project.host_connection_id == connection_id
580            || project
581                .share
582                .as_ref()
583                .ok_or_else(|| anyhow!("project is not shared"))?
584                .guests
585                .contains_key(&connection_id)
586        {
587            Ok(project)
588        } else {
589            Err(anyhow!("no such project"))?
590        }
591    }
592
593    #[cfg(test)]
594    fn check_invariants(&self) {
595        for (connection_id, connection) in &self.connections {
596            for project_id in &connection.projects {
597                let project = &self.projects.get(&project_id).unwrap();
598                if project.host_connection_id != *connection_id {
599                    assert!(project
600                        .share
601                        .as_ref()
602                        .unwrap()
603                        .guests
604                        .contains_key(connection_id));
605                }
606            }
607            for channel_id in &connection.channels {
608                let channel = self.channels.get(channel_id).unwrap();
609                assert!(channel.connection_ids.contains(connection_id));
610            }
611            assert!(self
612                .connections_by_user_id
613                .get(&connection.user_id)
614                .unwrap()
615                .contains(connection_id));
616        }
617
618        for (user_id, connection_ids) in &self.connections_by_user_id {
619            for connection_id in connection_ids {
620                assert_eq!(
621                    self.connections.get(connection_id).unwrap().user_id,
622                    *user_id
623                );
624            }
625        }
626
627        for (project_id, project) in &self.projects {
628            let host_connection = self.connections.get(&project.host_connection_id).unwrap();
629            assert!(host_connection.projects.contains(project_id));
630
631            for authorized_user_ids in project.authorized_user_ids() {
632                let visible_project_ids = self
633                    .visible_projects_by_user_id
634                    .get(&authorized_user_ids)
635                    .unwrap();
636                assert!(visible_project_ids.contains(project_id));
637            }
638
639            if let Some(share) = &project.share {
640                for guest_connection_id in share.guests.keys() {
641                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
642                    assert!(guest_connection.projects.contains(project_id));
643                }
644                assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
645                assert_eq!(
646                    share.active_replica_ids,
647                    share
648                        .guests
649                        .values()
650                        .map(|(replica_id, _)| *replica_id)
651                        .collect::<HashSet<_>>(),
652                );
653            }
654        }
655
656        for (user_id, visible_project_ids) in &self.visible_projects_by_user_id {
657            for project_id in visible_project_ids {
658                let project = self.projects.get(project_id).unwrap();
659                assert!(project.authorized_user_ids().contains(user_id));
660            }
661        }
662
663        for (channel_id, channel) in &self.channels {
664            for connection_id in &channel.connection_ids {
665                let connection = self.connections.get(connection_id).unwrap();
666                assert!(connection.channels.contains(channel_id));
667            }
668        }
669    }
670}
671
672impl Project {
673    pub fn has_authorized_user_id(&self, user_id: UserId) -> bool {
674        self.worktrees
675            .values()
676            .any(|worktree| worktree.authorized_user_ids.contains(&user_id))
677    }
678
679    pub fn authorized_user_ids(&self) -> Vec<UserId> {
680        let mut ids = self
681            .worktrees
682            .values()
683            .flat_map(|worktree| worktree.authorized_user_ids.iter())
684            .copied()
685            .collect::<Vec<_>>();
686        ids.sort_unstable();
687        ids.dedup();
688        ids
689    }
690
691    pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
692        if let Some(share) = &self.share {
693            share.guests.keys().copied().collect()
694        } else {
695            Vec::new()
696        }
697    }
698
699    pub fn connection_ids(&self) -> Vec<ConnectionId> {
700        if let Some(share) = &self.share {
701            share
702                .guests
703                .keys()
704                .copied()
705                .chain(Some(self.host_connection_id))
706                .collect()
707        } else {
708            vec![self.host_connection_id]
709        }
710    }
711
712    pub fn share(&self) -> tide::Result<&ProjectShare> {
713        Ok(self
714            .share
715            .as_ref()
716            .ok_or_else(|| anyhow!("worktree is not shared"))?)
717    }
718
719    fn share_mut(&mut self) -> tide::Result<&mut ProjectShare> {
720        Ok(self
721            .share
722            .as_mut()
723            .ok_or_else(|| anyhow!("worktree is not shared"))?)
724    }
725}
726
727impl Channel {
728    fn connection_ids(&self) -> Vec<ConnectionId> {
729        self.connection_ids.iter().copied().collect()
730    }
731}