store.rs

  1use crate::db::{ChannelId, UserId};
  2use anyhow::{anyhow, Result};
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use std::{collections::hash_map, path::PathBuf};
  6use tracing::instrument;
  7
  8#[derive(Default)]
  9pub struct Store {
 10    connections: HashMap<ConnectionId, ConnectionState>,
 11    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 12    projects: HashMap<u64, Project>,
 13    visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
 14    channels: HashMap<ChannelId, Channel>,
 15    next_project_id: u64,
 16}
 17
 18struct ConnectionState {
 19    user_id: UserId,
 20    projects: HashSet<u64>,
 21    channels: HashSet<ChannelId>,
 22}
 23
 24pub struct Project {
 25    pub host_connection_id: ConnectionId,
 26    pub host_user_id: UserId,
 27    pub share: Option<ProjectShare>,
 28    pub worktrees: HashMap<u64, Worktree>,
 29    pub language_servers: Vec<proto::LanguageServer>,
 30}
 31
 32pub struct Worktree {
 33    pub authorized_user_ids: Vec<UserId>,
 34    pub root_name: String,
 35    pub visible: bool,
 36}
 37
 38#[derive(Default)]
 39pub struct ProjectShare {
 40    pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
 41    pub active_replica_ids: HashSet<ReplicaId>,
 42    pub worktrees: HashMap<u64, WorktreeShare>,
 43}
 44
 45#[derive(Default)]
 46pub struct WorktreeShare {
 47    pub entries: HashMap<u64, proto::Entry>,
 48    pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
 49}
 50
 51#[derive(Default)]
 52pub struct Channel {
 53    pub connection_ids: HashSet<ConnectionId>,
 54}
 55
 56pub type ReplicaId = u16;
 57
 58#[derive(Default)]
 59pub struct RemovedConnectionState {
 60    pub hosted_projects: HashMap<u64, Project>,
 61    pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
 62    pub contact_ids: HashSet<UserId>,
 63}
 64
 65pub struct JoinedProject<'a> {
 66    pub replica_id: ReplicaId,
 67    pub project: &'a Project,
 68}
 69
 70pub struct SharedProject {
 71    pub authorized_user_ids: Vec<UserId>,
 72}
 73
 74pub struct UnsharedProject {
 75    pub connection_ids: Vec<ConnectionId>,
 76    pub authorized_user_ids: Vec<UserId>,
 77}
 78
 79pub struct LeftProject {
 80    pub connection_ids: Vec<ConnectionId>,
 81    pub authorized_user_ids: Vec<UserId>,
 82}
 83
 84#[derive(Copy, Clone)]
 85pub struct Metrics {
 86    pub connections: usize,
 87    pub registered_projects: usize,
 88    pub shared_projects: usize,
 89}
 90
 91impl Store {
 92    pub fn metrics(&self) -> Metrics {
 93        let connections = self.connections.len();
 94        let mut registered_projects = 0;
 95        let mut shared_projects = 0;
 96        for project in self.projects.values() {
 97            registered_projects += 1;
 98            if project.share.is_some() {
 99                shared_projects += 1;
100            }
101        }
102
103        Metrics {
104            connections,
105            registered_projects,
106            shared_projects,
107        }
108    }
109
110    #[instrument(skip(self))]
111    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
112        self.connections.insert(
113            connection_id,
114            ConnectionState {
115                user_id,
116                projects: Default::default(),
117                channels: Default::default(),
118            },
119        );
120        self.connections_by_user_id
121            .entry(user_id)
122            .or_default()
123            .insert(connection_id);
124    }
125
126    #[instrument(skip(self))]
127    pub fn remove_connection(
128        &mut self,
129        connection_id: ConnectionId,
130    ) -> Result<RemovedConnectionState> {
131        let connection = if let Some(connection) = self.connections.remove(&connection_id) {
132            connection
133        } else {
134            return Err(anyhow!("no such connection"))?;
135        };
136
137        for channel_id in &connection.channels {
138            if let Some(channel) = self.channels.get_mut(&channel_id) {
139                channel.connection_ids.remove(&connection_id);
140            }
141        }
142
143        let user_connections = self
144            .connections_by_user_id
145            .get_mut(&connection.user_id)
146            .unwrap();
147        user_connections.remove(&connection_id);
148        if user_connections.is_empty() {
149            self.connections_by_user_id.remove(&connection.user_id);
150        }
151
152        let mut result = RemovedConnectionState::default();
153        for project_id in connection.projects.clone() {
154            if let Ok(project) = self.unregister_project(project_id, connection_id) {
155                result.contact_ids.extend(project.authorized_user_ids());
156                result.hosted_projects.insert(project_id, project);
157            } else if let Ok(project) = self.leave_project(connection_id, project_id) {
158                result
159                    .guest_project_ids
160                    .insert(project_id, project.connection_ids);
161                result.contact_ids.extend(project.authorized_user_ids);
162            }
163        }
164
165        Ok(result)
166    }
167
168    #[cfg(test)]
169    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
170        self.channels.get(&id)
171    }
172
173    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
174        if let Some(connection) = self.connections.get_mut(&connection_id) {
175            connection.channels.insert(channel_id);
176            self.channels
177                .entry(channel_id)
178                .or_default()
179                .connection_ids
180                .insert(connection_id);
181        }
182    }
183
184    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
185        if let Some(connection) = self.connections.get_mut(&connection_id) {
186            connection.channels.remove(&channel_id);
187            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
188                entry.get_mut().connection_ids.remove(&connection_id);
189                if entry.get_mut().connection_ids.is_empty() {
190                    entry.remove();
191                }
192            }
193        }
194    }
195
196    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> Result<UserId> {
197        Ok(self
198            .connections
199            .get(&connection_id)
200            .ok_or_else(|| anyhow!("unknown connection"))?
201            .user_id)
202    }
203
204    pub fn connection_ids_for_user<'a>(
205        &'a self,
206        user_id: UserId,
207    ) -> impl 'a + Iterator<Item = ConnectionId> {
208        self.connections_by_user_id
209            .get(&user_id)
210            .into_iter()
211            .flatten()
212            .copied()
213    }
214
215    pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
216        let mut contacts = HashMap::default();
217        for project_id in self
218            .visible_projects_by_user_id
219            .get(&user_id)
220            .unwrap_or(&HashSet::default())
221        {
222            let project = &self.projects[project_id];
223
224            let mut guests = HashSet::default();
225            if let Ok(share) = project.share() {
226                for guest_connection_id in share.guests.keys() {
227                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
228                        guests.insert(user_id.to_proto());
229                    }
230                }
231            }
232
233            if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
234                let mut worktree_root_names = project
235                    .worktrees
236                    .values()
237                    .filter(|worktree| worktree.visible)
238                    .map(|worktree| worktree.root_name.clone())
239                    .collect::<Vec<_>>();
240                worktree_root_names.sort_unstable();
241                contacts
242                    .entry(host_user_id)
243                    .or_insert_with(|| proto::Contact {
244                        user_id: host_user_id.to_proto(),
245                        projects: Vec::new(),
246                    })
247                    .projects
248                    .push(proto::ProjectMetadata {
249                        id: *project_id,
250                        worktree_root_names,
251                        is_shared: project.share.is_some(),
252                        guests: guests.into_iter().collect(),
253                    });
254            }
255        }
256
257        contacts.into_values().collect()
258    }
259
260    pub fn register_project(
261        &mut self,
262        host_connection_id: ConnectionId,
263        host_user_id: UserId,
264    ) -> u64 {
265        let project_id = self.next_project_id;
266        self.projects.insert(
267            project_id,
268            Project {
269                host_connection_id,
270                host_user_id,
271                share: None,
272                worktrees: Default::default(),
273                language_servers: Default::default(),
274            },
275        );
276        if let Some(connection) = self.connections.get_mut(&host_connection_id) {
277            connection.projects.insert(project_id);
278        }
279        self.next_project_id += 1;
280        project_id
281    }
282
283    pub fn register_worktree(
284        &mut self,
285        project_id: u64,
286        worktree_id: u64,
287        connection_id: ConnectionId,
288        worktree: Worktree,
289    ) -> Result<()> {
290        let project = self
291            .projects
292            .get_mut(&project_id)
293            .ok_or_else(|| anyhow!("no such project"))?;
294        if project.host_connection_id == connection_id {
295            for authorized_user_id in &worktree.authorized_user_ids {
296                self.visible_projects_by_user_id
297                    .entry(*authorized_user_id)
298                    .or_default()
299                    .insert(project_id);
300            }
301
302            project.worktrees.insert(worktree_id, worktree);
303            if let Ok(share) = project.share_mut() {
304                share.worktrees.insert(worktree_id, Default::default());
305            }
306
307            Ok(())
308        } else {
309            Err(anyhow!("no such project"))?
310        }
311    }
312
313    pub fn unregister_project(
314        &mut self,
315        project_id: u64,
316        connection_id: ConnectionId,
317    ) -> Result<Project> {
318        match self.projects.entry(project_id) {
319            hash_map::Entry::Occupied(e) => {
320                if e.get().host_connection_id == connection_id {
321                    for user_id in e.get().authorized_user_ids() {
322                        if let hash_map::Entry::Occupied(mut projects) =
323                            self.visible_projects_by_user_id.entry(user_id)
324                        {
325                            projects.get_mut().remove(&project_id);
326                        }
327                    }
328
329                    let project = e.remove();
330
331                    if let Some(host_connection) = self.connections.get_mut(&connection_id) {
332                        host_connection.projects.remove(&project_id);
333                    }
334
335                    if let Some(share) = &project.share {
336                        for guest_connection in share.guests.keys() {
337                            if let Some(connection) = self.connections.get_mut(&guest_connection) {
338                                connection.projects.remove(&project_id);
339                            }
340                        }
341                    }
342
343                    Ok(project)
344                } else {
345                    Err(anyhow!("no such project"))?
346                }
347            }
348            hash_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?,
349        }
350    }
351
352    pub fn unregister_worktree(
353        &mut self,
354        project_id: u64,
355        worktree_id: u64,
356        acting_connection_id: ConnectionId,
357    ) -> Result<(Worktree, Vec<ConnectionId>)> {
358        let project = self
359            .projects
360            .get_mut(&project_id)
361            .ok_or_else(|| anyhow!("no such project"))?;
362        if project.host_connection_id != acting_connection_id {
363            Err(anyhow!("not your worktree"))?;
364        }
365
366        let worktree = project
367            .worktrees
368            .remove(&worktree_id)
369            .ok_or_else(|| anyhow!("no such worktree"))?;
370
371        let mut guest_connection_ids = Vec::new();
372        if let Ok(share) = project.share_mut() {
373            guest_connection_ids.extend(share.guests.keys());
374            share.worktrees.remove(&worktree_id);
375        }
376
377        for authorized_user_id in &worktree.authorized_user_ids {
378            if let Some(visible_projects) =
379                self.visible_projects_by_user_id.get_mut(authorized_user_id)
380            {
381                if !project.has_authorized_user_id(*authorized_user_id) {
382                    visible_projects.remove(&project_id);
383                }
384            }
385        }
386
387        Ok((worktree, guest_connection_ids))
388    }
389
390    pub fn share_project(
391        &mut self,
392        project_id: u64,
393        connection_id: ConnectionId,
394    ) -> Result<SharedProject> {
395        if let Some(project) = self.projects.get_mut(&project_id) {
396            if project.host_connection_id == connection_id {
397                let mut share = ProjectShare::default();
398                for worktree_id in project.worktrees.keys() {
399                    share.worktrees.insert(*worktree_id, Default::default());
400                }
401                project.share = Some(share);
402                return Ok(SharedProject {
403                    authorized_user_ids: project.authorized_user_ids(),
404                });
405            }
406        }
407        Err(anyhow!("no such project"))?
408    }
409
410    pub fn unshare_project(
411        &mut self,
412        project_id: u64,
413        acting_connection_id: ConnectionId,
414    ) -> Result<UnsharedProject> {
415        let project = if let Some(project) = self.projects.get_mut(&project_id) {
416            project
417        } else {
418            return Err(anyhow!("no such project"))?;
419        };
420
421        if project.host_connection_id != acting_connection_id {
422            return Err(anyhow!("not your project"))?;
423        }
424
425        let connection_ids = project.connection_ids();
426        let authorized_user_ids = project.authorized_user_ids();
427        if let Some(share) = project.share.take() {
428            for connection_id in share.guests.into_keys() {
429                if let Some(connection) = self.connections.get_mut(&connection_id) {
430                    connection.projects.remove(&project_id);
431                }
432            }
433
434            Ok(UnsharedProject {
435                connection_ids,
436                authorized_user_ids,
437            })
438        } else {
439            Err(anyhow!("project is not shared"))?
440        }
441    }
442
443    pub fn update_diagnostic_summary(
444        &mut self,
445        project_id: u64,
446        worktree_id: u64,
447        connection_id: ConnectionId,
448        summary: proto::DiagnosticSummary,
449    ) -> Result<Vec<ConnectionId>> {
450        let project = self
451            .projects
452            .get_mut(&project_id)
453            .ok_or_else(|| anyhow!("no such project"))?;
454        if project.host_connection_id == connection_id {
455            let worktree = project
456                .share_mut()?
457                .worktrees
458                .get_mut(&worktree_id)
459                .ok_or_else(|| anyhow!("no such worktree"))?;
460            worktree
461                .diagnostic_summaries
462                .insert(summary.path.clone().into(), summary);
463            return Ok(project.connection_ids());
464        }
465
466        Err(anyhow!("no such worktree"))?
467    }
468
469    pub fn start_language_server(
470        &mut self,
471        project_id: u64,
472        connection_id: ConnectionId,
473        language_server: proto::LanguageServer,
474    ) -> Result<Vec<ConnectionId>> {
475        let project = self
476            .projects
477            .get_mut(&project_id)
478            .ok_or_else(|| anyhow!("no such project"))?;
479        if project.host_connection_id == connection_id {
480            project.language_servers.push(language_server);
481            return Ok(project.connection_ids());
482        }
483
484        Err(anyhow!("no such project"))?
485    }
486
487    pub fn join_project(
488        &mut self,
489        connection_id: ConnectionId,
490        user_id: UserId,
491        project_id: u64,
492    ) -> Result<JoinedProject> {
493        let connection = self
494            .connections
495            .get_mut(&connection_id)
496            .ok_or_else(|| anyhow!("no such connection"))?;
497        let project = self
498            .projects
499            .get_mut(&project_id)
500            .and_then(|project| {
501                if project.has_authorized_user_id(user_id) {
502                    Some(project)
503                } else {
504                    None
505                }
506            })
507            .ok_or_else(|| anyhow!("no such project"))?;
508
509        let share = project.share_mut()?;
510        connection.projects.insert(project_id);
511
512        let mut replica_id = 1;
513        while share.active_replica_ids.contains(&replica_id) {
514            replica_id += 1;
515        }
516        share.active_replica_ids.insert(replica_id);
517        share.guests.insert(connection_id, (replica_id, user_id));
518
519        Ok(JoinedProject {
520            replica_id,
521            project: &self.projects[&project_id],
522        })
523    }
524
525    pub fn leave_project(
526        &mut self,
527        connection_id: ConnectionId,
528        project_id: u64,
529    ) -> Result<LeftProject> {
530        let project = self
531            .projects
532            .get_mut(&project_id)
533            .ok_or_else(|| anyhow!("no such project"))?;
534        let share = project
535            .share
536            .as_mut()
537            .ok_or_else(|| anyhow!("project is not shared"))?;
538        let (replica_id, _) = share
539            .guests
540            .remove(&connection_id)
541            .ok_or_else(|| anyhow!("cannot leave a project before joining it"))?;
542        share.active_replica_ids.remove(&replica_id);
543
544        if let Some(connection) = self.connections.get_mut(&connection_id) {
545            connection.projects.remove(&project_id);
546        }
547
548        let connection_ids = project.connection_ids();
549        let authorized_user_ids = project.authorized_user_ids();
550
551        Ok(LeftProject {
552            connection_ids,
553            authorized_user_ids,
554        })
555    }
556
557    pub fn update_worktree(
558        &mut self,
559        connection_id: ConnectionId,
560        project_id: u64,
561        worktree_id: u64,
562        removed_entries: &[u64],
563        updated_entries: &[proto::Entry],
564    ) -> Result<Vec<ConnectionId>> {
565        let project = self.write_project(project_id, connection_id)?;
566        let worktree = project
567            .share_mut()?
568            .worktrees
569            .get_mut(&worktree_id)
570            .ok_or_else(|| anyhow!("no such worktree"))?;
571        for entry_id in removed_entries {
572            worktree.entries.remove(&entry_id);
573        }
574        for entry in updated_entries {
575            worktree.entries.insert(entry.id, entry.clone());
576        }
577        let connection_ids = project.connection_ids();
578        Ok(connection_ids)
579    }
580
581    pub fn project_connection_ids(
582        &self,
583        project_id: u64,
584        acting_connection_id: ConnectionId,
585    ) -> Result<Vec<ConnectionId>> {
586        Ok(self
587            .read_project(project_id, acting_connection_id)?
588            .connection_ids())
589    }
590
591    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Result<Vec<ConnectionId>> {
592        Ok(self
593            .channels
594            .get(&channel_id)
595            .ok_or_else(|| anyhow!("no such channel"))?
596            .connection_ids())
597    }
598
599    #[cfg(test)]
600    pub fn project(&self, project_id: u64) -> Option<&Project> {
601        self.projects.get(&project_id)
602    }
603
604    pub fn read_project(&self, project_id: u64, connection_id: ConnectionId) -> Result<&Project> {
605        let project = self
606            .projects
607            .get(&project_id)
608            .ok_or_else(|| anyhow!("no such project"))?;
609        if project.host_connection_id == connection_id
610            || project
611                .share
612                .as_ref()
613                .ok_or_else(|| anyhow!("project is not shared"))?
614                .guests
615                .contains_key(&connection_id)
616        {
617            Ok(project)
618        } else {
619            Err(anyhow!("no such project"))?
620        }
621    }
622
623    fn write_project(
624        &mut self,
625        project_id: u64,
626        connection_id: ConnectionId,
627    ) -> Result<&mut Project> {
628        let project = self
629            .projects
630            .get_mut(&project_id)
631            .ok_or_else(|| anyhow!("no such project"))?;
632        if project.host_connection_id == connection_id
633            || project
634                .share
635                .as_ref()
636                .ok_or_else(|| anyhow!("project is not shared"))?
637                .guests
638                .contains_key(&connection_id)
639        {
640            Ok(project)
641        } else {
642            Err(anyhow!("no such project"))?
643        }
644    }
645
646    #[cfg(test)]
647    pub fn check_invariants(&self) {
648        for (connection_id, connection) in &self.connections {
649            for project_id in &connection.projects {
650                let project = &self.projects.get(&project_id).unwrap();
651                if project.host_connection_id != *connection_id {
652                    assert!(project
653                        .share
654                        .as_ref()
655                        .unwrap()
656                        .guests
657                        .contains_key(connection_id));
658                }
659
660                if let Some(share) = project.share.as_ref() {
661                    for (worktree_id, worktree) in share.worktrees.iter() {
662                        let mut paths = HashMap::default();
663                        for entry in worktree.entries.values() {
664                            let prev_entry = paths.insert(&entry.path, entry);
665                            assert_eq!(
666                                prev_entry,
667                                None,
668                                "worktree {:?}, duplicate path for entries {:?} and {:?}",
669                                worktree_id,
670                                prev_entry.unwrap(),
671                                entry
672                            );
673                        }
674                    }
675                }
676            }
677            for channel_id in &connection.channels {
678                let channel = self.channels.get(channel_id).unwrap();
679                assert!(channel.connection_ids.contains(connection_id));
680            }
681            assert!(self
682                .connections_by_user_id
683                .get(&connection.user_id)
684                .unwrap()
685                .contains(connection_id));
686        }
687
688        for (user_id, connection_ids) in &self.connections_by_user_id {
689            for connection_id in connection_ids {
690                assert_eq!(
691                    self.connections.get(connection_id).unwrap().user_id,
692                    *user_id
693                );
694            }
695        }
696
697        for (project_id, project) in &self.projects {
698            let host_connection = self.connections.get(&project.host_connection_id).unwrap();
699            assert!(host_connection.projects.contains(project_id));
700
701            for authorized_user_ids in project.authorized_user_ids() {
702                let visible_project_ids = self
703                    .visible_projects_by_user_id
704                    .get(&authorized_user_ids)
705                    .unwrap();
706                assert!(visible_project_ids.contains(project_id));
707            }
708
709            if let Some(share) = &project.share {
710                for guest_connection_id in share.guests.keys() {
711                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
712                    assert!(guest_connection.projects.contains(project_id));
713                }
714                assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
715                assert_eq!(
716                    share.active_replica_ids,
717                    share
718                        .guests
719                        .values()
720                        .map(|(replica_id, _)| *replica_id)
721                        .collect::<HashSet<_>>(),
722                );
723            }
724        }
725
726        for (user_id, visible_project_ids) in &self.visible_projects_by_user_id {
727            for project_id in visible_project_ids {
728                let project = self.projects.get(project_id).unwrap();
729                assert!(project.authorized_user_ids().contains(user_id));
730            }
731        }
732
733        for (channel_id, channel) in &self.channels {
734            for connection_id in &channel.connection_ids {
735                let connection = self.connections.get(connection_id).unwrap();
736                assert!(connection.channels.contains(channel_id));
737            }
738        }
739    }
740}
741
742impl Project {
743    pub fn has_authorized_user_id(&self, user_id: UserId) -> bool {
744        self.worktrees
745            .values()
746            .any(|worktree| worktree.authorized_user_ids.contains(&user_id))
747    }
748
749    pub fn authorized_user_ids(&self) -> Vec<UserId> {
750        let mut ids = self
751            .worktrees
752            .values()
753            .flat_map(|worktree| worktree.authorized_user_ids.iter())
754            .copied()
755            .collect::<Vec<_>>();
756        ids.sort_unstable();
757        ids.dedup();
758        ids
759    }
760
761    pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
762        if let Some(share) = &self.share {
763            share.guests.keys().copied().collect()
764        } else {
765            Vec::new()
766        }
767    }
768
769    pub fn connection_ids(&self) -> Vec<ConnectionId> {
770        if let Some(share) = &self.share {
771            share
772                .guests
773                .keys()
774                .copied()
775                .chain(Some(self.host_connection_id))
776                .collect()
777        } else {
778            vec![self.host_connection_id]
779        }
780    }
781
782    pub fn share(&self) -> Result<&ProjectShare> {
783        Ok(self
784            .share
785            .as_ref()
786            .ok_or_else(|| anyhow!("worktree is not shared"))?)
787    }
788
789    fn share_mut(&mut self) -> Result<&mut ProjectShare> {
790        Ok(self
791            .share
792            .as_mut()
793            .ok_or_else(|| anyhow!("worktree is not shared"))?)
794    }
795}
796
797impl Channel {
798    fn connection_ids(&self) -> Vec<ConnectionId> {
799        self.connection_ids.iter().copied().collect()
800    }
801}