store.rs

  1use crate::db::{ChannelId, UserId};
  2use anyhow::anyhow;
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use std::{collections::hash_map, path::PathBuf};
  6
  7#[derive(Default)]
  8pub struct Store {
  9    connections: HashMap<ConnectionId, ConnectionState>,
 10    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 11    projects: HashMap<u64, Project>,
 12    visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
 13    channels: HashMap<ChannelId, Channel>,
 14    next_project_id: u64,
 15}
 16
 17struct ConnectionState {
 18    user_id: UserId,
 19    projects: HashSet<u64>,
 20    channels: HashSet<ChannelId>,
 21}
 22
 23pub struct Project {
 24    pub host_connection_id: ConnectionId,
 25    pub host_user_id: UserId,
 26    pub share: Option<ProjectShare>,
 27    pub worktrees: HashMap<u64, Worktree>,
 28}
 29
 30pub struct Worktree {
 31    pub authorized_user_ids: Vec<UserId>,
 32    pub root_name: String,
 33    pub weak: bool,
 34}
 35
 36#[derive(Default)]
 37pub struct ProjectShare {
 38    pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
 39    pub active_replica_ids: HashSet<ReplicaId>,
 40    pub worktrees: HashMap<u64, WorktreeShare>,
 41}
 42
 43#[derive(Default)]
 44pub struct WorktreeShare {
 45    pub entries: HashMap<u64, proto::Entry>,
 46    pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
 47}
 48
 49#[derive(Default)]
 50pub struct Channel {
 51    pub connection_ids: HashSet<ConnectionId>,
 52}
 53
 54pub type ReplicaId = u16;
 55
 56#[derive(Default)]
 57pub struct RemovedConnectionState {
 58    pub hosted_projects: HashMap<u64, Project>,
 59    pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
 60    pub contact_ids: HashSet<UserId>,
 61}
 62
 63pub struct JoinedProject<'a> {
 64    pub replica_id: ReplicaId,
 65    pub project: &'a Project,
 66}
 67
 68pub struct UnsharedProject {
 69    pub connection_ids: Vec<ConnectionId>,
 70    pub authorized_user_ids: Vec<UserId>,
 71}
 72
 73pub struct LeftProject {
 74    pub connection_ids: Vec<ConnectionId>,
 75    pub authorized_user_ids: Vec<UserId>,
 76}
 77
 78impl Store {
 79    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 80        self.connections.insert(
 81            connection_id,
 82            ConnectionState {
 83                user_id,
 84                projects: Default::default(),
 85                channels: Default::default(),
 86            },
 87        );
 88        self.connections_by_user_id
 89            .entry(user_id)
 90            .or_default()
 91            .insert(connection_id);
 92    }
 93
 94    pub fn remove_connection(
 95        &mut self,
 96        connection_id: ConnectionId,
 97    ) -> tide::Result<RemovedConnectionState> {
 98        let connection = if let Some(connection) = self.connections.remove(&connection_id) {
 99            connection
100        } else {
101            return Err(anyhow!("no such connection"))?;
102        };
103
104        for channel_id in &connection.channels {
105            if let Some(channel) = self.channels.get_mut(&channel_id) {
106                channel.connection_ids.remove(&connection_id);
107            }
108        }
109
110        let user_connections = self
111            .connections_by_user_id
112            .get_mut(&connection.user_id)
113            .unwrap();
114        user_connections.remove(&connection_id);
115        if user_connections.is_empty() {
116            self.connections_by_user_id.remove(&connection.user_id);
117        }
118
119        let mut result = RemovedConnectionState::default();
120        for project_id in connection.projects.clone() {
121            if let Ok(project) = self.unregister_project(project_id, connection_id) {
122                result.contact_ids.extend(project.authorized_user_ids());
123                result.hosted_projects.insert(project_id, project);
124            } else if let Ok(project) = self.leave_project(connection_id, project_id) {
125                result
126                    .guest_project_ids
127                    .insert(project_id, project.connection_ids);
128                result.contact_ids.extend(project.authorized_user_ids);
129            }
130        }
131
132        #[cfg(test)]
133        self.check_invariants();
134
135        Ok(result)
136    }
137
138    #[cfg(test)]
139    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
140        self.channels.get(&id)
141    }
142
143    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
144        if let Some(connection) = self.connections.get_mut(&connection_id) {
145            connection.channels.insert(channel_id);
146            self.channels
147                .entry(channel_id)
148                .or_default()
149                .connection_ids
150                .insert(connection_id);
151        }
152    }
153
154    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
155        if let Some(connection) = self.connections.get_mut(&connection_id) {
156            connection.channels.remove(&channel_id);
157            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
158                entry.get_mut().connection_ids.remove(&connection_id);
159                if entry.get_mut().connection_ids.is_empty() {
160                    entry.remove();
161                }
162            }
163        }
164    }
165
166    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
167        Ok(self
168            .connections
169            .get(&connection_id)
170            .ok_or_else(|| anyhow!("unknown connection"))?
171            .user_id)
172    }
173
174    pub fn connection_ids_for_user<'a>(
175        &'a self,
176        user_id: UserId,
177    ) -> impl 'a + Iterator<Item = ConnectionId> {
178        self.connections_by_user_id
179            .get(&user_id)
180            .into_iter()
181            .flatten()
182            .copied()
183    }
184
185    pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
186        let mut contacts = HashMap::default();
187        for project_id in self
188            .visible_projects_by_user_id
189            .get(&user_id)
190            .unwrap_or(&HashSet::default())
191        {
192            let project = &self.projects[project_id];
193
194            let mut guests = HashSet::default();
195            if let Ok(share) = project.share() {
196                for guest_connection_id in share.guests.keys() {
197                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
198                        guests.insert(user_id.to_proto());
199                    }
200                }
201            }
202
203            if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
204                let mut worktree_root_names = project
205                    .worktrees
206                    .values()
207                    .filter(|worktree| !worktree.weak)
208                    .map(|worktree| worktree.root_name.clone())
209                    .collect::<Vec<_>>();
210                worktree_root_names.sort_unstable();
211                contacts
212                    .entry(host_user_id)
213                    .or_insert_with(|| proto::Contact {
214                        user_id: host_user_id.to_proto(),
215                        projects: Vec::new(),
216                    })
217                    .projects
218                    .push(proto::ProjectMetadata {
219                        id: *project_id,
220                        worktree_root_names,
221                        is_shared: project.share.is_some(),
222                        guests: guests.into_iter().collect(),
223                    });
224            }
225        }
226
227        contacts.into_values().collect()
228    }
229
230    pub fn register_project(
231        &mut self,
232        host_connection_id: ConnectionId,
233        host_user_id: UserId,
234    ) -> u64 {
235        let project_id = self.next_project_id;
236        self.projects.insert(
237            project_id,
238            Project {
239                host_connection_id,
240                host_user_id,
241                share: None,
242                worktrees: Default::default(),
243            },
244        );
245        self.next_project_id += 1;
246        project_id
247    }
248
249    pub fn register_worktree(
250        &mut self,
251        project_id: u64,
252        worktree_id: u64,
253        connection_id: ConnectionId,
254        worktree: Worktree,
255    ) -> tide::Result<()> {
256        let project = self
257            .projects
258            .get_mut(&project_id)
259            .ok_or_else(|| anyhow!("no such project"))?;
260        if project.host_connection_id == connection_id {
261            for authorized_user_id in &worktree.authorized_user_ids {
262                self.visible_projects_by_user_id
263                    .entry(*authorized_user_id)
264                    .or_default()
265                    .insert(project_id);
266            }
267            if let Some(connection) = self.connections.get_mut(&project.host_connection_id) {
268                connection.projects.insert(project_id);
269            }
270            project.worktrees.insert(worktree_id, worktree);
271            if let Ok(share) = project.share_mut() {
272                share.worktrees.insert(worktree_id, Default::default());
273            }
274
275            #[cfg(test)]
276            self.check_invariants();
277            Ok(())
278        } else {
279            Err(anyhow!("no such project"))?
280        }
281    }
282
283    pub fn unregister_project(
284        &mut self,
285        project_id: u64,
286        connection_id: ConnectionId,
287    ) -> tide::Result<Project> {
288        match self.projects.entry(project_id) {
289            hash_map::Entry::Occupied(e) => {
290                if e.get().host_connection_id == connection_id {
291                    for user_id in e.get().authorized_user_ids() {
292                        if let hash_map::Entry::Occupied(mut projects) =
293                            self.visible_projects_by_user_id.entry(user_id)
294                        {
295                            projects.get_mut().remove(&project_id);
296                        }
297                    }
298
299                    let project = e.remove();
300                    if let Some(share) = &project.share {
301                        for guest_connection in share.guests.keys() {
302                            if let Some(connection) = self.connections.get_mut(&guest_connection) {
303                                connection.projects.remove(&project_id);
304                            }
305                        }
306                    }
307
308                    Ok(project)
309                } else {
310                    Err(anyhow!("no such project"))?
311                }
312            }
313            hash_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?,
314        }
315    }
316
317    pub fn unregister_worktree(
318        &mut self,
319        project_id: u64,
320        worktree_id: u64,
321        acting_connection_id: ConnectionId,
322    ) -> tide::Result<(Worktree, Vec<ConnectionId>)> {
323        let project = self
324            .projects
325            .get_mut(&project_id)
326            .ok_or_else(|| anyhow!("no such project"))?;
327        if project.host_connection_id != acting_connection_id {
328            Err(anyhow!("not your worktree"))?;
329        }
330
331        let worktree = project
332            .worktrees
333            .remove(&worktree_id)
334            .ok_or_else(|| anyhow!("no such worktree"))?;
335
336        let mut guest_connection_ids = Vec::new();
337        if let Ok(share) = project.share_mut() {
338            guest_connection_ids.extend(share.guests.keys());
339            share.worktrees.remove(&worktree_id);
340        }
341
342        for authorized_user_id in &worktree.authorized_user_ids {
343            if let Some(visible_projects) =
344                self.visible_projects_by_user_id.get_mut(authorized_user_id)
345            {
346                if !project.has_authorized_user_id(*authorized_user_id) {
347                    visible_projects.remove(&project_id);
348                }
349            }
350        }
351
352        #[cfg(test)]
353        self.check_invariants();
354
355        Ok((worktree, guest_connection_ids))
356    }
357
358    pub fn share_project(&mut self, project_id: u64, connection_id: ConnectionId) -> bool {
359        if let Some(project) = self.projects.get_mut(&project_id) {
360            if project.host_connection_id == connection_id {
361                let mut share = ProjectShare::default();
362                for worktree_id in project.worktrees.keys() {
363                    share.worktrees.insert(*worktree_id, Default::default());
364                }
365                project.share = Some(share);
366                return true;
367            }
368        }
369        false
370    }
371
372    pub fn unshare_project(
373        &mut self,
374        project_id: u64,
375        acting_connection_id: ConnectionId,
376    ) -> tide::Result<UnsharedProject> {
377        let project = if let Some(project) = self.projects.get_mut(&project_id) {
378            project
379        } else {
380            return Err(anyhow!("no such project"))?;
381        };
382
383        if project.host_connection_id != acting_connection_id {
384            return Err(anyhow!("not your project"))?;
385        }
386
387        let connection_ids = project.connection_ids();
388        let authorized_user_ids = project.authorized_user_ids();
389        if let Some(share) = project.share.take() {
390            for connection_id in share.guests.into_keys() {
391                if let Some(connection) = self.connections.get_mut(&connection_id) {
392                    connection.projects.remove(&project_id);
393                }
394            }
395
396            #[cfg(test)]
397            self.check_invariants();
398
399            Ok(UnsharedProject {
400                connection_ids,
401                authorized_user_ids,
402            })
403        } else {
404            Err(anyhow!("project is not shared"))?
405        }
406    }
407
408    pub fn update_diagnostic_summary(
409        &mut self,
410        project_id: u64,
411        worktree_id: u64,
412        connection_id: ConnectionId,
413        summary: proto::DiagnosticSummary,
414    ) -> tide::Result<Vec<ConnectionId>> {
415        let project = self
416            .projects
417            .get_mut(&project_id)
418            .ok_or_else(|| anyhow!("no such project"))?;
419        if project.host_connection_id == connection_id {
420            let worktree = project
421                .share_mut()?
422                .worktrees
423                .get_mut(&worktree_id)
424                .ok_or_else(|| anyhow!("no such worktree"))?;
425            worktree
426                .diagnostic_summaries
427                .insert(summary.path.clone().into(), summary);
428            return Ok(project.connection_ids());
429        }
430
431        Err(anyhow!("no such worktree"))?
432    }
433
434    pub fn join_project(
435        &mut self,
436        connection_id: ConnectionId,
437        user_id: UserId,
438        project_id: u64,
439    ) -> tide::Result<JoinedProject> {
440        let connection = self
441            .connections
442            .get_mut(&connection_id)
443            .ok_or_else(|| anyhow!("no such connection"))?;
444        let project = self
445            .projects
446            .get_mut(&project_id)
447            .and_then(|project| {
448                if project.has_authorized_user_id(user_id) {
449                    Some(project)
450                } else {
451                    None
452                }
453            })
454            .ok_or_else(|| anyhow!("no such project"))?;
455
456        let share = project.share_mut()?;
457        connection.projects.insert(project_id);
458
459        let mut replica_id = 1;
460        while share.active_replica_ids.contains(&replica_id) {
461            replica_id += 1;
462        }
463        share.active_replica_ids.insert(replica_id);
464        share.guests.insert(connection_id, (replica_id, user_id));
465
466        #[cfg(test)]
467        self.check_invariants();
468
469        Ok(JoinedProject {
470            replica_id,
471            project: &self.projects[&project_id],
472        })
473    }
474
475    pub fn leave_project(
476        &mut self,
477        connection_id: ConnectionId,
478        project_id: u64,
479    ) -> tide::Result<LeftProject> {
480        let project = self
481            .projects
482            .get_mut(&project_id)
483            .ok_or_else(|| anyhow!("no such project"))?;
484        let share = project
485            .share
486            .as_mut()
487            .ok_or_else(|| anyhow!("project is not shared"))?;
488        let (replica_id, _) = share
489            .guests
490            .remove(&connection_id)
491            .ok_or_else(|| anyhow!("cannot leave a project before joining it"))?;
492        share.active_replica_ids.remove(&replica_id);
493
494        if let Some(connection) = self.connections.get_mut(&connection_id) {
495            connection.projects.remove(&project_id);
496        }
497
498        let connection_ids = project.connection_ids();
499        let authorized_user_ids = project.authorized_user_ids();
500
501        #[cfg(test)]
502        self.check_invariants();
503
504        Ok(LeftProject {
505            connection_ids,
506            authorized_user_ids,
507        })
508    }
509
510    pub fn update_worktree(
511        &mut self,
512        connection_id: ConnectionId,
513        project_id: u64,
514        worktree_id: u64,
515        removed_entries: &[u64],
516        updated_entries: &[proto::Entry],
517    ) -> tide::Result<Vec<ConnectionId>> {
518        let project = self.write_project(project_id, connection_id)?;
519        let worktree = project
520            .share_mut()?
521            .worktrees
522            .get_mut(&worktree_id)
523            .ok_or_else(|| anyhow!("no such worktree"))?;
524        for entry_id in removed_entries {
525            worktree.entries.remove(&entry_id);
526        }
527        for entry in updated_entries {
528            worktree.entries.insert(entry.id, entry.clone());
529        }
530        Ok(project.connection_ids())
531    }
532
533    pub fn project_connection_ids(
534        &self,
535        project_id: u64,
536        acting_connection_id: ConnectionId,
537    ) -> tide::Result<Vec<ConnectionId>> {
538        Ok(self
539            .read_project(project_id, acting_connection_id)?
540            .connection_ids())
541    }
542
543    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> tide::Result<Vec<ConnectionId>> {
544        Ok(self
545            .channels
546            .get(&channel_id)
547            .ok_or_else(|| anyhow!("no such channel"))?
548            .connection_ids())
549    }
550
551    #[cfg(test)]
552    pub fn project(&self, project_id: u64) -> Option<&Project> {
553        self.projects.get(&project_id)
554    }
555
556    pub fn read_project(
557        &self,
558        project_id: u64,
559        connection_id: ConnectionId,
560    ) -> tide::Result<&Project> {
561        let project = self
562            .projects
563            .get(&project_id)
564            .ok_or_else(|| anyhow!("no such project"))?;
565        if project.host_connection_id == connection_id
566            || project
567                .share
568                .as_ref()
569                .ok_or_else(|| anyhow!("project is not shared"))?
570                .guests
571                .contains_key(&connection_id)
572        {
573            Ok(project)
574        } else {
575            Err(anyhow!("no such project"))?
576        }
577    }
578
579    fn write_project(
580        &mut self,
581        project_id: u64,
582        connection_id: ConnectionId,
583    ) -> tide::Result<&mut Project> {
584        let project = self
585            .projects
586            .get_mut(&project_id)
587            .ok_or_else(|| anyhow!("no such project"))?;
588        if project.host_connection_id == connection_id
589            || project
590                .share
591                .as_ref()
592                .ok_or_else(|| anyhow!("project is not shared"))?
593                .guests
594                .contains_key(&connection_id)
595        {
596            Ok(project)
597        } else {
598            Err(anyhow!("no such project"))?
599        }
600    }
601
602    #[cfg(test)]
603    fn check_invariants(&self) {
604        for (connection_id, connection) in &self.connections {
605            for project_id in &connection.projects {
606                let project = &self.projects.get(&project_id).unwrap();
607                if project.host_connection_id != *connection_id {
608                    assert!(project
609                        .share
610                        .as_ref()
611                        .unwrap()
612                        .guests
613                        .contains_key(connection_id));
614                }
615            }
616            for channel_id in &connection.channels {
617                let channel = self.channels.get(channel_id).unwrap();
618                assert!(channel.connection_ids.contains(connection_id));
619            }
620            assert!(self
621                .connections_by_user_id
622                .get(&connection.user_id)
623                .unwrap()
624                .contains(connection_id));
625        }
626
627        for (user_id, connection_ids) in &self.connections_by_user_id {
628            for connection_id in connection_ids {
629                assert_eq!(
630                    self.connections.get(connection_id).unwrap().user_id,
631                    *user_id
632                );
633            }
634        }
635
636        for (project_id, project) in &self.projects {
637            let host_connection = self.connections.get(&project.host_connection_id).unwrap();
638            assert!(host_connection.projects.contains(project_id));
639
640            for authorized_user_ids in project.authorized_user_ids() {
641                let visible_project_ids = self
642                    .visible_projects_by_user_id
643                    .get(&authorized_user_ids)
644                    .unwrap();
645                assert!(visible_project_ids.contains(project_id));
646            }
647
648            if let Some(share) = &project.share {
649                for guest_connection_id in share.guests.keys() {
650                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
651                    assert!(guest_connection.projects.contains(project_id));
652                }
653                assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
654                assert_eq!(
655                    share.active_replica_ids,
656                    share
657                        .guests
658                        .values()
659                        .map(|(replica_id, _)| *replica_id)
660                        .collect::<HashSet<_>>(),
661                );
662            }
663        }
664
665        for (user_id, visible_project_ids) in &self.visible_projects_by_user_id {
666            for project_id in visible_project_ids {
667                let project = self.projects.get(project_id).unwrap();
668                assert!(project.authorized_user_ids().contains(user_id));
669            }
670        }
671
672        for (channel_id, channel) in &self.channels {
673            for connection_id in &channel.connection_ids {
674                let connection = self.connections.get(connection_id).unwrap();
675                assert!(connection.channels.contains(channel_id));
676            }
677        }
678    }
679}
680
681impl Project {
682    pub fn has_authorized_user_id(&self, user_id: UserId) -> bool {
683        self.worktrees
684            .values()
685            .any(|worktree| worktree.authorized_user_ids.contains(&user_id))
686    }
687
688    pub fn authorized_user_ids(&self) -> Vec<UserId> {
689        let mut ids = self
690            .worktrees
691            .values()
692            .flat_map(|worktree| worktree.authorized_user_ids.iter())
693            .copied()
694            .collect::<Vec<_>>();
695        ids.sort_unstable();
696        ids.dedup();
697        ids
698    }
699
700    pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
701        if let Some(share) = &self.share {
702            share.guests.keys().copied().collect()
703        } else {
704            Vec::new()
705        }
706    }
707
708    pub fn connection_ids(&self) -> Vec<ConnectionId> {
709        if let Some(share) = &self.share {
710            share
711                .guests
712                .keys()
713                .copied()
714                .chain(Some(self.host_connection_id))
715                .collect()
716        } else {
717            vec![self.host_connection_id]
718        }
719    }
720
721    pub fn share(&self) -> tide::Result<&ProjectShare> {
722        Ok(self
723            .share
724            .as_ref()
725            .ok_or_else(|| anyhow!("worktree is not shared"))?)
726    }
727
728    fn share_mut(&mut self) -> tide::Result<&mut ProjectShare> {
729        Ok(self
730            .share
731            .as_mut()
732            .ok_or_else(|| anyhow!("worktree is not shared"))?)
733    }
734}
735
736impl Channel {
737    fn connection_ids(&self) -> Vec<ConnectionId> {
738        self.connection_ids.iter().copied().collect()
739    }
740}