store.rs

  1use crate::db::{ChannelId, UserId};
  2use anyhow::anyhow;
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use std::{collections::hash_map, path::PathBuf};
  6
  7#[derive(Default)]
  8pub struct Store {
  9    connections: HashMap<ConnectionId, ConnectionState>,
 10    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 11    projects: HashMap<u64, Project>,
 12    visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
 13    channels: HashMap<ChannelId, Channel>,
 14    next_project_id: u64,
 15}
 16
 17struct ConnectionState {
 18    user_id: UserId,
 19    projects: HashSet<u64>,
 20    channels: HashSet<ChannelId>,
 21}
 22
 23pub struct Project {
 24    pub host_connection_id: ConnectionId,
 25    pub host_user_id: UserId,
 26    pub share: Option<ProjectShare>,
 27    pub worktrees: HashMap<u64, Worktree>,
 28}
 29
 30pub struct Worktree {
 31    pub authorized_user_ids: Vec<UserId>,
 32    pub root_name: String,
 33    pub share: Option<WorktreeShare>,
 34}
 35
 36#[derive(Default)]
 37pub struct ProjectShare {
 38    pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
 39    pub active_replica_ids: HashSet<ReplicaId>,
 40}
 41
 42pub struct WorktreeShare {
 43    pub entries: HashMap<u64, proto::Entry>,
 44    pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
 45}
 46
 47#[derive(Default)]
 48pub struct Channel {
 49    pub connection_ids: HashSet<ConnectionId>,
 50}
 51
 52pub type ReplicaId = u16;
 53
 54#[derive(Default)]
 55pub struct RemovedConnectionState {
 56    pub hosted_projects: HashMap<u64, Project>,
 57    pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
 58    pub contact_ids: HashSet<UserId>,
 59}
 60
 61pub struct JoinedProject<'a> {
 62    pub replica_id: ReplicaId,
 63    pub project: &'a Project,
 64}
 65
 66pub struct UnsharedProject {
 67    pub connection_ids: Vec<ConnectionId>,
 68    pub authorized_user_ids: Vec<UserId>,
 69}
 70
 71pub struct LeftProject {
 72    pub connection_ids: Vec<ConnectionId>,
 73    pub authorized_user_ids: Vec<UserId>,
 74}
 75
 76impl Store {
 77    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 78        self.connections.insert(
 79            connection_id,
 80            ConnectionState {
 81                user_id,
 82                projects: Default::default(),
 83                channels: Default::default(),
 84            },
 85        );
 86        self.connections_by_user_id
 87            .entry(user_id)
 88            .or_default()
 89            .insert(connection_id);
 90    }
 91
 92    pub fn remove_connection(
 93        &mut self,
 94        connection_id: ConnectionId,
 95    ) -> tide::Result<RemovedConnectionState> {
 96        let connection = if let Some(connection) = self.connections.remove(&connection_id) {
 97            connection
 98        } else {
 99            return Err(anyhow!("no such connection"))?;
100        };
101
102        for channel_id in &connection.channels {
103            if let Some(channel) = self.channels.get_mut(&channel_id) {
104                channel.connection_ids.remove(&connection_id);
105            }
106        }
107
108        let user_connections = self
109            .connections_by_user_id
110            .get_mut(&connection.user_id)
111            .unwrap();
112        user_connections.remove(&connection_id);
113        if user_connections.is_empty() {
114            self.connections_by_user_id.remove(&connection.user_id);
115        }
116
117        let mut result = RemovedConnectionState::default();
118        for project_id in connection.projects.clone() {
119            if let Some(project) = self.unregister_project(project_id, connection_id) {
120                result.contact_ids.extend(project.authorized_user_ids());
121                result.hosted_projects.insert(project_id, project);
122            } else if let Some(project) = self.leave_project(connection_id, project_id) {
123                result
124                    .guest_project_ids
125                    .insert(project_id, project.connection_ids);
126                result.contact_ids.extend(project.authorized_user_ids);
127            }
128        }
129
130        #[cfg(test)]
131        self.check_invariants();
132
133        Ok(result)
134    }
135
136    #[cfg(test)]
137    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
138        self.channels.get(&id)
139    }
140
141    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
142        if let Some(connection) = self.connections.get_mut(&connection_id) {
143            connection.channels.insert(channel_id);
144            self.channels
145                .entry(channel_id)
146                .or_default()
147                .connection_ids
148                .insert(connection_id);
149        }
150    }
151
152    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
153        if let Some(connection) = self.connections.get_mut(&connection_id) {
154            connection.channels.remove(&channel_id);
155            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
156                entry.get_mut().connection_ids.remove(&connection_id);
157                if entry.get_mut().connection_ids.is_empty() {
158                    entry.remove();
159                }
160            }
161        }
162    }
163
164    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
165        Ok(self
166            .connections
167            .get(&connection_id)
168            .ok_or_else(|| anyhow!("unknown connection"))?
169            .user_id)
170    }
171
172    pub fn connection_ids_for_user<'a>(
173        &'a self,
174        user_id: UserId,
175    ) -> impl 'a + Iterator<Item = ConnectionId> {
176        self.connections_by_user_id
177            .get(&user_id)
178            .into_iter()
179            .flatten()
180            .copied()
181    }
182
183    pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
184        let mut contacts = HashMap::default();
185        for project_id in self
186            .visible_projects_by_user_id
187            .get(&user_id)
188            .unwrap_or(&HashSet::default())
189        {
190            let project = &self.projects[project_id];
191
192            let mut guests = HashSet::default();
193            if let Ok(share) = project.share() {
194                for guest_connection_id in share.guests.keys() {
195                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
196                        guests.insert(user_id.to_proto());
197                    }
198                }
199            }
200
201            if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
202                let mut worktree_root_names = project
203                    .worktrees
204                    .values()
205                    .map(|worktree| worktree.root_name.clone())
206                    .collect::<Vec<_>>();
207                worktree_root_names.sort_unstable();
208                contacts
209                    .entry(host_user_id)
210                    .or_insert_with(|| proto::Contact {
211                        user_id: host_user_id.to_proto(),
212                        projects: Vec::new(),
213                    })
214                    .projects
215                    .push(proto::ProjectMetadata {
216                        id: *project_id,
217                        worktree_root_names,
218                        is_shared: project.share.is_some(),
219                        guests: guests.into_iter().collect(),
220                    });
221            }
222        }
223
224        contacts.into_values().collect()
225    }
226
227    pub fn register_project(
228        &mut self,
229        host_connection_id: ConnectionId,
230        host_user_id: UserId,
231    ) -> u64 {
232        let project_id = self.next_project_id;
233        self.projects.insert(
234            project_id,
235            Project {
236                host_connection_id,
237                host_user_id,
238                share: None,
239                worktrees: Default::default(),
240            },
241        );
242        self.next_project_id += 1;
243        project_id
244    }
245
246    pub fn register_worktree(
247        &mut self,
248        project_id: u64,
249        worktree_id: u64,
250        worktree: Worktree,
251    ) -> bool {
252        if let Some(project) = self.projects.get_mut(&project_id) {
253            for authorized_user_id in &worktree.authorized_user_ids {
254                self.visible_projects_by_user_id
255                    .entry(*authorized_user_id)
256                    .or_default()
257                    .insert(project_id);
258            }
259            if let Some(connection) = self.connections.get_mut(&project.host_connection_id) {
260                connection.projects.insert(project_id);
261            }
262            project.worktrees.insert(worktree_id, worktree);
263
264            #[cfg(test)]
265            self.check_invariants();
266            true
267        } else {
268            false
269        }
270    }
271
272    pub fn unregister_project(
273        &mut self,
274        project_id: u64,
275        connection_id: ConnectionId,
276    ) -> Option<Project> {
277        match self.projects.entry(project_id) {
278            hash_map::Entry::Occupied(e) => {
279                if e.get().host_connection_id == connection_id {
280                    for user_id in e.get().authorized_user_ids() {
281                        if let hash_map::Entry::Occupied(mut projects) =
282                            self.visible_projects_by_user_id.entry(user_id)
283                        {
284                            projects.get_mut().remove(&project_id);
285                        }
286                    }
287
288                    Some(e.remove())
289                } else {
290                    None
291                }
292            }
293            hash_map::Entry::Vacant(_) => None,
294        }
295    }
296
297    pub fn unregister_worktree(
298        &mut self,
299        project_id: u64,
300        worktree_id: u64,
301        acting_connection_id: ConnectionId,
302    ) -> tide::Result<(Worktree, Vec<ConnectionId>)> {
303        let project = self
304            .projects
305            .get_mut(&project_id)
306            .ok_or_else(|| anyhow!("no such project"))?;
307        if project.host_connection_id != acting_connection_id {
308            Err(anyhow!("not your worktree"))?;
309        }
310
311        let worktree = project
312            .worktrees
313            .remove(&worktree_id)
314            .ok_or_else(|| anyhow!("no such worktree"))?;
315
316        let mut guest_connection_ids = Vec::new();
317        if let Some(share) = &project.share {
318            guest_connection_ids.extend(share.guests.keys());
319        }
320
321        for authorized_user_id in &worktree.authorized_user_ids {
322            if let Some(visible_projects) =
323                self.visible_projects_by_user_id.get_mut(authorized_user_id)
324            {
325                if !project.has_authorized_user_id(*authorized_user_id) {
326                    visible_projects.remove(&project_id);
327                }
328            }
329        }
330
331        #[cfg(test)]
332        self.check_invariants();
333
334        Ok((worktree, guest_connection_ids))
335    }
336
337    pub fn share_project(&mut self, project_id: u64, connection_id: ConnectionId) -> bool {
338        if let Some(project) = self.projects.get_mut(&project_id) {
339            if project.host_connection_id == connection_id {
340                project.share = Some(ProjectShare::default());
341                return true;
342            }
343        }
344        false
345    }
346
347    pub fn unshare_project(
348        &mut self,
349        project_id: u64,
350        acting_connection_id: ConnectionId,
351    ) -> tide::Result<UnsharedProject> {
352        let project = if let Some(project) = self.projects.get_mut(&project_id) {
353            project
354        } else {
355            return Err(anyhow!("no such project"))?;
356        };
357
358        if project.host_connection_id != acting_connection_id {
359            return Err(anyhow!("not your project"))?;
360        }
361
362        let connection_ids = project.connection_ids();
363        let authorized_user_ids = project.authorized_user_ids();
364        if let Some(share) = project.share.take() {
365            for connection_id in share.guests.into_keys() {
366                if let Some(connection) = self.connections.get_mut(&connection_id) {
367                    connection.projects.remove(&project_id);
368                }
369            }
370
371            for worktree in project.worktrees.values_mut() {
372                worktree.share.take();
373            }
374
375            #[cfg(test)]
376            self.check_invariants();
377
378            Ok(UnsharedProject {
379                connection_ids,
380                authorized_user_ids,
381            })
382        } else {
383            Err(anyhow!("project is not shared"))?
384        }
385    }
386
387    pub fn share_worktree(
388        &mut self,
389        project_id: u64,
390        worktree_id: u64,
391        connection_id: ConnectionId,
392        entries: HashMap<u64, proto::Entry>,
393        diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
394    ) -> Option<Vec<UserId>> {
395        let project = self.projects.get_mut(&project_id)?;
396        let worktree = project.worktrees.get_mut(&worktree_id)?;
397        if project.host_connection_id == connection_id && project.share.is_some() {
398            worktree.share = Some(WorktreeShare {
399                entries,
400                diagnostic_summaries,
401            });
402            Some(project.authorized_user_ids())
403        } else {
404            None
405        }
406    }
407
408    pub fn update_diagnostic_summary(
409        &mut self,
410        project_id: u64,
411        worktree_id: u64,
412        connection_id: ConnectionId,
413        summary: proto::DiagnosticSummary,
414    ) -> Option<Vec<ConnectionId>> {
415        let project = self.projects.get_mut(&project_id)?;
416        let worktree = project.worktrees.get_mut(&worktree_id)?;
417        if project.host_connection_id == connection_id {
418            if let Some(share) = worktree.share.as_mut() {
419                share
420                    .diagnostic_summaries
421                    .insert(summary.path.clone().into(), summary);
422                return Some(project.connection_ids());
423            }
424        }
425
426        None
427    }
428
429    pub fn join_project(
430        &mut self,
431        connection_id: ConnectionId,
432        user_id: UserId,
433        project_id: u64,
434    ) -> tide::Result<JoinedProject> {
435        let connection = self
436            .connections
437            .get_mut(&connection_id)
438            .ok_or_else(|| anyhow!("no such connection"))?;
439        let project = self
440            .projects
441            .get_mut(&project_id)
442            .and_then(|project| {
443                if project.has_authorized_user_id(user_id) {
444                    Some(project)
445                } else {
446                    None
447                }
448            })
449            .ok_or_else(|| anyhow!("no such project"))?;
450
451        let share = project.share_mut()?;
452        connection.projects.insert(project_id);
453
454        let mut replica_id = 1;
455        while share.active_replica_ids.contains(&replica_id) {
456            replica_id += 1;
457        }
458        share.active_replica_ids.insert(replica_id);
459        share.guests.insert(connection_id, (replica_id, user_id));
460
461        #[cfg(test)]
462        self.check_invariants();
463
464        Ok(JoinedProject {
465            replica_id,
466            project: &self.projects[&project_id],
467        })
468    }
469
470    pub fn leave_project(
471        &mut self,
472        connection_id: ConnectionId,
473        project_id: u64,
474    ) -> Option<LeftProject> {
475        let project = self.projects.get_mut(&project_id)?;
476        let share = project.share.as_mut()?;
477        let (replica_id, _) = share.guests.remove(&connection_id)?;
478        share.active_replica_ids.remove(&replica_id);
479
480        if let Some(connection) = self.connections.get_mut(&connection_id) {
481            connection.projects.remove(&project_id);
482        }
483
484        let connection_ids = project.connection_ids();
485        let authorized_user_ids = project.authorized_user_ids();
486
487        #[cfg(test)]
488        self.check_invariants();
489
490        Some(LeftProject {
491            connection_ids,
492            authorized_user_ids,
493        })
494    }
495
496    pub fn update_worktree(
497        &mut self,
498        connection_id: ConnectionId,
499        project_id: u64,
500        worktree_id: u64,
501        removed_entries: &[u64],
502        updated_entries: &[proto::Entry],
503    ) -> Option<Vec<ConnectionId>> {
504        let project = self.write_project(project_id, connection_id)?;
505        let share = project.worktrees.get_mut(&worktree_id)?.share.as_mut()?;
506        for entry_id in removed_entries {
507            share.entries.remove(&entry_id);
508        }
509        for entry in updated_entries {
510            share.entries.insert(entry.id, entry.clone());
511        }
512        Some(project.connection_ids())
513    }
514
515    pub fn project_connection_ids(
516        &self,
517        project_id: u64,
518        acting_connection_id: ConnectionId,
519    ) -> Option<Vec<ConnectionId>> {
520        Some(
521            self.read_project(project_id, acting_connection_id)?
522                .connection_ids(),
523        )
524    }
525
526    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
527        Some(self.channels.get(&channel_id)?.connection_ids())
528    }
529
530    #[cfg(test)]
531    pub fn project(&self, project_id: u64) -> Option<&Project> {
532        self.projects.get(&project_id)
533    }
534
535    pub fn read_project(&self, project_id: u64, connection_id: ConnectionId) -> Option<&Project> {
536        let project = self.projects.get(&project_id)?;
537        if project.host_connection_id == connection_id
538            || project.share.as_ref()?.guests.contains_key(&connection_id)
539        {
540            Some(project)
541        } else {
542            None
543        }
544    }
545
546    fn write_project(
547        &mut self,
548        project_id: u64,
549        connection_id: ConnectionId,
550    ) -> Option<&mut Project> {
551        let project = self.projects.get_mut(&project_id)?;
552        if project.host_connection_id == connection_id
553            || project.share.as_ref()?.guests.contains_key(&connection_id)
554        {
555            Some(project)
556        } else {
557            None
558        }
559    }
560
561    #[cfg(test)]
562    fn check_invariants(&self) {
563        for (connection_id, connection) in &self.connections {
564            for project_id in &connection.projects {
565                let project = &self.projects.get(&project_id).unwrap();
566                if project.host_connection_id != *connection_id {
567                    assert!(project
568                        .share
569                        .as_ref()
570                        .unwrap()
571                        .guests
572                        .contains_key(connection_id));
573                }
574            }
575            for channel_id in &connection.channels {
576                let channel = self.channels.get(channel_id).unwrap();
577                assert!(channel.connection_ids.contains(connection_id));
578            }
579            assert!(self
580                .connections_by_user_id
581                .get(&connection.user_id)
582                .unwrap()
583                .contains(connection_id));
584        }
585
586        for (user_id, connection_ids) in &self.connections_by_user_id {
587            for connection_id in connection_ids {
588                assert_eq!(
589                    self.connections.get(connection_id).unwrap().user_id,
590                    *user_id
591                );
592            }
593        }
594
595        for (project_id, project) in &self.projects {
596            let host_connection = self.connections.get(&project.host_connection_id).unwrap();
597            assert!(host_connection.projects.contains(project_id));
598
599            for authorized_user_ids in project.authorized_user_ids() {
600                let visible_project_ids = self
601                    .visible_projects_by_user_id
602                    .get(&authorized_user_ids)
603                    .unwrap();
604                assert!(visible_project_ids.contains(project_id));
605            }
606
607            if let Some(share) = &project.share {
608                for guest_connection_id in share.guests.keys() {
609                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
610                    assert!(guest_connection.projects.contains(project_id));
611                }
612                assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
613                assert_eq!(
614                    share.active_replica_ids,
615                    share
616                        .guests
617                        .values()
618                        .map(|(replica_id, _)| *replica_id)
619                        .collect::<HashSet<_>>(),
620                );
621            }
622        }
623
624        for (user_id, visible_project_ids) in &self.visible_projects_by_user_id {
625            for project_id in visible_project_ids {
626                let project = self.projects.get(project_id).unwrap();
627                assert!(project.authorized_user_ids().contains(user_id));
628            }
629        }
630
631        for (channel_id, channel) in &self.channels {
632            for connection_id in &channel.connection_ids {
633                let connection = self.connections.get(connection_id).unwrap();
634                assert!(connection.channels.contains(channel_id));
635            }
636        }
637    }
638}
639
640impl Project {
641    pub fn has_authorized_user_id(&self, user_id: UserId) -> bool {
642        self.worktrees
643            .values()
644            .any(|worktree| worktree.authorized_user_ids.contains(&user_id))
645    }
646
647    pub fn authorized_user_ids(&self) -> Vec<UserId> {
648        let mut ids = self
649            .worktrees
650            .values()
651            .flat_map(|worktree| worktree.authorized_user_ids.iter())
652            .copied()
653            .collect::<Vec<_>>();
654        ids.sort_unstable();
655        ids.dedup();
656        ids
657    }
658
659    pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
660        if let Some(share) = &self.share {
661            share.guests.keys().copied().collect()
662        } else {
663            Vec::new()
664        }
665    }
666
667    pub fn connection_ids(&self) -> Vec<ConnectionId> {
668        if let Some(share) = &self.share {
669            share
670                .guests
671                .keys()
672                .copied()
673                .chain(Some(self.host_connection_id))
674                .collect()
675        } else {
676            vec![self.host_connection_id]
677        }
678    }
679
680    pub fn share(&self) -> tide::Result<&ProjectShare> {
681        Ok(self
682            .share
683            .as_ref()
684            .ok_or_else(|| anyhow!("worktree is not shared"))?)
685    }
686
687    fn share_mut(&mut self) -> tide::Result<&mut ProjectShare> {
688        Ok(self
689            .share
690            .as_mut()
691            .ok_or_else(|| anyhow!("worktree is not shared"))?)
692    }
693}
694
695impl Channel {
696    fn connection_ids(&self) -> Vec<ConnectionId> {
697        self.connection_ids.iter().copied().collect()
698    }
699}