store.rs

  1use crate::db::{ChannelId, UserId};
  2use anyhow::{anyhow, Result};
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use std::{collections::hash_map, path::PathBuf};
  6use tracing::instrument;
  7
  8#[derive(Default)]
  9pub struct Store {
 10    connections: HashMap<ConnectionId, ConnectionState>,
 11    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 12    projects: HashMap<u64, Project>,
 13    visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
 14    channels: HashMap<ChannelId, Channel>,
 15    next_project_id: u64,
 16}
 17
 18struct ConnectionState {
 19    user_id: UserId,
 20    projects: HashSet<u64>,
 21    channels: HashSet<ChannelId>,
 22}
 23
 24pub struct Project {
 25    pub host_connection_id: ConnectionId,
 26    pub host_user_id: UserId,
 27    pub share: Option<ProjectShare>,
 28    pub worktrees: HashMap<u64, Worktree>,
 29    pub language_servers: Vec<proto::LanguageServer>,
 30}
 31
 32pub struct Worktree {
 33    pub authorized_user_ids: Vec<UserId>,
 34    pub root_name: String,
 35    pub visible: bool,
 36}
 37
 38#[derive(Default)]
 39pub struct ProjectShare {
 40    pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
 41    pub active_replica_ids: HashSet<ReplicaId>,
 42    pub worktrees: HashMap<u64, WorktreeShare>,
 43}
 44
 45#[derive(Default)]
 46pub struct WorktreeShare {
 47    pub entries: HashMap<u64, proto::Entry>,
 48    pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
 49    pub scan_id: u64,
 50}
 51
 52#[derive(Default)]
 53pub struct Channel {
 54    pub connection_ids: HashSet<ConnectionId>,
 55}
 56
 57pub type ReplicaId = u16;
 58
 59#[derive(Default)]
 60pub struct RemovedConnectionState {
 61    pub hosted_projects: HashMap<u64, Project>,
 62    pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
 63    pub contact_ids: HashSet<UserId>,
 64}
 65
 66pub struct JoinedProject<'a> {
 67    pub replica_id: ReplicaId,
 68    pub project: &'a Project,
 69}
 70
 71pub struct SharedProject {
 72    pub authorized_user_ids: Vec<UserId>,
 73}
 74
 75pub struct UnsharedProject {
 76    pub connection_ids: Vec<ConnectionId>,
 77    pub authorized_user_ids: Vec<UserId>,
 78}
 79
 80pub struct LeftProject {
 81    pub connection_ids: Vec<ConnectionId>,
 82    pub authorized_user_ids: Vec<UserId>,
 83}
 84
 85#[derive(Copy, Clone)]
 86pub struct Metrics {
 87    pub connections: usize,
 88    pub registered_projects: usize,
 89    pub shared_projects: usize,
 90}
 91
 92impl Store {
 93    pub fn metrics(&self) -> Metrics {
 94        let connections = self.connections.len();
 95        let mut registered_projects = 0;
 96        let mut shared_projects = 0;
 97        for project in self.projects.values() {
 98            registered_projects += 1;
 99            if project.share.is_some() {
100                shared_projects += 1;
101            }
102        }
103
104        Metrics {
105            connections,
106            registered_projects,
107            shared_projects,
108        }
109    }
110
111    #[instrument(skip(self))]
112    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
113        self.connections.insert(
114            connection_id,
115            ConnectionState {
116                user_id,
117                projects: Default::default(),
118                channels: Default::default(),
119            },
120        );
121        self.connections_by_user_id
122            .entry(user_id)
123            .or_default()
124            .insert(connection_id);
125    }
126
127    #[instrument(skip(self))]
128    pub fn remove_connection(
129        &mut self,
130        connection_id: ConnectionId,
131    ) -> Result<RemovedConnectionState> {
132        let connection = if let Some(connection) = self.connections.remove(&connection_id) {
133            connection
134        } else {
135            return Err(anyhow!("no such connection"))?;
136        };
137
138        for channel_id in &connection.channels {
139            if let Some(channel) = self.channels.get_mut(&channel_id) {
140                channel.connection_ids.remove(&connection_id);
141            }
142        }
143
144        let user_connections = self
145            .connections_by_user_id
146            .get_mut(&connection.user_id)
147            .unwrap();
148        user_connections.remove(&connection_id);
149        if user_connections.is_empty() {
150            self.connections_by_user_id.remove(&connection.user_id);
151        }
152
153        let mut result = RemovedConnectionState::default();
154        for project_id in connection.projects.clone() {
155            if let Ok(project) = self.unregister_project(project_id, connection_id) {
156                result.contact_ids.extend(project.authorized_user_ids());
157                result.hosted_projects.insert(project_id, project);
158            } else if let Ok(project) = self.leave_project(connection_id, project_id) {
159                result
160                    .guest_project_ids
161                    .insert(project_id, project.connection_ids);
162                result.contact_ids.extend(project.authorized_user_ids);
163            }
164        }
165
166        Ok(result)
167    }
168
169    #[cfg(test)]
170    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
171        self.channels.get(&id)
172    }
173
174    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
175        if let Some(connection) = self.connections.get_mut(&connection_id) {
176            connection.channels.insert(channel_id);
177            self.channels
178                .entry(channel_id)
179                .or_default()
180                .connection_ids
181                .insert(connection_id);
182        }
183    }
184
185    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
186        if let Some(connection) = self.connections.get_mut(&connection_id) {
187            connection.channels.remove(&channel_id);
188            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
189                entry.get_mut().connection_ids.remove(&connection_id);
190                if entry.get_mut().connection_ids.is_empty() {
191                    entry.remove();
192                }
193            }
194        }
195    }
196
197    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> Result<UserId> {
198        Ok(self
199            .connections
200            .get(&connection_id)
201            .ok_or_else(|| anyhow!("unknown connection"))?
202            .user_id)
203    }
204
205    pub fn connection_ids_for_user<'a>(
206        &'a self,
207        user_id: UserId,
208    ) -> impl 'a + Iterator<Item = ConnectionId> {
209        self.connections_by_user_id
210            .get(&user_id)
211            .into_iter()
212            .flatten()
213            .copied()
214    }
215
216    pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
217        let mut contacts = HashMap::default();
218        for project_id in self
219            .visible_projects_by_user_id
220            .get(&user_id)
221            .unwrap_or(&HashSet::default())
222        {
223            let project = &self.projects[project_id];
224
225            let mut guests = HashSet::default();
226            if let Ok(share) = project.share() {
227                for guest_connection_id in share.guests.keys() {
228                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
229                        guests.insert(user_id.to_proto());
230                    }
231                }
232            }
233
234            if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
235                let mut worktree_root_names = project
236                    .worktrees
237                    .values()
238                    .filter(|worktree| worktree.visible)
239                    .map(|worktree| worktree.root_name.clone())
240                    .collect::<Vec<_>>();
241                worktree_root_names.sort_unstable();
242                contacts
243                    .entry(host_user_id)
244                    .or_insert_with(|| proto::Contact {
245                        user_id: host_user_id.to_proto(),
246                        projects: Vec::new(),
247                    })
248                    .projects
249                    .push(proto::ProjectMetadata {
250                        id: *project_id,
251                        worktree_root_names,
252                        is_shared: project.share.is_some(),
253                        guests: guests.into_iter().collect(),
254                    });
255            }
256        }
257
258        contacts.into_values().collect()
259    }
260
261    pub fn register_project(
262        &mut self,
263        host_connection_id: ConnectionId,
264        host_user_id: UserId,
265    ) -> u64 {
266        let project_id = self.next_project_id;
267        self.projects.insert(
268            project_id,
269            Project {
270                host_connection_id,
271                host_user_id,
272                share: None,
273                worktrees: Default::default(),
274                language_servers: Default::default(),
275            },
276        );
277        if let Some(connection) = self.connections.get_mut(&host_connection_id) {
278            connection.projects.insert(project_id);
279        }
280        self.next_project_id += 1;
281        project_id
282    }
283
284    pub fn register_worktree(
285        &mut self,
286        project_id: u64,
287        worktree_id: u64,
288        connection_id: ConnectionId,
289        worktree: Worktree,
290    ) -> Result<()> {
291        let project = self
292            .projects
293            .get_mut(&project_id)
294            .ok_or_else(|| anyhow!("no such project"))?;
295        if project.host_connection_id == connection_id {
296            for authorized_user_id in &worktree.authorized_user_ids {
297                self.visible_projects_by_user_id
298                    .entry(*authorized_user_id)
299                    .or_default()
300                    .insert(project_id);
301            }
302
303            project.worktrees.insert(worktree_id, worktree);
304            if let Ok(share) = project.share_mut() {
305                share.worktrees.insert(worktree_id, Default::default());
306            }
307
308            Ok(())
309        } else {
310            Err(anyhow!("no such project"))?
311        }
312    }
313
314    pub fn unregister_project(
315        &mut self,
316        project_id: u64,
317        connection_id: ConnectionId,
318    ) -> Result<Project> {
319        match self.projects.entry(project_id) {
320            hash_map::Entry::Occupied(e) => {
321                if e.get().host_connection_id == connection_id {
322                    for user_id in e.get().authorized_user_ids() {
323                        if let hash_map::Entry::Occupied(mut projects) =
324                            self.visible_projects_by_user_id.entry(user_id)
325                        {
326                            projects.get_mut().remove(&project_id);
327                        }
328                    }
329
330                    let project = e.remove();
331
332                    if let Some(host_connection) = self.connections.get_mut(&connection_id) {
333                        host_connection.projects.remove(&project_id);
334                    }
335
336                    if let Some(share) = &project.share {
337                        for guest_connection in share.guests.keys() {
338                            if let Some(connection) = self.connections.get_mut(&guest_connection) {
339                                connection.projects.remove(&project_id);
340                            }
341                        }
342                    }
343
344                    Ok(project)
345                } else {
346                    Err(anyhow!("no such project"))?
347                }
348            }
349            hash_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?,
350        }
351    }
352
353    pub fn unregister_worktree(
354        &mut self,
355        project_id: u64,
356        worktree_id: u64,
357        acting_connection_id: ConnectionId,
358    ) -> Result<(Worktree, Vec<ConnectionId>)> {
359        let project = self
360            .projects
361            .get_mut(&project_id)
362            .ok_or_else(|| anyhow!("no such project"))?;
363        if project.host_connection_id != acting_connection_id {
364            Err(anyhow!("not your worktree"))?;
365        }
366
367        let worktree = project
368            .worktrees
369            .remove(&worktree_id)
370            .ok_or_else(|| anyhow!("no such worktree"))?;
371
372        let mut guest_connection_ids = Vec::new();
373        if let Ok(share) = project.share_mut() {
374            guest_connection_ids.extend(share.guests.keys());
375            share.worktrees.remove(&worktree_id);
376        }
377
378        for authorized_user_id in &worktree.authorized_user_ids {
379            if let Some(visible_projects) =
380                self.visible_projects_by_user_id.get_mut(authorized_user_id)
381            {
382                if !project.has_authorized_user_id(*authorized_user_id) {
383                    visible_projects.remove(&project_id);
384                }
385            }
386        }
387
388        Ok((worktree, guest_connection_ids))
389    }
390
391    pub fn share_project(
392        &mut self,
393        project_id: u64,
394        connection_id: ConnectionId,
395    ) -> Result<SharedProject> {
396        if let Some(project) = self.projects.get_mut(&project_id) {
397            if project.host_connection_id == connection_id {
398                let mut share = ProjectShare::default();
399                for worktree_id in project.worktrees.keys() {
400                    share.worktrees.insert(*worktree_id, Default::default());
401                }
402                project.share = Some(share);
403                return Ok(SharedProject {
404                    authorized_user_ids: project.authorized_user_ids(),
405                });
406            }
407        }
408        Err(anyhow!("no such project"))?
409    }
410
411    pub fn unshare_project(
412        &mut self,
413        project_id: u64,
414        acting_connection_id: ConnectionId,
415    ) -> Result<UnsharedProject> {
416        let project = if let Some(project) = self.projects.get_mut(&project_id) {
417            project
418        } else {
419            return Err(anyhow!("no such project"))?;
420        };
421
422        if project.host_connection_id != acting_connection_id {
423            return Err(anyhow!("not your project"))?;
424        }
425
426        let connection_ids = project.connection_ids();
427        let authorized_user_ids = project.authorized_user_ids();
428        if let Some(share) = project.share.take() {
429            for connection_id in share.guests.into_keys() {
430                if let Some(connection) = self.connections.get_mut(&connection_id) {
431                    connection.projects.remove(&project_id);
432                }
433            }
434
435            Ok(UnsharedProject {
436                connection_ids,
437                authorized_user_ids,
438            })
439        } else {
440            Err(anyhow!("project is not shared"))?
441        }
442    }
443
444    pub fn update_diagnostic_summary(
445        &mut self,
446        project_id: u64,
447        worktree_id: u64,
448        connection_id: ConnectionId,
449        summary: proto::DiagnosticSummary,
450    ) -> Result<Vec<ConnectionId>> {
451        let project = self
452            .projects
453            .get_mut(&project_id)
454            .ok_or_else(|| anyhow!("no such project"))?;
455        if project.host_connection_id == connection_id {
456            let worktree = project
457                .share_mut()?
458                .worktrees
459                .get_mut(&worktree_id)
460                .ok_or_else(|| anyhow!("no such worktree"))?;
461            worktree
462                .diagnostic_summaries
463                .insert(summary.path.clone().into(), summary);
464            return Ok(project.connection_ids());
465        }
466
467        Err(anyhow!("no such worktree"))?
468    }
469
470    pub fn start_language_server(
471        &mut self,
472        project_id: u64,
473        connection_id: ConnectionId,
474        language_server: proto::LanguageServer,
475    ) -> Result<Vec<ConnectionId>> {
476        let project = self
477            .projects
478            .get_mut(&project_id)
479            .ok_or_else(|| anyhow!("no such project"))?;
480        if project.host_connection_id == connection_id {
481            project.language_servers.push(language_server);
482            return Ok(project.connection_ids());
483        }
484
485        Err(anyhow!("no such project"))?
486    }
487
488    pub fn join_project(
489        &mut self,
490        connection_id: ConnectionId,
491        user_id: UserId,
492        project_id: u64,
493    ) -> Result<JoinedProject> {
494        let connection = self
495            .connections
496            .get_mut(&connection_id)
497            .ok_or_else(|| anyhow!("no such connection"))?;
498        let project = self
499            .projects
500            .get_mut(&project_id)
501            .and_then(|project| {
502                if project.has_authorized_user_id(user_id) {
503                    Some(project)
504                } else {
505                    None
506                }
507            })
508            .ok_or_else(|| anyhow!("no such project"))?;
509
510        let share = project.share_mut()?;
511        connection.projects.insert(project_id);
512
513        let mut replica_id = 1;
514        while share.active_replica_ids.contains(&replica_id) {
515            replica_id += 1;
516        }
517        share.active_replica_ids.insert(replica_id);
518        share.guests.insert(connection_id, (replica_id, user_id));
519
520        Ok(JoinedProject {
521            replica_id,
522            project: &self.projects[&project_id],
523        })
524    }
525
526    pub fn leave_project(
527        &mut self,
528        connection_id: ConnectionId,
529        project_id: u64,
530    ) -> Result<LeftProject> {
531        let project = self
532            .projects
533            .get_mut(&project_id)
534            .ok_or_else(|| anyhow!("no such project"))?;
535        let share = project
536            .share
537            .as_mut()
538            .ok_or_else(|| anyhow!("project is not shared"))?;
539        let (replica_id, _) = share
540            .guests
541            .remove(&connection_id)
542            .ok_or_else(|| anyhow!("cannot leave a project before joining it"))?;
543        share.active_replica_ids.remove(&replica_id);
544
545        if let Some(connection) = self.connections.get_mut(&connection_id) {
546            connection.projects.remove(&project_id);
547        }
548
549        let connection_ids = project.connection_ids();
550        let authorized_user_ids = project.authorized_user_ids();
551
552        Ok(LeftProject {
553            connection_ids,
554            authorized_user_ids,
555        })
556    }
557
558    pub fn update_worktree(
559        &mut self,
560        connection_id: ConnectionId,
561        project_id: u64,
562        worktree_id: u64,
563        removed_entries: &[u64],
564        updated_entries: &[proto::Entry],
565        scan_id: u64,
566    ) -> Result<Vec<ConnectionId>> {
567        let project = self.write_project(project_id, connection_id)?;
568        let worktree = project
569            .share_mut()?
570            .worktrees
571            .get_mut(&worktree_id)
572            .ok_or_else(|| anyhow!("no such worktree"))?;
573        for entry_id in removed_entries {
574            worktree.entries.remove(&entry_id);
575        }
576        for entry in updated_entries {
577            worktree.entries.insert(entry.id, entry.clone());
578        }
579        worktree.scan_id = scan_id;
580        let connection_ids = project.connection_ids();
581        Ok(connection_ids)
582    }
583
584    pub fn project_connection_ids(
585        &self,
586        project_id: u64,
587        acting_connection_id: ConnectionId,
588    ) -> Result<Vec<ConnectionId>> {
589        Ok(self
590            .read_project(project_id, acting_connection_id)?
591            .connection_ids())
592    }
593
594    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Result<Vec<ConnectionId>> {
595        Ok(self
596            .channels
597            .get(&channel_id)
598            .ok_or_else(|| anyhow!("no such channel"))?
599            .connection_ids())
600    }
601
602    #[cfg(test)]
603    pub fn project(&self, project_id: u64) -> Option<&Project> {
604        self.projects.get(&project_id)
605    }
606
607    pub fn read_project(&self, project_id: u64, connection_id: ConnectionId) -> Result<&Project> {
608        let project = self
609            .projects
610            .get(&project_id)
611            .ok_or_else(|| anyhow!("no such project"))?;
612        if project.host_connection_id == connection_id
613            || project
614                .share
615                .as_ref()
616                .ok_or_else(|| anyhow!("project is not shared"))?
617                .guests
618                .contains_key(&connection_id)
619        {
620            Ok(project)
621        } else {
622            Err(anyhow!("no such project"))?
623        }
624    }
625
626    fn write_project(
627        &mut self,
628        project_id: u64,
629        connection_id: ConnectionId,
630    ) -> Result<&mut Project> {
631        let project = self
632            .projects
633            .get_mut(&project_id)
634            .ok_or_else(|| anyhow!("no such project"))?;
635        if project.host_connection_id == connection_id
636            || project
637                .share
638                .as_ref()
639                .ok_or_else(|| anyhow!("project is not shared"))?
640                .guests
641                .contains_key(&connection_id)
642        {
643            Ok(project)
644        } else {
645            Err(anyhow!("no such project"))?
646        }
647    }
648
649    #[cfg(test)]
650    pub fn check_invariants(&self) {
651        for (connection_id, connection) in &self.connections {
652            for project_id in &connection.projects {
653                let project = &self.projects.get(&project_id).unwrap();
654                if project.host_connection_id != *connection_id {
655                    assert!(project
656                        .share
657                        .as_ref()
658                        .unwrap()
659                        .guests
660                        .contains_key(connection_id));
661                }
662
663                if let Some(share) = project.share.as_ref() {
664                    for (worktree_id, worktree) in share.worktrees.iter() {
665                        let mut paths = HashMap::default();
666                        for entry in worktree.entries.values() {
667                            let prev_entry = paths.insert(&entry.path, entry);
668                            assert_eq!(
669                                prev_entry,
670                                None,
671                                "worktree {:?}, duplicate path for entries {:?} and {:?}",
672                                worktree_id,
673                                prev_entry.unwrap(),
674                                entry
675                            );
676                        }
677                    }
678                }
679            }
680            for channel_id in &connection.channels {
681                let channel = self.channels.get(channel_id).unwrap();
682                assert!(channel.connection_ids.contains(connection_id));
683            }
684            assert!(self
685                .connections_by_user_id
686                .get(&connection.user_id)
687                .unwrap()
688                .contains(connection_id));
689        }
690
691        for (user_id, connection_ids) in &self.connections_by_user_id {
692            for connection_id in connection_ids {
693                assert_eq!(
694                    self.connections.get(connection_id).unwrap().user_id,
695                    *user_id
696                );
697            }
698        }
699
700        for (project_id, project) in &self.projects {
701            let host_connection = self.connections.get(&project.host_connection_id).unwrap();
702            assert!(host_connection.projects.contains(project_id));
703
704            for authorized_user_ids in project.authorized_user_ids() {
705                let visible_project_ids = self
706                    .visible_projects_by_user_id
707                    .get(&authorized_user_ids)
708                    .unwrap();
709                assert!(visible_project_ids.contains(project_id));
710            }
711
712            if let Some(share) = &project.share {
713                for guest_connection_id in share.guests.keys() {
714                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
715                    assert!(guest_connection.projects.contains(project_id));
716                }
717                assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
718                assert_eq!(
719                    share.active_replica_ids,
720                    share
721                        .guests
722                        .values()
723                        .map(|(replica_id, _)| *replica_id)
724                        .collect::<HashSet<_>>(),
725                );
726            }
727        }
728
729        for (user_id, visible_project_ids) in &self.visible_projects_by_user_id {
730            for project_id in visible_project_ids {
731                let project = self.projects.get(project_id).unwrap();
732                assert!(project.authorized_user_ids().contains(user_id));
733            }
734        }
735
736        for (channel_id, channel) in &self.channels {
737            for connection_id in &channel.connection_ids {
738                let connection = self.connections.get(connection_id).unwrap();
739                assert!(connection.channels.contains(channel_id));
740            }
741        }
742    }
743}
744
745impl Project {
746    pub fn has_authorized_user_id(&self, user_id: UserId) -> bool {
747        self.worktrees
748            .values()
749            .any(|worktree| worktree.authorized_user_ids.contains(&user_id))
750    }
751
752    pub fn authorized_user_ids(&self) -> Vec<UserId> {
753        let mut ids = self
754            .worktrees
755            .values()
756            .flat_map(|worktree| worktree.authorized_user_ids.iter())
757            .copied()
758            .collect::<Vec<_>>();
759        ids.sort_unstable();
760        ids.dedup();
761        ids
762    }
763
764    pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
765        if let Some(share) = &self.share {
766            share.guests.keys().copied().collect()
767        } else {
768            Vec::new()
769        }
770    }
771
772    pub fn connection_ids(&self) -> Vec<ConnectionId> {
773        if let Some(share) = &self.share {
774            share
775                .guests
776                .keys()
777                .copied()
778                .chain(Some(self.host_connection_id))
779                .collect()
780        } else {
781            vec![self.host_connection_id]
782        }
783    }
784
785    pub fn share(&self) -> Result<&ProjectShare> {
786        Ok(self
787            .share
788            .as_ref()
789            .ok_or_else(|| anyhow!("worktree is not shared"))?)
790    }
791
792    fn share_mut(&mut self) -> Result<&mut ProjectShare> {
793        Ok(self
794            .share
795            .as_mut()
796            .ok_or_else(|| anyhow!("worktree is not shared"))?)
797    }
798}
799
800impl Channel {
801    fn connection_ids(&self) -> Vec<ConnectionId> {
802        self.connection_ids.iter().copied().collect()
803    }
804}