store.rs

  1use crate::db::{ChannelId, UserId};
  2use anyhow::anyhow;
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use std::{collections::hash_map, path::PathBuf};
  6
  7#[derive(Default)]
  8pub struct Store {
  9    connections: HashMap<ConnectionId, ConnectionState>,
 10    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 11    projects: HashMap<u64, Project>,
 12    visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
 13    channels: HashMap<ChannelId, Channel>,
 14    next_project_id: u64,
 15}
 16
 17struct ConnectionState {
 18    user_id: UserId,
 19    projects: HashSet<u64>,
 20    channels: HashSet<ChannelId>,
 21}
 22
 23pub struct Project {
 24    pub host_connection_id: ConnectionId,
 25    pub host_user_id: UserId,
 26    pub share: Option<ProjectShare>,
 27    pub worktrees: HashMap<u64, Worktree>,
 28}
 29
 30pub struct Worktree {
 31    pub authorized_user_ids: Vec<UserId>,
 32    pub root_name: String,
 33    pub share: Option<WorktreeShare>,
 34    pub weak: bool,
 35}
 36
 37#[derive(Default)]
 38pub struct ProjectShare {
 39    pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
 40    pub active_replica_ids: HashSet<ReplicaId>,
 41}
 42
 43pub struct WorktreeShare {
 44    pub entries: HashMap<u64, proto::Entry>,
 45    pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
 46}
 47
 48#[derive(Default)]
 49pub struct Channel {
 50    pub connection_ids: HashSet<ConnectionId>,
 51}
 52
 53pub type ReplicaId = u16;
 54
 55#[derive(Default)]
 56pub struct RemovedConnectionState {
 57    pub hosted_projects: HashMap<u64, Project>,
 58    pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
 59    pub contact_ids: HashSet<UserId>,
 60}
 61
 62pub struct JoinedProject<'a> {
 63    pub replica_id: ReplicaId,
 64    pub project: &'a Project,
 65}
 66
 67pub struct UnsharedProject {
 68    pub connection_ids: Vec<ConnectionId>,
 69    pub authorized_user_ids: Vec<UserId>,
 70}
 71
 72pub struct LeftProject {
 73    pub connection_ids: Vec<ConnectionId>,
 74    pub authorized_user_ids: Vec<UserId>,
 75}
 76
 77pub struct SharedWorktree {
 78    pub authorized_user_ids: Vec<UserId>,
 79    pub connection_ids: Vec<ConnectionId>,
 80}
 81
 82impl Store {
 83    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 84        self.connections.insert(
 85            connection_id,
 86            ConnectionState {
 87                user_id,
 88                projects: Default::default(),
 89                channels: Default::default(),
 90            },
 91        );
 92        self.connections_by_user_id
 93            .entry(user_id)
 94            .or_default()
 95            .insert(connection_id);
 96    }
 97
 98    pub fn remove_connection(
 99        &mut self,
100        connection_id: ConnectionId,
101    ) -> tide::Result<RemovedConnectionState> {
102        let connection = if let Some(connection) = self.connections.remove(&connection_id) {
103            connection
104        } else {
105            return Err(anyhow!("no such connection"))?;
106        };
107
108        for channel_id in &connection.channels {
109            if let Some(channel) = self.channels.get_mut(&channel_id) {
110                channel.connection_ids.remove(&connection_id);
111            }
112        }
113
114        let user_connections = self
115            .connections_by_user_id
116            .get_mut(&connection.user_id)
117            .unwrap();
118        user_connections.remove(&connection_id);
119        if user_connections.is_empty() {
120            self.connections_by_user_id.remove(&connection.user_id);
121        }
122
123        let mut result = RemovedConnectionState::default();
124        for project_id in connection.projects.clone() {
125            if let Some(project) = self.unregister_project(project_id, connection_id) {
126                result.contact_ids.extend(project.authorized_user_ids());
127                result.hosted_projects.insert(project_id, project);
128            } else if let Some(project) = self.leave_project(connection_id, project_id) {
129                result
130                    .guest_project_ids
131                    .insert(project_id, project.connection_ids);
132                result.contact_ids.extend(project.authorized_user_ids);
133            }
134        }
135
136        #[cfg(test)]
137        self.check_invariants();
138
139        Ok(result)
140    }
141
142    #[cfg(test)]
143    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
144        self.channels.get(&id)
145    }
146
147    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
148        if let Some(connection) = self.connections.get_mut(&connection_id) {
149            connection.channels.insert(channel_id);
150            self.channels
151                .entry(channel_id)
152                .or_default()
153                .connection_ids
154                .insert(connection_id);
155        }
156    }
157
158    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
159        if let Some(connection) = self.connections.get_mut(&connection_id) {
160            connection.channels.remove(&channel_id);
161            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
162                entry.get_mut().connection_ids.remove(&connection_id);
163                if entry.get_mut().connection_ids.is_empty() {
164                    entry.remove();
165                }
166            }
167        }
168    }
169
170    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
171        Ok(self
172            .connections
173            .get(&connection_id)
174            .ok_or_else(|| anyhow!("unknown connection"))?
175            .user_id)
176    }
177
178    pub fn connection_ids_for_user<'a>(
179        &'a self,
180        user_id: UserId,
181    ) -> impl 'a + Iterator<Item = ConnectionId> {
182        self.connections_by_user_id
183            .get(&user_id)
184            .into_iter()
185            .flatten()
186            .copied()
187    }
188
189    pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
190        let mut contacts = HashMap::default();
191        for project_id in self
192            .visible_projects_by_user_id
193            .get(&user_id)
194            .unwrap_or(&HashSet::default())
195        {
196            let project = &self.projects[project_id];
197
198            let mut guests = HashSet::default();
199            if let Ok(share) = project.share() {
200                for guest_connection_id in share.guests.keys() {
201                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
202                        guests.insert(user_id.to_proto());
203                    }
204                }
205            }
206
207            if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
208                let mut worktree_root_names = project
209                    .worktrees
210                    .values()
211                    .filter(|worktree| !worktree.weak)
212                    .map(|worktree| worktree.root_name.clone())
213                    .collect::<Vec<_>>();
214                worktree_root_names.sort_unstable();
215                contacts
216                    .entry(host_user_id)
217                    .or_insert_with(|| proto::Contact {
218                        user_id: host_user_id.to_proto(),
219                        projects: Vec::new(),
220                    })
221                    .projects
222                    .push(proto::ProjectMetadata {
223                        id: *project_id,
224                        worktree_root_names,
225                        is_shared: project.share.is_some(),
226                        guests: guests.into_iter().collect(),
227                    });
228            }
229        }
230
231        contacts.into_values().collect()
232    }
233
234    pub fn register_project(
235        &mut self,
236        host_connection_id: ConnectionId,
237        host_user_id: UserId,
238    ) -> u64 {
239        let project_id = self.next_project_id;
240        self.projects.insert(
241            project_id,
242            Project {
243                host_connection_id,
244                host_user_id,
245                share: None,
246                worktrees: Default::default(),
247            },
248        );
249        self.next_project_id += 1;
250        project_id
251    }
252
253    pub fn register_worktree(
254        &mut self,
255        project_id: u64,
256        worktree_id: u64,
257        worktree: Worktree,
258    ) -> bool {
259        if let Some(project) = self.projects.get_mut(&project_id) {
260            for authorized_user_id in &worktree.authorized_user_ids {
261                self.visible_projects_by_user_id
262                    .entry(*authorized_user_id)
263                    .or_default()
264                    .insert(project_id);
265            }
266            if let Some(connection) = self.connections.get_mut(&project.host_connection_id) {
267                connection.projects.insert(project_id);
268            }
269            project.worktrees.insert(worktree_id, worktree);
270
271            #[cfg(test)]
272            self.check_invariants();
273            true
274        } else {
275            false
276        }
277    }
278
279    pub fn unregister_project(
280        &mut self,
281        project_id: u64,
282        connection_id: ConnectionId,
283    ) -> Option<Project> {
284        match self.projects.entry(project_id) {
285            hash_map::Entry::Occupied(e) => {
286                if e.get().host_connection_id == connection_id {
287                    for user_id in e.get().authorized_user_ids() {
288                        if let hash_map::Entry::Occupied(mut projects) =
289                            self.visible_projects_by_user_id.entry(user_id)
290                        {
291                            projects.get_mut().remove(&project_id);
292                        }
293                    }
294
295                    Some(e.remove())
296                } else {
297                    None
298                }
299            }
300            hash_map::Entry::Vacant(_) => None,
301        }
302    }
303
304    pub fn unregister_worktree(
305        &mut self,
306        project_id: u64,
307        worktree_id: u64,
308        acting_connection_id: ConnectionId,
309    ) -> tide::Result<(Worktree, Vec<ConnectionId>)> {
310        let project = self
311            .projects
312            .get_mut(&project_id)
313            .ok_or_else(|| anyhow!("no such project"))?;
314        if project.host_connection_id != acting_connection_id {
315            Err(anyhow!("not your worktree"))?;
316        }
317
318        let worktree = project
319            .worktrees
320            .remove(&worktree_id)
321            .ok_or_else(|| anyhow!("no such worktree"))?;
322
323        let mut guest_connection_ids = Vec::new();
324        if let Some(share) = &project.share {
325            guest_connection_ids.extend(share.guests.keys());
326        }
327
328        for authorized_user_id in &worktree.authorized_user_ids {
329            if let Some(visible_projects) =
330                self.visible_projects_by_user_id.get_mut(authorized_user_id)
331            {
332                if !project.has_authorized_user_id(*authorized_user_id) {
333                    visible_projects.remove(&project_id);
334                }
335            }
336        }
337
338        #[cfg(test)]
339        self.check_invariants();
340
341        Ok((worktree, guest_connection_ids))
342    }
343
344    pub fn share_project(&mut self, project_id: u64, connection_id: ConnectionId) -> bool {
345        if let Some(project) = self.projects.get_mut(&project_id) {
346            if project.host_connection_id == connection_id {
347                project.share = Some(ProjectShare::default());
348                return true;
349            }
350        }
351        false
352    }
353
354    pub fn unshare_project(
355        &mut self,
356        project_id: u64,
357        acting_connection_id: ConnectionId,
358    ) -> tide::Result<UnsharedProject> {
359        let project = if let Some(project) = self.projects.get_mut(&project_id) {
360            project
361        } else {
362            return Err(anyhow!("no such project"))?;
363        };
364
365        if project.host_connection_id != acting_connection_id {
366            return Err(anyhow!("not your project"))?;
367        }
368
369        let connection_ids = project.connection_ids();
370        let authorized_user_ids = project.authorized_user_ids();
371        if let Some(share) = project.share.take() {
372            for connection_id in share.guests.into_keys() {
373                if let Some(connection) = self.connections.get_mut(&connection_id) {
374                    connection.projects.remove(&project_id);
375                }
376            }
377
378            for worktree in project.worktrees.values_mut() {
379                worktree.share.take();
380            }
381
382            #[cfg(test)]
383            self.check_invariants();
384
385            Ok(UnsharedProject {
386                connection_ids,
387                authorized_user_ids,
388            })
389        } else {
390            Err(anyhow!("project is not shared"))?
391        }
392    }
393
394    pub fn share_worktree(
395        &mut self,
396        project_id: u64,
397        worktree_id: u64,
398        connection_id: ConnectionId,
399        entries: HashMap<u64, proto::Entry>,
400        diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
401    ) -> Option<SharedWorktree> {
402        let project = self.projects.get_mut(&project_id)?;
403        let worktree = project.worktrees.get_mut(&worktree_id)?;
404        if project.host_connection_id == connection_id && project.share.is_some() {
405            worktree.share = Some(WorktreeShare {
406                entries,
407                diagnostic_summaries,
408            });
409            Some(SharedWorktree {
410                authorized_user_ids: project.authorized_user_ids(),
411                connection_ids: project.guest_connection_ids(),
412            })
413        } else {
414            None
415        }
416    }
417
418    pub fn update_diagnostic_summary(
419        &mut self,
420        project_id: u64,
421        worktree_id: u64,
422        connection_id: ConnectionId,
423        summary: proto::DiagnosticSummary,
424    ) -> Option<Vec<ConnectionId>> {
425        let project = self.projects.get_mut(&project_id)?;
426        let worktree = project.worktrees.get_mut(&worktree_id)?;
427        if project.host_connection_id == connection_id {
428            if let Some(share) = worktree.share.as_mut() {
429                share
430                    .diagnostic_summaries
431                    .insert(summary.path.clone().into(), summary);
432                return Some(project.connection_ids());
433            }
434        }
435
436        None
437    }
438
439    pub fn join_project(
440        &mut self,
441        connection_id: ConnectionId,
442        user_id: UserId,
443        project_id: u64,
444    ) -> tide::Result<JoinedProject> {
445        let connection = self
446            .connections
447            .get_mut(&connection_id)
448            .ok_or_else(|| anyhow!("no such connection"))?;
449        let project = self
450            .projects
451            .get_mut(&project_id)
452            .and_then(|project| {
453                if project.has_authorized_user_id(user_id) {
454                    Some(project)
455                } else {
456                    None
457                }
458            })
459            .ok_or_else(|| anyhow!("no such project"))?;
460
461        let share = project.share_mut()?;
462        connection.projects.insert(project_id);
463
464        let mut replica_id = 1;
465        while share.active_replica_ids.contains(&replica_id) {
466            replica_id += 1;
467        }
468        share.active_replica_ids.insert(replica_id);
469        share.guests.insert(connection_id, (replica_id, user_id));
470
471        #[cfg(test)]
472        self.check_invariants();
473
474        Ok(JoinedProject {
475            replica_id,
476            project: &self.projects[&project_id],
477        })
478    }
479
480    pub fn leave_project(
481        &mut self,
482        connection_id: ConnectionId,
483        project_id: u64,
484    ) -> Option<LeftProject> {
485        let project = self.projects.get_mut(&project_id)?;
486        let share = project.share.as_mut()?;
487        let (replica_id, _) = share.guests.remove(&connection_id)?;
488        share.active_replica_ids.remove(&replica_id);
489
490        if let Some(connection) = self.connections.get_mut(&connection_id) {
491            connection.projects.remove(&project_id);
492        }
493
494        let connection_ids = project.connection_ids();
495        let authorized_user_ids = project.authorized_user_ids();
496
497        #[cfg(test)]
498        self.check_invariants();
499
500        Some(LeftProject {
501            connection_ids,
502            authorized_user_ids,
503        })
504    }
505
506    pub fn update_worktree(
507        &mut self,
508        connection_id: ConnectionId,
509        project_id: u64,
510        worktree_id: u64,
511        removed_entries: &[u64],
512        updated_entries: &[proto::Entry],
513    ) -> Option<Vec<ConnectionId>> {
514        let project = self.write_project(project_id, connection_id)?;
515        let share = project.worktrees.get_mut(&worktree_id)?.share.as_mut()?;
516        for entry_id in removed_entries {
517            share.entries.remove(&entry_id);
518        }
519        for entry in updated_entries {
520            share.entries.insert(entry.id, entry.clone());
521        }
522        Some(project.connection_ids())
523    }
524
525    pub fn project_connection_ids(
526        &self,
527        project_id: u64,
528        acting_connection_id: ConnectionId,
529    ) -> Option<Vec<ConnectionId>> {
530        Some(
531            self.read_project(project_id, acting_connection_id)?
532                .connection_ids(),
533        )
534    }
535
536    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
537        Some(self.channels.get(&channel_id)?.connection_ids())
538    }
539
540    #[cfg(test)]
541    pub fn project(&self, project_id: u64) -> Option<&Project> {
542        self.projects.get(&project_id)
543    }
544
545    pub fn read_project(&self, project_id: u64, connection_id: ConnectionId) -> Option<&Project> {
546        let project = self.projects.get(&project_id)?;
547        if project.host_connection_id == connection_id
548            || project.share.as_ref()?.guests.contains_key(&connection_id)
549        {
550            Some(project)
551        } else {
552            None
553        }
554    }
555
556    fn write_project(
557        &mut self,
558        project_id: u64,
559        connection_id: ConnectionId,
560    ) -> Option<&mut Project> {
561        let project = self.projects.get_mut(&project_id)?;
562        if project.host_connection_id == connection_id
563            || project.share.as_ref()?.guests.contains_key(&connection_id)
564        {
565            Some(project)
566        } else {
567            None
568        }
569    }
570
571    #[cfg(test)]
572    fn check_invariants(&self) {
573        for (connection_id, connection) in &self.connections {
574            for project_id in &connection.projects {
575                let project = &self.projects.get(&project_id).unwrap();
576                if project.host_connection_id != *connection_id {
577                    assert!(project
578                        .share
579                        .as_ref()
580                        .unwrap()
581                        .guests
582                        .contains_key(connection_id));
583                }
584            }
585            for channel_id in &connection.channels {
586                let channel = self.channels.get(channel_id).unwrap();
587                assert!(channel.connection_ids.contains(connection_id));
588            }
589            assert!(self
590                .connections_by_user_id
591                .get(&connection.user_id)
592                .unwrap()
593                .contains(connection_id));
594        }
595
596        for (user_id, connection_ids) in &self.connections_by_user_id {
597            for connection_id in connection_ids {
598                assert_eq!(
599                    self.connections.get(connection_id).unwrap().user_id,
600                    *user_id
601                );
602            }
603        }
604
605        for (project_id, project) in &self.projects {
606            let host_connection = self.connections.get(&project.host_connection_id).unwrap();
607            assert!(host_connection.projects.contains(project_id));
608
609            for authorized_user_ids in project.authorized_user_ids() {
610                let visible_project_ids = self
611                    .visible_projects_by_user_id
612                    .get(&authorized_user_ids)
613                    .unwrap();
614                assert!(visible_project_ids.contains(project_id));
615            }
616
617            if let Some(share) = &project.share {
618                for guest_connection_id in share.guests.keys() {
619                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
620                    assert!(guest_connection.projects.contains(project_id));
621                }
622                assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
623                assert_eq!(
624                    share.active_replica_ids,
625                    share
626                        .guests
627                        .values()
628                        .map(|(replica_id, _)| *replica_id)
629                        .collect::<HashSet<_>>(),
630                );
631            }
632        }
633
634        for (user_id, visible_project_ids) in &self.visible_projects_by_user_id {
635            for project_id in visible_project_ids {
636                let project = self.projects.get(project_id).unwrap();
637                assert!(project.authorized_user_ids().contains(user_id));
638            }
639        }
640
641        for (channel_id, channel) in &self.channels {
642            for connection_id in &channel.connection_ids {
643                let connection = self.connections.get(connection_id).unwrap();
644                assert!(connection.channels.contains(channel_id));
645            }
646        }
647    }
648}
649
650impl Project {
651    pub fn has_authorized_user_id(&self, user_id: UserId) -> bool {
652        self.worktrees
653            .values()
654            .any(|worktree| worktree.authorized_user_ids.contains(&user_id))
655    }
656
657    pub fn authorized_user_ids(&self) -> Vec<UserId> {
658        let mut ids = self
659            .worktrees
660            .values()
661            .flat_map(|worktree| worktree.authorized_user_ids.iter())
662            .copied()
663            .collect::<Vec<_>>();
664        ids.sort_unstable();
665        ids.dedup();
666        ids
667    }
668
669    pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
670        if let Some(share) = &self.share {
671            share.guests.keys().copied().collect()
672        } else {
673            Vec::new()
674        }
675    }
676
677    pub fn connection_ids(&self) -> Vec<ConnectionId> {
678        if let Some(share) = &self.share {
679            share
680                .guests
681                .keys()
682                .copied()
683                .chain(Some(self.host_connection_id))
684                .collect()
685        } else {
686            vec![self.host_connection_id]
687        }
688    }
689
690    pub fn share(&self) -> tide::Result<&ProjectShare> {
691        Ok(self
692            .share
693            .as_ref()
694            .ok_or_else(|| anyhow!("worktree is not shared"))?)
695    }
696
697    fn share_mut(&mut self) -> tide::Result<&mut ProjectShare> {
698        Ok(self
699            .share
700            .as_mut()
701            .ok_or_else(|| anyhow!("worktree is not shared"))?)
702    }
703}
704
705impl Channel {
706    fn connection_ids(&self) -> Vec<ConnectionId> {
707        self.connection_ids.iter().copied().collect()
708    }
709}