store.rs

  1use crate::db::{ChannelId, UserId};
  2use anyhow::anyhow;
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use std::{collections::hash_map, path::PathBuf};
  6
  7#[derive(Default)]
  8pub struct Store {
  9    connections: HashMap<ConnectionId, ConnectionState>,
 10    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 11    projects: HashMap<u64, Project>,
 12    visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
 13    channels: HashMap<ChannelId, Channel>,
 14    next_project_id: u64,
 15}
 16
 17struct ConnectionState {
 18    user_id: UserId,
 19    projects: HashSet<u64>,
 20    channels: HashSet<ChannelId>,
 21}
 22
 23pub struct Project {
 24    pub host_connection_id: ConnectionId,
 25    pub host_user_id: UserId,
 26    pub share: Option<ProjectShare>,
 27    pub worktrees: HashMap<u64, Worktree>,
 28}
 29
 30pub struct Worktree {
 31    pub authorized_user_ids: Vec<UserId>,
 32    pub root_name: String,
 33    pub share: Option<WorktreeShare>,
 34    pub weak: bool,
 35}
 36
 37#[derive(Default)]
 38pub struct ProjectShare {
 39    pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
 40    pub active_replica_ids: HashSet<ReplicaId>,
 41}
 42
 43pub struct WorktreeShare {
 44    pub entries: HashMap<u64, proto::Entry>,
 45    pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
 46}
 47
 48#[derive(Default)]
 49pub struct Channel {
 50    pub connection_ids: HashSet<ConnectionId>,
 51}
 52
 53pub type ReplicaId = u16;
 54
 55#[derive(Default)]
 56pub struct RemovedConnectionState {
 57    pub hosted_projects: HashMap<u64, Project>,
 58    pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
 59    pub contact_ids: HashSet<UserId>,
 60}
 61
 62pub struct JoinedProject<'a> {
 63    pub replica_id: ReplicaId,
 64    pub project: &'a Project,
 65}
 66
 67pub struct UnsharedProject {
 68    pub connection_ids: Vec<ConnectionId>,
 69    pub authorized_user_ids: Vec<UserId>,
 70}
 71
 72pub struct LeftProject {
 73    pub connection_ids: Vec<ConnectionId>,
 74    pub authorized_user_ids: Vec<UserId>,
 75}
 76
 77pub struct SharedWorktree {
 78    pub authorized_user_ids: Vec<UserId>,
 79    pub connection_ids: Vec<ConnectionId>,
 80}
 81
 82impl Store {
 83    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 84        self.connections.insert(
 85            connection_id,
 86            ConnectionState {
 87                user_id,
 88                projects: Default::default(),
 89                channels: Default::default(),
 90            },
 91        );
 92        self.connections_by_user_id
 93            .entry(user_id)
 94            .or_default()
 95            .insert(connection_id);
 96    }
 97
 98    pub fn remove_connection(
 99        &mut self,
100        connection_id: ConnectionId,
101    ) -> tide::Result<RemovedConnectionState> {
102        let connection = if let Some(connection) = self.connections.remove(&connection_id) {
103            connection
104        } else {
105            return Err(anyhow!("no such connection"))?;
106        };
107
108        for channel_id in &connection.channels {
109            if let Some(channel) = self.channels.get_mut(&channel_id) {
110                channel.connection_ids.remove(&connection_id);
111            }
112        }
113
114        let user_connections = self
115            .connections_by_user_id
116            .get_mut(&connection.user_id)
117            .unwrap();
118        user_connections.remove(&connection_id);
119        if user_connections.is_empty() {
120            self.connections_by_user_id.remove(&connection.user_id);
121        }
122
123        let mut result = RemovedConnectionState::default();
124        for project_id in connection.projects.clone() {
125            if let Ok(project) = self.unregister_project(project_id, connection_id) {
126                result.contact_ids.extend(project.authorized_user_ids());
127                result.hosted_projects.insert(project_id, project);
128            } else if let Ok(project) = self.leave_project(connection_id, project_id) {
129                result
130                    .guest_project_ids
131                    .insert(project_id, project.connection_ids);
132                result.contact_ids.extend(project.authorized_user_ids);
133            }
134        }
135
136        #[cfg(test)]
137        self.check_invariants();
138
139        Ok(result)
140    }
141
142    #[cfg(test)]
143    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
144        self.channels.get(&id)
145    }
146
147    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
148        if let Some(connection) = self.connections.get_mut(&connection_id) {
149            connection.channels.insert(channel_id);
150            self.channels
151                .entry(channel_id)
152                .or_default()
153                .connection_ids
154                .insert(connection_id);
155        }
156    }
157
158    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
159        if let Some(connection) = self.connections.get_mut(&connection_id) {
160            connection.channels.remove(&channel_id);
161            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
162                entry.get_mut().connection_ids.remove(&connection_id);
163                if entry.get_mut().connection_ids.is_empty() {
164                    entry.remove();
165                }
166            }
167        }
168    }
169
170    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
171        Ok(self
172            .connections
173            .get(&connection_id)
174            .ok_or_else(|| anyhow!("unknown connection"))?
175            .user_id)
176    }
177
178    pub fn connection_ids_for_user<'a>(
179        &'a self,
180        user_id: UserId,
181    ) -> impl 'a + Iterator<Item = ConnectionId> {
182        self.connections_by_user_id
183            .get(&user_id)
184            .into_iter()
185            .flatten()
186            .copied()
187    }
188
189    pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
190        let mut contacts = HashMap::default();
191        for project_id in self
192            .visible_projects_by_user_id
193            .get(&user_id)
194            .unwrap_or(&HashSet::default())
195        {
196            let project = &self.projects[project_id];
197
198            let mut guests = HashSet::default();
199            if let Ok(share) = project.share() {
200                for guest_connection_id in share.guests.keys() {
201                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
202                        guests.insert(user_id.to_proto());
203                    }
204                }
205            }
206
207            if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
208                let mut worktree_root_names = project
209                    .worktrees
210                    .values()
211                    .filter(|worktree| !worktree.weak)
212                    .map(|worktree| worktree.root_name.clone())
213                    .collect::<Vec<_>>();
214                worktree_root_names.sort_unstable();
215                contacts
216                    .entry(host_user_id)
217                    .or_insert_with(|| proto::Contact {
218                        user_id: host_user_id.to_proto(),
219                        projects: Vec::new(),
220                    })
221                    .projects
222                    .push(proto::ProjectMetadata {
223                        id: *project_id,
224                        worktree_root_names,
225                        is_shared: project.share.is_some(),
226                        guests: guests.into_iter().collect(),
227                    });
228            }
229        }
230
231        contacts.into_values().collect()
232    }
233
234    pub fn register_project(
235        &mut self,
236        host_connection_id: ConnectionId,
237        host_user_id: UserId,
238    ) -> u64 {
239        let project_id = self.next_project_id;
240        self.projects.insert(
241            project_id,
242            Project {
243                host_connection_id,
244                host_user_id,
245                share: None,
246                worktrees: Default::default(),
247            },
248        );
249        self.next_project_id += 1;
250        project_id
251    }
252
253    pub fn register_worktree(
254        &mut self,
255        project_id: u64,
256        worktree_id: u64,
257        connection_id: ConnectionId,
258        worktree: Worktree,
259    ) -> tide::Result<()> {
260        let project = self
261            .projects
262            .get_mut(&project_id)
263            .ok_or_else(|| anyhow!("no such project"))?;
264        if project.host_connection_id == connection_id {
265            for authorized_user_id in &worktree.authorized_user_ids {
266                self.visible_projects_by_user_id
267                    .entry(*authorized_user_id)
268                    .or_default()
269                    .insert(project_id);
270            }
271            if let Some(connection) = self.connections.get_mut(&project.host_connection_id) {
272                connection.projects.insert(project_id);
273            }
274            project.worktrees.insert(worktree_id, worktree);
275
276            #[cfg(test)]
277            self.check_invariants();
278            Ok(())
279        } else {
280            Err(anyhow!("no such project"))?
281        }
282    }
283
284    pub fn unregister_project(
285        &mut self,
286        project_id: u64,
287        connection_id: ConnectionId,
288    ) -> tide::Result<Project> {
289        match self.projects.entry(project_id) {
290            hash_map::Entry::Occupied(e) => {
291                if e.get().host_connection_id == connection_id {
292                    for user_id in e.get().authorized_user_ids() {
293                        if let hash_map::Entry::Occupied(mut projects) =
294                            self.visible_projects_by_user_id.entry(user_id)
295                        {
296                            projects.get_mut().remove(&project_id);
297                        }
298                    }
299
300                    Ok(e.remove())
301                } else {
302                    Err(anyhow!("no such project"))?
303                }
304            }
305            hash_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?,
306        }
307    }
308
309    pub fn unregister_worktree(
310        &mut self,
311        project_id: u64,
312        worktree_id: u64,
313        acting_connection_id: ConnectionId,
314    ) -> tide::Result<(Worktree, Vec<ConnectionId>)> {
315        let project = self
316            .projects
317            .get_mut(&project_id)
318            .ok_or_else(|| anyhow!("no such project"))?;
319        if project.host_connection_id != acting_connection_id {
320            Err(anyhow!("not your worktree"))?;
321        }
322
323        let worktree = project
324            .worktrees
325            .remove(&worktree_id)
326            .ok_or_else(|| anyhow!("no such worktree"))?;
327
328        let mut guest_connection_ids = Vec::new();
329        if let Some(share) = &project.share {
330            guest_connection_ids.extend(share.guests.keys());
331        }
332
333        for authorized_user_id in &worktree.authorized_user_ids {
334            if let Some(visible_projects) =
335                self.visible_projects_by_user_id.get_mut(authorized_user_id)
336            {
337                if !project.has_authorized_user_id(*authorized_user_id) {
338                    visible_projects.remove(&project_id);
339                }
340            }
341        }
342
343        #[cfg(test)]
344        self.check_invariants();
345
346        Ok((worktree, guest_connection_ids))
347    }
348
349    pub fn share_project(&mut self, project_id: u64, connection_id: ConnectionId) -> bool {
350        if let Some(project) = self.projects.get_mut(&project_id) {
351            if project.host_connection_id == connection_id {
352                project.share = Some(ProjectShare::default());
353                return true;
354            }
355        }
356        false
357    }
358
359    pub fn unshare_project(
360        &mut self,
361        project_id: u64,
362        acting_connection_id: ConnectionId,
363    ) -> tide::Result<UnsharedProject> {
364        let project = if let Some(project) = self.projects.get_mut(&project_id) {
365            project
366        } else {
367            return Err(anyhow!("no such project"))?;
368        };
369
370        if project.host_connection_id != acting_connection_id {
371            return Err(anyhow!("not your project"))?;
372        }
373
374        let connection_ids = project.connection_ids();
375        let authorized_user_ids = project.authorized_user_ids();
376        if let Some(share) = project.share.take() {
377            for connection_id in share.guests.into_keys() {
378                if let Some(connection) = self.connections.get_mut(&connection_id) {
379                    connection.projects.remove(&project_id);
380                }
381            }
382
383            for worktree in project.worktrees.values_mut() {
384                worktree.share.take();
385            }
386
387            #[cfg(test)]
388            self.check_invariants();
389
390            Ok(UnsharedProject {
391                connection_ids,
392                authorized_user_ids,
393            })
394        } else {
395            Err(anyhow!("project is not shared"))?
396        }
397    }
398
399    pub fn share_worktree(
400        &mut self,
401        project_id: u64,
402        worktree_id: u64,
403        connection_id: ConnectionId,
404        entries: HashMap<u64, proto::Entry>,
405        diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
406    ) -> tide::Result<SharedWorktree> {
407        let project = self
408            .projects
409            .get_mut(&project_id)
410            .ok_or_else(|| anyhow!("no such project"))?;
411        let worktree = project
412            .worktrees
413            .get_mut(&worktree_id)
414            .ok_or_else(|| anyhow!("no such worktree"))?;
415        if project.host_connection_id == connection_id && project.share.is_some() {
416            worktree.share = Some(WorktreeShare {
417                entries,
418                diagnostic_summaries,
419            });
420            Ok(SharedWorktree {
421                authorized_user_ids: project.authorized_user_ids(),
422                connection_ids: project.guest_connection_ids(),
423            })
424        } else {
425            Err(anyhow!("no such worktree"))?
426        }
427    }
428
429    pub fn update_diagnostic_summary(
430        &mut self,
431        project_id: u64,
432        worktree_id: u64,
433        connection_id: ConnectionId,
434        summary: proto::DiagnosticSummary,
435    ) -> tide::Result<Vec<ConnectionId>> {
436        let project = self
437            .projects
438            .get_mut(&project_id)
439            .ok_or_else(|| anyhow!("no such project"))?;
440        let worktree = project
441            .worktrees
442            .get_mut(&worktree_id)
443            .ok_or_else(|| anyhow!("no such worktree"))?;
444        if project.host_connection_id == connection_id {
445            if let Some(share) = worktree.share.as_mut() {
446                share
447                    .diagnostic_summaries
448                    .insert(summary.path.clone().into(), summary);
449                return Ok(project.connection_ids());
450            }
451        }
452
453        Err(anyhow!("no such worktree"))?
454    }
455
456    pub fn join_project(
457        &mut self,
458        connection_id: ConnectionId,
459        user_id: UserId,
460        project_id: u64,
461    ) -> tide::Result<JoinedProject> {
462        let connection = self
463            .connections
464            .get_mut(&connection_id)
465            .ok_or_else(|| anyhow!("no such connection"))?;
466        let project = self
467            .projects
468            .get_mut(&project_id)
469            .and_then(|project| {
470                if project.has_authorized_user_id(user_id) {
471                    Some(project)
472                } else {
473                    None
474                }
475            })
476            .ok_or_else(|| anyhow!("no such project"))?;
477
478        let share = project.share_mut()?;
479        connection.projects.insert(project_id);
480
481        let mut replica_id = 1;
482        while share.active_replica_ids.contains(&replica_id) {
483            replica_id += 1;
484        }
485        share.active_replica_ids.insert(replica_id);
486        share.guests.insert(connection_id, (replica_id, user_id));
487
488        #[cfg(test)]
489        self.check_invariants();
490
491        Ok(JoinedProject {
492            replica_id,
493            project: &self.projects[&project_id],
494        })
495    }
496
497    pub fn leave_project(
498        &mut self,
499        connection_id: ConnectionId,
500        project_id: u64,
501    ) -> tide::Result<LeftProject> {
502        let project = self
503            .projects
504            .get_mut(&project_id)
505            .ok_or_else(|| anyhow!("no such project"))?;
506        let share = project
507            .share
508            .as_mut()
509            .ok_or_else(|| anyhow!("project is not shared"))?;
510        let (replica_id, _) = share
511            .guests
512            .remove(&connection_id)
513            .ok_or_else(|| anyhow!("cannot leave a project before joining it"))?;
514        share.active_replica_ids.remove(&replica_id);
515
516        if let Some(connection) = self.connections.get_mut(&connection_id) {
517            connection.projects.remove(&project_id);
518        }
519
520        let connection_ids = project.connection_ids();
521        let authorized_user_ids = project.authorized_user_ids();
522
523        #[cfg(test)]
524        self.check_invariants();
525
526        Ok(LeftProject {
527            connection_ids,
528            authorized_user_ids,
529        })
530    }
531
532    pub fn update_worktree(
533        &mut self,
534        connection_id: ConnectionId,
535        project_id: u64,
536        worktree_id: u64,
537        removed_entries: &[u64],
538        updated_entries: &[proto::Entry],
539    ) -> tide::Result<Vec<ConnectionId>> {
540        let project = self.write_project(project_id, connection_id)?;
541        let share = project
542            .worktrees
543            .get_mut(&worktree_id)
544            .ok_or_else(|| anyhow!("no such worktree"))?
545            .share
546            .as_mut()
547            .ok_or_else(|| anyhow!("worktree is not shared"))?;
548        for entry_id in removed_entries {
549            share.entries.remove(&entry_id);
550        }
551        for entry in updated_entries {
552            share.entries.insert(entry.id, entry.clone());
553        }
554        Ok(project.connection_ids())
555    }
556
557    pub fn project_connection_ids(
558        &self,
559        project_id: u64,
560        acting_connection_id: ConnectionId,
561    ) -> tide::Result<Vec<ConnectionId>> {
562        Ok(self
563            .read_project(project_id, acting_connection_id)?
564            .connection_ids())
565    }
566
567    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> tide::Result<Vec<ConnectionId>> {
568        Ok(self
569            .channels
570            .get(&channel_id)
571            .ok_or_else(|| anyhow!("no such channel"))?
572            .connection_ids())
573    }
574
575    #[cfg(test)]
576    pub fn project(&self, project_id: u64) -> Option<&Project> {
577        self.projects.get(&project_id)
578    }
579
580    pub fn read_project(
581        &self,
582        project_id: u64,
583        connection_id: ConnectionId,
584    ) -> tide::Result<&Project> {
585        let project = self
586            .projects
587            .get(&project_id)
588            .ok_or_else(|| anyhow!("no such project"))?;
589        if project.host_connection_id == connection_id
590            || project
591                .share
592                .as_ref()
593                .ok_or_else(|| anyhow!("project is not shared"))?
594                .guests
595                .contains_key(&connection_id)
596        {
597            Ok(project)
598        } else {
599            Err(anyhow!("no such project"))?
600        }
601    }
602
603    fn write_project(
604        &mut self,
605        project_id: u64,
606        connection_id: ConnectionId,
607    ) -> tide::Result<&mut Project> {
608        let project = self
609            .projects
610            .get_mut(&project_id)
611            .ok_or_else(|| anyhow!("no such project"))?;
612        if project.host_connection_id == connection_id
613            || project
614                .share
615                .as_ref()
616                .ok_or_else(|| anyhow!("project is not shared"))?
617                .guests
618                .contains_key(&connection_id)
619        {
620            Ok(project)
621        } else {
622            Err(anyhow!("no such project"))?
623        }
624    }
625
626    #[cfg(test)]
627    fn check_invariants(&self) {
628        for (connection_id, connection) in &self.connections {
629            for project_id in &connection.projects {
630                let project = &self.projects.get(&project_id).unwrap();
631                if project.host_connection_id != *connection_id {
632                    assert!(project
633                        .share
634                        .as_ref()
635                        .unwrap()
636                        .guests
637                        .contains_key(connection_id));
638                }
639            }
640            for channel_id in &connection.channels {
641                let channel = self.channels.get(channel_id).unwrap();
642                assert!(channel.connection_ids.contains(connection_id));
643            }
644            assert!(self
645                .connections_by_user_id
646                .get(&connection.user_id)
647                .unwrap()
648                .contains(connection_id));
649        }
650
651        for (user_id, connection_ids) in &self.connections_by_user_id {
652            for connection_id in connection_ids {
653                assert_eq!(
654                    self.connections.get(connection_id).unwrap().user_id,
655                    *user_id
656                );
657            }
658        }
659
660        for (project_id, project) in &self.projects {
661            let host_connection = self.connections.get(&project.host_connection_id).unwrap();
662            assert!(host_connection.projects.contains(project_id));
663
664            for authorized_user_ids in project.authorized_user_ids() {
665                let visible_project_ids = self
666                    .visible_projects_by_user_id
667                    .get(&authorized_user_ids)
668                    .unwrap();
669                assert!(visible_project_ids.contains(project_id));
670            }
671
672            if let Some(share) = &project.share {
673                for guest_connection_id in share.guests.keys() {
674                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
675                    assert!(guest_connection.projects.contains(project_id));
676                }
677                assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
678                assert_eq!(
679                    share.active_replica_ids,
680                    share
681                        .guests
682                        .values()
683                        .map(|(replica_id, _)| *replica_id)
684                        .collect::<HashSet<_>>(),
685                );
686            }
687        }
688
689        for (user_id, visible_project_ids) in &self.visible_projects_by_user_id {
690            for project_id in visible_project_ids {
691                let project = self.projects.get(project_id).unwrap();
692                assert!(project.authorized_user_ids().contains(user_id));
693            }
694        }
695
696        for (channel_id, channel) in &self.channels {
697            for connection_id in &channel.connection_ids {
698                let connection = self.connections.get(connection_id).unwrap();
699                assert!(connection.channels.contains(channel_id));
700            }
701        }
702    }
703}
704
705impl Project {
706    pub fn has_authorized_user_id(&self, user_id: UserId) -> bool {
707        self.worktrees
708            .values()
709            .any(|worktree| worktree.authorized_user_ids.contains(&user_id))
710    }
711
712    pub fn authorized_user_ids(&self) -> Vec<UserId> {
713        let mut ids = self
714            .worktrees
715            .values()
716            .flat_map(|worktree| worktree.authorized_user_ids.iter())
717            .copied()
718            .collect::<Vec<_>>();
719        ids.sort_unstable();
720        ids.dedup();
721        ids
722    }
723
724    pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
725        if let Some(share) = &self.share {
726            share.guests.keys().copied().collect()
727        } else {
728            Vec::new()
729        }
730    }
731
732    pub fn connection_ids(&self) -> Vec<ConnectionId> {
733        if let Some(share) = &self.share {
734            share
735                .guests
736                .keys()
737                .copied()
738                .chain(Some(self.host_connection_id))
739                .collect()
740        } else {
741            vec![self.host_connection_id]
742        }
743    }
744
745    pub fn share(&self) -> tide::Result<&ProjectShare> {
746        Ok(self
747            .share
748            .as_ref()
749            .ok_or_else(|| anyhow!("worktree is not shared"))?)
750    }
751
752    fn share_mut(&mut self) -> tide::Result<&mut ProjectShare> {
753        Ok(self
754            .share
755            .as_mut()
756            .ok_or_else(|| anyhow!("worktree is not shared"))?)
757    }
758}
759
760impl Channel {
761    fn connection_ids(&self) -> Vec<ConnectionId> {
762        self.connection_ids.iter().copied().collect()
763    }
764}