store.rs

  1use crate::db::{ChannelId, UserId};
  2use anyhow::anyhow;
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use std::{collections::hash_map, path::PathBuf};
  6
  7#[derive(Default)]
  8pub struct Store {
  9    connections: HashMap<ConnectionId, ConnectionState>,
 10    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 11    projects: HashMap<u64, Project>,
 12    visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
 13    channels: HashMap<ChannelId, Channel>,
 14    next_project_id: u64,
 15}
 16
 17struct ConnectionState {
 18    user_id: UserId,
 19    projects: HashSet<u64>,
 20    channels: HashSet<ChannelId>,
 21}
 22
 23pub struct Project {
 24    pub host_connection_id: ConnectionId,
 25    pub host_user_id: UserId,
 26    pub share: Option<ProjectShare>,
 27    pub worktrees: HashMap<u64, Worktree>,
 28}
 29
 30pub struct Worktree {
 31    pub authorized_user_ids: Vec<UserId>,
 32    pub root_name: String,
 33    pub share: Option<WorktreeShare>,
 34}
 35
 36#[derive(Default)]
 37pub struct ProjectShare {
 38    pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
 39    pub active_replica_ids: HashSet<ReplicaId>,
 40}
 41
 42pub struct WorktreeShare {
 43    pub entries: HashMap<u64, proto::Entry>,
 44    pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
 45}
 46
 47#[derive(Default)]
 48pub struct Channel {
 49    pub connection_ids: HashSet<ConnectionId>,
 50}
 51
 52pub type ReplicaId = u16;
 53
 54#[derive(Default)]
 55pub struct RemovedConnectionState {
 56    pub hosted_projects: HashMap<u64, Project>,
 57    pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
 58    pub contact_ids: HashSet<UserId>,
 59}
 60
 61pub struct JoinedProject<'a> {
 62    pub replica_id: ReplicaId,
 63    pub project: &'a Project,
 64}
 65
 66pub struct UnsharedWorktree {
 67    pub connection_ids: Vec<ConnectionId>,
 68    pub authorized_user_ids: Vec<UserId>,
 69}
 70
 71pub struct LeftProject {
 72    pub connection_ids: Vec<ConnectionId>,
 73    pub authorized_user_ids: Vec<UserId>,
 74}
 75
 76impl Store {
 77    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 78        self.connections.insert(
 79            connection_id,
 80            ConnectionState {
 81                user_id,
 82                projects: Default::default(),
 83                channels: Default::default(),
 84            },
 85        );
 86        self.connections_by_user_id
 87            .entry(user_id)
 88            .or_default()
 89            .insert(connection_id);
 90    }
 91
 92    pub fn remove_connection(
 93        &mut self,
 94        connection_id: ConnectionId,
 95    ) -> tide::Result<RemovedConnectionState> {
 96        let connection = if let Some(connection) = self.connections.remove(&connection_id) {
 97            connection
 98        } else {
 99            return Err(anyhow!("no such connection"))?;
100        };
101
102        for channel_id in &connection.channels {
103            if let Some(channel) = self.channels.get_mut(&channel_id) {
104                channel.connection_ids.remove(&connection_id);
105            }
106        }
107
108        let user_connections = self
109            .connections_by_user_id
110            .get_mut(&connection.user_id)
111            .unwrap();
112        user_connections.remove(&connection_id);
113        if user_connections.is_empty() {
114            self.connections_by_user_id.remove(&connection.user_id);
115        }
116
117        let mut result = RemovedConnectionState::default();
118        for project_id in connection.projects.clone() {
119            if let Some(project) = self.unregister_project(project_id, connection_id) {
120                result.contact_ids.extend(project.authorized_user_ids());
121                result.hosted_projects.insert(project_id, project);
122            } else if let Some(project) = self.leave_project(connection_id, project_id) {
123                result
124                    .guest_project_ids
125                    .insert(project_id, project.connection_ids);
126                result.contact_ids.extend(project.authorized_user_ids);
127            }
128        }
129
130        #[cfg(test)]
131        self.check_invariants();
132
133        Ok(result)
134    }
135
136    #[cfg(test)]
137    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
138        self.channels.get(&id)
139    }
140
141    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
142        if let Some(connection) = self.connections.get_mut(&connection_id) {
143            connection.channels.insert(channel_id);
144            self.channels
145                .entry(channel_id)
146                .or_default()
147                .connection_ids
148                .insert(connection_id);
149        }
150    }
151
152    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
153        if let Some(connection) = self.connections.get_mut(&connection_id) {
154            connection.channels.remove(&channel_id);
155            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
156                entry.get_mut().connection_ids.remove(&connection_id);
157                if entry.get_mut().connection_ids.is_empty() {
158                    entry.remove();
159                }
160            }
161        }
162    }
163
164    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
165        Ok(self
166            .connections
167            .get(&connection_id)
168            .ok_or_else(|| anyhow!("unknown connection"))?
169            .user_id)
170    }
171
172    pub fn connection_ids_for_user<'a>(
173        &'a self,
174        user_id: UserId,
175    ) -> impl 'a + Iterator<Item = ConnectionId> {
176        self.connections_by_user_id
177            .get(&user_id)
178            .into_iter()
179            .flatten()
180            .copied()
181    }
182
183    pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
184        let mut contacts = HashMap::default();
185        for project_id in self
186            .visible_projects_by_user_id
187            .get(&user_id)
188            .unwrap_or(&HashSet::default())
189        {
190            let project = &self.projects[project_id];
191
192            let mut guests = HashSet::default();
193            if let Ok(share) = project.share() {
194                for guest_connection_id in share.guests.keys() {
195                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
196                        guests.insert(user_id.to_proto());
197                    }
198                }
199            }
200
201            if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
202                let mut worktree_root_names = project
203                    .worktrees
204                    .values()
205                    .map(|worktree| worktree.root_name.clone())
206                    .collect::<Vec<_>>();
207                worktree_root_names.sort_unstable();
208                contacts
209                    .entry(host_user_id)
210                    .or_insert_with(|| proto::Contact {
211                        user_id: host_user_id.to_proto(),
212                        projects: Vec::new(),
213                    })
214                    .projects
215                    .push(proto::ProjectMetadata {
216                        id: *project_id,
217                        worktree_root_names,
218                        is_shared: project.share.is_some(),
219                        guests: guests.into_iter().collect(),
220                    });
221            }
222        }
223
224        contacts.into_values().collect()
225    }
226
227    pub fn register_project(
228        &mut self,
229        host_connection_id: ConnectionId,
230        host_user_id: UserId,
231    ) -> u64 {
232        let project_id = self.next_project_id;
233        self.projects.insert(
234            project_id,
235            Project {
236                host_connection_id,
237                host_user_id,
238                share: None,
239                worktrees: Default::default(),
240            },
241        );
242        self.next_project_id += 1;
243        project_id
244    }
245
246    pub fn register_worktree(
247        &mut self,
248        project_id: u64,
249        worktree_id: u64,
250        worktree: Worktree,
251    ) -> bool {
252        if let Some(project) = self.projects.get_mut(&project_id) {
253            for authorized_user_id in &worktree.authorized_user_ids {
254                self.visible_projects_by_user_id
255                    .entry(*authorized_user_id)
256                    .or_default()
257                    .insert(project_id);
258            }
259            if let Some(connection) = self.connections.get_mut(&project.host_connection_id) {
260                connection.projects.insert(project_id);
261            }
262            project.worktrees.insert(worktree_id, worktree);
263
264            #[cfg(test)]
265            self.check_invariants();
266            true
267        } else {
268            false
269        }
270    }
271
272    pub fn unregister_project(
273        &mut self,
274        project_id: u64,
275        connection_id: ConnectionId,
276    ) -> Option<Project> {
277        match self.projects.entry(project_id) {
278            hash_map::Entry::Occupied(e) => {
279                if e.get().host_connection_id == connection_id {
280                    for user_id in e.get().authorized_user_ids() {
281                        if let hash_map::Entry::Occupied(mut projects) =
282                            self.visible_projects_by_user_id.entry(user_id)
283                        {
284                            projects.get_mut().remove(&project_id);
285                        }
286                    }
287
288                    Some(e.remove())
289                } else {
290                    None
291                }
292            }
293            hash_map::Entry::Vacant(_) => None,
294        }
295    }
296
297    pub fn unregister_worktree(
298        &mut self,
299        project_id: u64,
300        worktree_id: u64,
301        acting_connection_id: ConnectionId,
302    ) -> tide::Result<(Worktree, Vec<ConnectionId>)> {
303        let project = self
304            .projects
305            .get_mut(&project_id)
306            .ok_or_else(|| anyhow!("no such project"))?;
307        if project.host_connection_id != acting_connection_id {
308            Err(anyhow!("not your worktree"))?;
309        }
310
311        let worktree = project
312            .worktrees
313            .remove(&worktree_id)
314            .ok_or_else(|| anyhow!("no such worktree"))?;
315
316        let mut guest_connection_ids = Vec::new();
317        if let Some(share) = &project.share {
318            guest_connection_ids.extend(share.guests.keys());
319        }
320
321        for authorized_user_id in &worktree.authorized_user_ids {
322            if let Some(visible_projects) =
323                self.visible_projects_by_user_id.get_mut(authorized_user_id)
324            {
325                if !project.has_authorized_user_id(*authorized_user_id) {
326                    visible_projects.remove(&project_id);
327                }
328            }
329        }
330
331        #[cfg(test)]
332        self.check_invariants();
333
334        Ok((worktree, guest_connection_ids))
335    }
336
337    pub fn share_project(&mut self, project_id: u64, connection_id: ConnectionId) -> bool {
338        if let Some(project) = self.projects.get_mut(&project_id) {
339            if project.host_connection_id == connection_id {
340                project.share = Some(ProjectShare::default());
341                return true;
342            }
343        }
344        false
345    }
346
347    pub fn unshare_project(
348        &mut self,
349        project_id: u64,
350        acting_connection_id: ConnectionId,
351    ) -> tide::Result<UnsharedWorktree> {
352        let project = if let Some(project) = self.projects.get_mut(&project_id) {
353            project
354        } else {
355            return Err(anyhow!("no such project"))?;
356        };
357
358        if project.host_connection_id != acting_connection_id {
359            return Err(anyhow!("not your project"))?;
360        }
361
362        let connection_ids = project.connection_ids();
363        let authorized_user_ids = project.authorized_user_ids();
364        if let Some(share) = project.share.take() {
365            for connection_id in share.guests.into_keys() {
366                if let Some(connection) = self.connections.get_mut(&connection_id) {
367                    connection.projects.remove(&project_id);
368                }
369            }
370
371            #[cfg(test)]
372            self.check_invariants();
373
374            Ok(UnsharedWorktree {
375                connection_ids,
376                authorized_user_ids,
377            })
378        } else {
379            Err(anyhow!("project is not shared"))?
380        }
381    }
382
383    pub fn share_worktree(
384        &mut self,
385        project_id: u64,
386        worktree_id: u64,
387        connection_id: ConnectionId,
388        entries: HashMap<u64, proto::Entry>,
389        diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
390    ) -> Option<Vec<UserId>> {
391        let project = self.projects.get_mut(&project_id)?;
392        let worktree = project.worktrees.get_mut(&worktree_id)?;
393        if project.host_connection_id == connection_id && project.share.is_some() {
394            worktree.share = Some(WorktreeShare {
395                entries,
396                diagnostic_summaries,
397            });
398            Some(project.authorized_user_ids())
399        } else {
400            None
401        }
402    }
403
404    pub fn update_diagnostic_summary(
405        &mut self,
406        project_id: u64,
407        worktree_id: u64,
408        connection_id: ConnectionId,
409        summary: proto::DiagnosticSummary,
410    ) -> Option<Vec<ConnectionId>> {
411        let project = self.projects.get_mut(&project_id)?;
412        let worktree = project.worktrees.get_mut(&worktree_id)?;
413        if project.host_connection_id == connection_id {
414            if let Some(share) = worktree.share.as_mut() {
415                share
416                    .diagnostic_summaries
417                    .insert(summary.path.clone().into(), summary);
418                return Some(project.connection_ids());
419            }
420        }
421
422        None
423    }
424
425    pub fn join_project(
426        &mut self,
427        connection_id: ConnectionId,
428        user_id: UserId,
429        project_id: u64,
430    ) -> tide::Result<JoinedProject> {
431        let connection = self
432            .connections
433            .get_mut(&connection_id)
434            .ok_or_else(|| anyhow!("no such connection"))?;
435        let project = self
436            .projects
437            .get_mut(&project_id)
438            .and_then(|project| {
439                if project.has_authorized_user_id(user_id) {
440                    Some(project)
441                } else {
442                    None
443                }
444            })
445            .ok_or_else(|| anyhow!("no such project"))?;
446
447        let share = project.share_mut()?;
448        connection.projects.insert(project_id);
449
450        let mut replica_id = 1;
451        while share.active_replica_ids.contains(&replica_id) {
452            replica_id += 1;
453        }
454        share.active_replica_ids.insert(replica_id);
455        share.guests.insert(connection_id, (replica_id, user_id));
456
457        #[cfg(test)]
458        self.check_invariants();
459
460        Ok(JoinedProject {
461            replica_id,
462            project: &self.projects[&project_id],
463        })
464    }
465
466    pub fn leave_project(
467        &mut self,
468        connection_id: ConnectionId,
469        project_id: u64,
470    ) -> Option<LeftProject> {
471        let project = self.projects.get_mut(&project_id)?;
472        let share = project.share.as_mut()?;
473        let (replica_id, _) = share.guests.remove(&connection_id)?;
474        share.active_replica_ids.remove(&replica_id);
475
476        if let Some(connection) = self.connections.get_mut(&connection_id) {
477            connection.projects.remove(&project_id);
478        }
479
480        let connection_ids = project.connection_ids();
481        let authorized_user_ids = project.authorized_user_ids();
482
483        #[cfg(test)]
484        self.check_invariants();
485
486        Some(LeftProject {
487            connection_ids,
488            authorized_user_ids,
489        })
490    }
491
492    pub fn update_worktree(
493        &mut self,
494        connection_id: ConnectionId,
495        project_id: u64,
496        worktree_id: u64,
497        removed_entries: &[u64],
498        updated_entries: &[proto::Entry],
499    ) -> Option<Vec<ConnectionId>> {
500        let project = self.write_project(project_id, connection_id)?;
501        let share = project.worktrees.get_mut(&worktree_id)?.share.as_mut()?;
502        for entry_id in removed_entries {
503            share.entries.remove(&entry_id);
504        }
505        for entry in updated_entries {
506            share.entries.insert(entry.id, entry.clone());
507        }
508        Some(project.connection_ids())
509    }
510
511    pub fn project_connection_ids(
512        &self,
513        project_id: u64,
514        acting_connection_id: ConnectionId,
515    ) -> Option<Vec<ConnectionId>> {
516        Some(
517            self.read_project(project_id, acting_connection_id)?
518                .connection_ids(),
519        )
520    }
521
522    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
523        Some(self.channels.get(&channel_id)?.connection_ids())
524    }
525
526    #[cfg(test)]
527    pub fn project(&self, project_id: u64) -> Option<&Project> {
528        self.projects.get(&project_id)
529    }
530
531    pub fn read_project(&self, project_id: u64, connection_id: ConnectionId) -> Option<&Project> {
532        let project = self.projects.get(&project_id)?;
533        if project.host_connection_id == connection_id
534            || project.share.as_ref()?.guests.contains_key(&connection_id)
535        {
536            Some(project)
537        } else {
538            None
539        }
540    }
541
542    fn write_project(
543        &mut self,
544        project_id: u64,
545        connection_id: ConnectionId,
546    ) -> Option<&mut Project> {
547        let project = self.projects.get_mut(&project_id)?;
548        if project.host_connection_id == connection_id
549            || project.share.as_ref()?.guests.contains_key(&connection_id)
550        {
551            Some(project)
552        } else {
553            None
554        }
555    }
556
557    #[cfg(test)]
558    fn check_invariants(&self) {
559        for (connection_id, connection) in &self.connections {
560            for project_id in &connection.projects {
561                let project = &self.projects.get(&project_id).unwrap();
562                if project.host_connection_id != *connection_id {
563                    assert!(project
564                        .share
565                        .as_ref()
566                        .unwrap()
567                        .guests
568                        .contains_key(connection_id));
569                }
570            }
571            for channel_id in &connection.channels {
572                let channel = self.channels.get(channel_id).unwrap();
573                assert!(channel.connection_ids.contains(connection_id));
574            }
575            assert!(self
576                .connections_by_user_id
577                .get(&connection.user_id)
578                .unwrap()
579                .contains(connection_id));
580        }
581
582        for (user_id, connection_ids) in &self.connections_by_user_id {
583            for connection_id in connection_ids {
584                assert_eq!(
585                    self.connections.get(connection_id).unwrap().user_id,
586                    *user_id
587                );
588            }
589        }
590
591        for (project_id, project) in &self.projects {
592            let host_connection = self.connections.get(&project.host_connection_id).unwrap();
593            assert!(host_connection.projects.contains(project_id));
594
595            for authorized_user_ids in project.authorized_user_ids() {
596                let visible_project_ids = self
597                    .visible_projects_by_user_id
598                    .get(&authorized_user_ids)
599                    .unwrap();
600                assert!(visible_project_ids.contains(project_id));
601            }
602
603            if let Some(share) = &project.share {
604                for guest_connection_id in share.guests.keys() {
605                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
606                    assert!(guest_connection.projects.contains(project_id));
607                }
608                assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
609                assert_eq!(
610                    share.active_replica_ids,
611                    share
612                        .guests
613                        .values()
614                        .map(|(replica_id, _)| *replica_id)
615                        .collect::<HashSet<_>>(),
616                );
617            }
618        }
619
620        for (user_id, visible_project_ids) in &self.visible_projects_by_user_id {
621            for project_id in visible_project_ids {
622                let project = self.projects.get(project_id).unwrap();
623                assert!(project.authorized_user_ids().contains(user_id));
624            }
625        }
626
627        for (channel_id, channel) in &self.channels {
628            for connection_id in &channel.connection_ids {
629                let connection = self.connections.get(connection_id).unwrap();
630                assert!(connection.channels.contains(channel_id));
631            }
632        }
633    }
634}
635
636impl Project {
637    pub fn has_authorized_user_id(&self, user_id: UserId) -> bool {
638        self.worktrees
639            .values()
640            .any(|worktree| worktree.authorized_user_ids.contains(&user_id))
641    }
642
643    pub fn authorized_user_ids(&self) -> Vec<UserId> {
644        let mut ids = self
645            .worktrees
646            .values()
647            .flat_map(|worktree| worktree.authorized_user_ids.iter())
648            .copied()
649            .collect::<Vec<_>>();
650        ids.sort_unstable();
651        ids.dedup();
652        ids
653    }
654
655    pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
656        if let Some(share) = &self.share {
657            share.guests.keys().copied().collect()
658        } else {
659            Vec::new()
660        }
661    }
662
663    pub fn connection_ids(&self) -> Vec<ConnectionId> {
664        if let Some(share) = &self.share {
665            share
666                .guests
667                .keys()
668                .copied()
669                .chain(Some(self.host_connection_id))
670                .collect()
671        } else {
672            vec![self.host_connection_id]
673        }
674    }
675
676    pub fn share(&self) -> tide::Result<&ProjectShare> {
677        Ok(self
678            .share
679            .as_ref()
680            .ok_or_else(|| anyhow!("worktree is not shared"))?)
681    }
682
683    fn share_mut(&mut self) -> tide::Result<&mut ProjectShare> {
684        Ok(self
685            .share
686            .as_mut()
687            .ok_or_else(|| anyhow!("worktree is not shared"))?)
688    }
689}
690
691impl Channel {
692    fn connection_ids(&self) -> Vec<ConnectionId> {
693        self.connection_ids.iter().copied().collect()
694    }
695}