store.rs

  1use crate::db::{ChannelId, UserId};
  2use anyhow::anyhow;
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use std::{collections::hash_map, path::PathBuf};
  6
  7#[derive(Default)]
  8pub struct Store {
  9    connections: HashMap<ConnectionId, ConnectionState>,
 10    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 11    projects: HashMap<u64, Project>,
 12    visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
 13    channels: HashMap<ChannelId, Channel>,
 14    next_project_id: u64,
 15}
 16
 17struct ConnectionState {
 18    user_id: UserId,
 19    projects: HashSet<u64>,
 20    channels: HashSet<ChannelId>,
 21}
 22
 23pub struct Project {
 24    pub host_connection_id: ConnectionId,
 25    pub host_user_id: UserId,
 26    pub share: Option<ProjectShare>,
 27    pub worktrees: HashMap<u64, Worktree>,
 28}
 29
 30pub struct Worktree {
 31    pub authorized_user_ids: Vec<UserId>,
 32    pub root_name: String,
 33    pub share: Option<WorktreeShare>,
 34    pub weak: bool,
 35}
 36
 37#[derive(Default)]
 38pub struct ProjectShare {
 39    pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
 40    pub active_replica_ids: HashSet<ReplicaId>,
 41}
 42
 43pub struct WorktreeShare {
 44    pub entries: HashMap<u64, proto::Entry>,
 45    pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
 46}
 47
 48#[derive(Default)]
 49pub struct Channel {
 50    pub connection_ids: HashSet<ConnectionId>,
 51}
 52
 53pub type ReplicaId = u16;
 54
 55#[derive(Default)]
 56pub struct RemovedConnectionState {
 57    pub hosted_projects: HashMap<u64, Project>,
 58    pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
 59    pub contact_ids: HashSet<UserId>,
 60}
 61
 62pub struct JoinedProject<'a> {
 63    pub replica_id: ReplicaId,
 64    pub project: &'a Project,
 65}
 66
 67pub struct UnsharedProject {
 68    pub connection_ids: Vec<ConnectionId>,
 69    pub authorized_user_ids: Vec<UserId>,
 70}
 71
 72pub struct LeftProject {
 73    pub connection_ids: Vec<ConnectionId>,
 74    pub authorized_user_ids: Vec<UserId>,
 75}
 76
 77impl Store {
 78    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 79        self.connections.insert(
 80            connection_id,
 81            ConnectionState {
 82                user_id,
 83                projects: Default::default(),
 84                channels: Default::default(),
 85            },
 86        );
 87        self.connections_by_user_id
 88            .entry(user_id)
 89            .or_default()
 90            .insert(connection_id);
 91    }
 92
 93    pub fn remove_connection(
 94        &mut self,
 95        connection_id: ConnectionId,
 96    ) -> tide::Result<RemovedConnectionState> {
 97        let connection = if let Some(connection) = self.connections.remove(&connection_id) {
 98            connection
 99        } else {
100            return Err(anyhow!("no such connection"))?;
101        };
102
103        for channel_id in &connection.channels {
104            if let Some(channel) = self.channels.get_mut(&channel_id) {
105                channel.connection_ids.remove(&connection_id);
106            }
107        }
108
109        let user_connections = self
110            .connections_by_user_id
111            .get_mut(&connection.user_id)
112            .unwrap();
113        user_connections.remove(&connection_id);
114        if user_connections.is_empty() {
115            self.connections_by_user_id.remove(&connection.user_id);
116        }
117
118        let mut result = RemovedConnectionState::default();
119        for project_id in connection.projects.clone() {
120            if let Some(project) = self.unregister_project(project_id, connection_id) {
121                result.contact_ids.extend(project.authorized_user_ids());
122                result.hosted_projects.insert(project_id, project);
123            } else if let Some(project) = self.leave_project(connection_id, project_id) {
124                result
125                    .guest_project_ids
126                    .insert(project_id, project.connection_ids);
127                result.contact_ids.extend(project.authorized_user_ids);
128            }
129        }
130
131        #[cfg(test)]
132        self.check_invariants();
133
134        Ok(result)
135    }
136
137    #[cfg(test)]
138    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
139        self.channels.get(&id)
140    }
141
142    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
143        if let Some(connection) = self.connections.get_mut(&connection_id) {
144            connection.channels.insert(channel_id);
145            self.channels
146                .entry(channel_id)
147                .or_default()
148                .connection_ids
149                .insert(connection_id);
150        }
151    }
152
153    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
154        if let Some(connection) = self.connections.get_mut(&connection_id) {
155            connection.channels.remove(&channel_id);
156            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
157                entry.get_mut().connection_ids.remove(&connection_id);
158                if entry.get_mut().connection_ids.is_empty() {
159                    entry.remove();
160                }
161            }
162        }
163    }
164
165    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
166        Ok(self
167            .connections
168            .get(&connection_id)
169            .ok_or_else(|| anyhow!("unknown connection"))?
170            .user_id)
171    }
172
173    pub fn connection_ids_for_user<'a>(
174        &'a self,
175        user_id: UserId,
176    ) -> impl 'a + Iterator<Item = ConnectionId> {
177        self.connections_by_user_id
178            .get(&user_id)
179            .into_iter()
180            .flatten()
181            .copied()
182    }
183
184    pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
185        let mut contacts = HashMap::default();
186        for project_id in self
187            .visible_projects_by_user_id
188            .get(&user_id)
189            .unwrap_or(&HashSet::default())
190        {
191            let project = &self.projects[project_id];
192
193            let mut guests = HashSet::default();
194            if let Ok(share) = project.share() {
195                for guest_connection_id in share.guests.keys() {
196                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
197                        guests.insert(user_id.to_proto());
198                    }
199                }
200            }
201
202            if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
203                let mut worktree_root_names = project
204                    .worktrees
205                    .values()
206                    .filter(|worktree| !worktree.weak)
207                    .map(|worktree| worktree.root_name.clone())
208                    .collect::<Vec<_>>();
209                worktree_root_names.sort_unstable();
210                contacts
211                    .entry(host_user_id)
212                    .or_insert_with(|| proto::Contact {
213                        user_id: host_user_id.to_proto(),
214                        projects: Vec::new(),
215                    })
216                    .projects
217                    .push(proto::ProjectMetadata {
218                        id: *project_id,
219                        worktree_root_names,
220                        is_shared: project.share.is_some(),
221                        guests: guests.into_iter().collect(),
222                    });
223            }
224        }
225
226        contacts.into_values().collect()
227    }
228
229    pub fn register_project(
230        &mut self,
231        host_connection_id: ConnectionId,
232        host_user_id: UserId,
233    ) -> u64 {
234        let project_id = self.next_project_id;
235        self.projects.insert(
236            project_id,
237            Project {
238                host_connection_id,
239                host_user_id,
240                share: None,
241                worktrees: Default::default(),
242            },
243        );
244        self.next_project_id += 1;
245        project_id
246    }
247
248    pub fn register_worktree(
249        &mut self,
250        project_id: u64,
251        worktree_id: u64,
252        worktree: Worktree,
253    ) -> bool {
254        if let Some(project) = self.projects.get_mut(&project_id) {
255            for authorized_user_id in &worktree.authorized_user_ids {
256                self.visible_projects_by_user_id
257                    .entry(*authorized_user_id)
258                    .or_default()
259                    .insert(project_id);
260            }
261            if let Some(connection) = self.connections.get_mut(&project.host_connection_id) {
262                connection.projects.insert(project_id);
263            }
264            project.worktrees.insert(worktree_id, worktree);
265
266            #[cfg(test)]
267            self.check_invariants();
268            true
269        } else {
270            false
271        }
272    }
273
274    pub fn unregister_project(
275        &mut self,
276        project_id: u64,
277        connection_id: ConnectionId,
278    ) -> Option<Project> {
279        match self.projects.entry(project_id) {
280            hash_map::Entry::Occupied(e) => {
281                if e.get().host_connection_id == connection_id {
282                    for user_id in e.get().authorized_user_ids() {
283                        if let hash_map::Entry::Occupied(mut projects) =
284                            self.visible_projects_by_user_id.entry(user_id)
285                        {
286                            projects.get_mut().remove(&project_id);
287                        }
288                    }
289
290                    Some(e.remove())
291                } else {
292                    None
293                }
294            }
295            hash_map::Entry::Vacant(_) => None,
296        }
297    }
298
299    pub fn unregister_worktree(
300        &mut self,
301        project_id: u64,
302        worktree_id: u64,
303        acting_connection_id: ConnectionId,
304    ) -> tide::Result<(Worktree, Vec<ConnectionId>)> {
305        let project = self
306            .projects
307            .get_mut(&project_id)
308            .ok_or_else(|| anyhow!("no such project"))?;
309        if project.host_connection_id != acting_connection_id {
310            Err(anyhow!("not your worktree"))?;
311        }
312
313        let worktree = project
314            .worktrees
315            .remove(&worktree_id)
316            .ok_or_else(|| anyhow!("no such worktree"))?;
317
318        let mut guest_connection_ids = Vec::new();
319        if let Some(share) = &project.share {
320            guest_connection_ids.extend(share.guests.keys());
321        }
322
323        for authorized_user_id in &worktree.authorized_user_ids {
324            if let Some(visible_projects) =
325                self.visible_projects_by_user_id.get_mut(authorized_user_id)
326            {
327                if !project.has_authorized_user_id(*authorized_user_id) {
328                    visible_projects.remove(&project_id);
329                }
330            }
331        }
332
333        #[cfg(test)]
334        self.check_invariants();
335
336        Ok((worktree, guest_connection_ids))
337    }
338
339    pub fn share_project(&mut self, project_id: u64, connection_id: ConnectionId) -> bool {
340        if let Some(project) = self.projects.get_mut(&project_id) {
341            if project.host_connection_id == connection_id {
342                project.share = Some(ProjectShare::default());
343                return true;
344            }
345        }
346        false
347    }
348
349    pub fn unshare_project(
350        &mut self,
351        project_id: u64,
352        acting_connection_id: ConnectionId,
353    ) -> tide::Result<UnsharedProject> {
354        let project = if let Some(project) = self.projects.get_mut(&project_id) {
355            project
356        } else {
357            return Err(anyhow!("no such project"))?;
358        };
359
360        if project.host_connection_id != acting_connection_id {
361            return Err(anyhow!("not your project"))?;
362        }
363
364        let connection_ids = project.connection_ids();
365        let authorized_user_ids = project.authorized_user_ids();
366        if let Some(share) = project.share.take() {
367            for connection_id in share.guests.into_keys() {
368                if let Some(connection) = self.connections.get_mut(&connection_id) {
369                    connection.projects.remove(&project_id);
370                }
371            }
372
373            for worktree in project.worktrees.values_mut() {
374                worktree.share.take();
375            }
376
377            #[cfg(test)]
378            self.check_invariants();
379
380            Ok(UnsharedProject {
381                connection_ids,
382                authorized_user_ids,
383            })
384        } else {
385            Err(anyhow!("project is not shared"))?
386        }
387    }
388
389    pub fn share_worktree(
390        &mut self,
391        project_id: u64,
392        worktree_id: u64,
393        connection_id: ConnectionId,
394        entries: HashMap<u64, proto::Entry>,
395        diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
396    ) -> Option<Vec<UserId>> {
397        let project = self.projects.get_mut(&project_id)?;
398        let worktree = project.worktrees.get_mut(&worktree_id)?;
399        if project.host_connection_id == connection_id && project.share.is_some() {
400            worktree.share = Some(WorktreeShare {
401                entries,
402                diagnostic_summaries,
403            });
404            Some(project.authorized_user_ids())
405        } else {
406            None
407        }
408    }
409
410    pub fn update_diagnostic_summary(
411        &mut self,
412        project_id: u64,
413        worktree_id: u64,
414        connection_id: ConnectionId,
415        summary: proto::DiagnosticSummary,
416    ) -> Option<Vec<ConnectionId>> {
417        let project = self.projects.get_mut(&project_id)?;
418        let worktree = project.worktrees.get_mut(&worktree_id)?;
419        if project.host_connection_id == connection_id {
420            if let Some(share) = worktree.share.as_mut() {
421                share
422                    .diagnostic_summaries
423                    .insert(summary.path.clone().into(), summary);
424                return Some(project.connection_ids());
425            }
426        }
427
428        None
429    }
430
431    pub fn join_project(
432        &mut self,
433        connection_id: ConnectionId,
434        user_id: UserId,
435        project_id: u64,
436    ) -> tide::Result<JoinedProject> {
437        let connection = self
438            .connections
439            .get_mut(&connection_id)
440            .ok_or_else(|| anyhow!("no such connection"))?;
441        let project = self
442            .projects
443            .get_mut(&project_id)
444            .and_then(|project| {
445                if project.has_authorized_user_id(user_id) {
446                    Some(project)
447                } else {
448                    None
449                }
450            })
451            .ok_or_else(|| anyhow!("no such project"))?;
452
453        let share = project.share_mut()?;
454        connection.projects.insert(project_id);
455
456        let mut replica_id = 1;
457        while share.active_replica_ids.contains(&replica_id) {
458            replica_id += 1;
459        }
460        share.active_replica_ids.insert(replica_id);
461        share.guests.insert(connection_id, (replica_id, user_id));
462
463        #[cfg(test)]
464        self.check_invariants();
465
466        Ok(JoinedProject {
467            replica_id,
468            project: &self.projects[&project_id],
469        })
470    }
471
472    pub fn leave_project(
473        &mut self,
474        connection_id: ConnectionId,
475        project_id: u64,
476    ) -> Option<LeftProject> {
477        let project = self.projects.get_mut(&project_id)?;
478        let share = project.share.as_mut()?;
479        let (replica_id, _) = share.guests.remove(&connection_id)?;
480        share.active_replica_ids.remove(&replica_id);
481
482        if let Some(connection) = self.connections.get_mut(&connection_id) {
483            connection.projects.remove(&project_id);
484        }
485
486        let connection_ids = project.connection_ids();
487        let authorized_user_ids = project.authorized_user_ids();
488
489        #[cfg(test)]
490        self.check_invariants();
491
492        Some(LeftProject {
493            connection_ids,
494            authorized_user_ids,
495        })
496    }
497
498    pub fn update_worktree(
499        &mut self,
500        connection_id: ConnectionId,
501        project_id: u64,
502        worktree_id: u64,
503        removed_entries: &[u64],
504        updated_entries: &[proto::Entry],
505    ) -> Option<Vec<ConnectionId>> {
506        let project = self.write_project(project_id, connection_id)?;
507        let share = project.worktrees.get_mut(&worktree_id)?.share.as_mut()?;
508        for entry_id in removed_entries {
509            share.entries.remove(&entry_id);
510        }
511        for entry in updated_entries {
512            share.entries.insert(entry.id, entry.clone());
513        }
514        Some(project.connection_ids())
515    }
516
517    pub fn project_connection_ids(
518        &self,
519        project_id: u64,
520        acting_connection_id: ConnectionId,
521    ) -> Option<Vec<ConnectionId>> {
522        Some(
523            self.read_project(project_id, acting_connection_id)?
524                .connection_ids(),
525        )
526    }
527
528    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
529        Some(self.channels.get(&channel_id)?.connection_ids())
530    }
531
532    #[cfg(test)]
533    pub fn project(&self, project_id: u64) -> Option<&Project> {
534        self.projects.get(&project_id)
535    }
536
537    pub fn read_project(&self, project_id: u64, connection_id: ConnectionId) -> Option<&Project> {
538        let project = self.projects.get(&project_id)?;
539        if project.host_connection_id == connection_id
540            || project.share.as_ref()?.guests.contains_key(&connection_id)
541        {
542            Some(project)
543        } else {
544            None
545        }
546    }
547
548    fn write_project(
549        &mut self,
550        project_id: u64,
551        connection_id: ConnectionId,
552    ) -> Option<&mut Project> {
553        let project = self.projects.get_mut(&project_id)?;
554        if project.host_connection_id == connection_id
555            || project.share.as_ref()?.guests.contains_key(&connection_id)
556        {
557            Some(project)
558        } else {
559            None
560        }
561    }
562
563    #[cfg(test)]
564    fn check_invariants(&self) {
565        for (connection_id, connection) in &self.connections {
566            for project_id in &connection.projects {
567                let project = &self.projects.get(&project_id).unwrap();
568                if project.host_connection_id != *connection_id {
569                    assert!(project
570                        .share
571                        .as_ref()
572                        .unwrap()
573                        .guests
574                        .contains_key(connection_id));
575                }
576            }
577            for channel_id in &connection.channels {
578                let channel = self.channels.get(channel_id).unwrap();
579                assert!(channel.connection_ids.contains(connection_id));
580            }
581            assert!(self
582                .connections_by_user_id
583                .get(&connection.user_id)
584                .unwrap()
585                .contains(connection_id));
586        }
587
588        for (user_id, connection_ids) in &self.connections_by_user_id {
589            for connection_id in connection_ids {
590                assert_eq!(
591                    self.connections.get(connection_id).unwrap().user_id,
592                    *user_id
593                );
594            }
595        }
596
597        for (project_id, project) in &self.projects {
598            let host_connection = self.connections.get(&project.host_connection_id).unwrap();
599            assert!(host_connection.projects.contains(project_id));
600
601            for authorized_user_ids in project.authorized_user_ids() {
602                let visible_project_ids = self
603                    .visible_projects_by_user_id
604                    .get(&authorized_user_ids)
605                    .unwrap();
606                assert!(visible_project_ids.contains(project_id));
607            }
608
609            if let Some(share) = &project.share {
610                for guest_connection_id in share.guests.keys() {
611                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
612                    assert!(guest_connection.projects.contains(project_id));
613                }
614                assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
615                assert_eq!(
616                    share.active_replica_ids,
617                    share
618                        .guests
619                        .values()
620                        .map(|(replica_id, _)| *replica_id)
621                        .collect::<HashSet<_>>(),
622                );
623            }
624        }
625
626        for (user_id, visible_project_ids) in &self.visible_projects_by_user_id {
627            for project_id in visible_project_ids {
628                let project = self.projects.get(project_id).unwrap();
629                assert!(project.authorized_user_ids().contains(user_id));
630            }
631        }
632
633        for (channel_id, channel) in &self.channels {
634            for connection_id in &channel.connection_ids {
635                let connection = self.connections.get(connection_id).unwrap();
636                assert!(connection.channels.contains(channel_id));
637            }
638        }
639    }
640}
641
642impl Project {
643    pub fn has_authorized_user_id(&self, user_id: UserId) -> bool {
644        self.worktrees
645            .values()
646            .any(|worktree| worktree.authorized_user_ids.contains(&user_id))
647    }
648
649    pub fn authorized_user_ids(&self) -> Vec<UserId> {
650        let mut ids = self
651            .worktrees
652            .values()
653            .flat_map(|worktree| worktree.authorized_user_ids.iter())
654            .copied()
655            .collect::<Vec<_>>();
656        ids.sort_unstable();
657        ids.dedup();
658        ids
659    }
660
661    pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
662        if let Some(share) = &self.share {
663            share.guests.keys().copied().collect()
664        } else {
665            Vec::new()
666        }
667    }
668
669    pub fn connection_ids(&self) -> Vec<ConnectionId> {
670        if let Some(share) = &self.share {
671            share
672                .guests
673                .keys()
674                .copied()
675                .chain(Some(self.host_connection_id))
676                .collect()
677        } else {
678            vec![self.host_connection_id]
679        }
680    }
681
682    pub fn share(&self) -> tide::Result<&ProjectShare> {
683        Ok(self
684            .share
685            .as_ref()
686            .ok_or_else(|| anyhow!("worktree is not shared"))?)
687    }
688
689    fn share_mut(&mut self) -> tide::Result<&mut ProjectShare> {
690        Ok(self
691            .share
692            .as_mut()
693            .ok_or_else(|| anyhow!("worktree is not shared"))?)
694    }
695}
696
697impl Channel {
698    fn connection_ids(&self) -> Vec<ConnectionId> {
699        self.connection_ids.iter().copied().collect()
700    }
701}