store.rs

  1use crate::db::{ChannelId, UserId};
  2use anyhow::{anyhow, Result};
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use std::{collections::hash_map, path::PathBuf};
  6
  7#[derive(Default)]
  8pub struct Store {
  9    connections: HashMap<ConnectionId, ConnectionState>,
 10    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 11    projects: HashMap<u64, Project>,
 12    visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
 13    channels: HashMap<ChannelId, Channel>,
 14    next_project_id: u64,
 15}
 16
 17struct ConnectionState {
 18    user_id: UserId,
 19    projects: HashSet<u64>,
 20    channels: HashSet<ChannelId>,
 21}
 22
 23pub struct Project {
 24    pub host_connection_id: ConnectionId,
 25    pub host_user_id: UserId,
 26    pub share: Option<ProjectShare>,
 27    pub worktrees: HashMap<u64, Worktree>,
 28    pub language_servers: Vec<proto::LanguageServer>,
 29}
 30
 31pub struct Worktree {
 32    pub authorized_user_ids: Vec<UserId>,
 33    pub root_name: String,
 34    pub visible: bool,
 35}
 36
 37#[derive(Default)]
 38pub struct ProjectShare {
 39    pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
 40    pub active_replica_ids: HashSet<ReplicaId>,
 41    pub worktrees: HashMap<u64, WorktreeShare>,
 42}
 43
 44#[derive(Default)]
 45pub struct WorktreeShare {
 46    pub entries: HashMap<u64, proto::Entry>,
 47    pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
 48}
 49
 50#[derive(Default)]
 51pub struct Channel {
 52    pub connection_ids: HashSet<ConnectionId>,
 53}
 54
 55pub type ReplicaId = u16;
 56
 57#[derive(Default)]
 58pub struct RemovedConnectionState {
 59    pub hosted_projects: HashMap<u64, Project>,
 60    pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
 61    pub contact_ids: HashSet<UserId>,
 62}
 63
 64pub struct JoinedProject<'a> {
 65    pub replica_id: ReplicaId,
 66    pub project: &'a Project,
 67}
 68
 69pub struct SharedProject {
 70    pub authorized_user_ids: Vec<UserId>,
 71}
 72
 73pub struct UnsharedProject {
 74    pub connection_ids: Vec<ConnectionId>,
 75    pub authorized_user_ids: Vec<UserId>,
 76}
 77
 78pub struct LeftProject {
 79    pub connection_ids: Vec<ConnectionId>,
 80    pub authorized_user_ids: Vec<UserId>,
 81}
 82
 83impl Store {
 84    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 85        self.connections.insert(
 86            connection_id,
 87            ConnectionState {
 88                user_id,
 89                projects: Default::default(),
 90                channels: Default::default(),
 91            },
 92        );
 93        self.connections_by_user_id
 94            .entry(user_id)
 95            .or_default()
 96            .insert(connection_id);
 97    }
 98
 99    pub fn remove_connection(
100        &mut self,
101        connection_id: ConnectionId,
102    ) -> Result<RemovedConnectionState> {
103        let connection = if let Some(connection) = self.connections.remove(&connection_id) {
104            connection
105        } else {
106            return Err(anyhow!("no such connection"))?;
107        };
108
109        for channel_id in &connection.channels {
110            if let Some(channel) = self.channels.get_mut(&channel_id) {
111                channel.connection_ids.remove(&connection_id);
112            }
113        }
114
115        let user_connections = self
116            .connections_by_user_id
117            .get_mut(&connection.user_id)
118            .unwrap();
119        user_connections.remove(&connection_id);
120        if user_connections.is_empty() {
121            self.connections_by_user_id.remove(&connection.user_id);
122        }
123
124        let mut result = RemovedConnectionState::default();
125        for project_id in connection.projects.clone() {
126            if let Ok(project) = self.unregister_project(project_id, connection_id) {
127                result.contact_ids.extend(project.authorized_user_ids());
128                result.hosted_projects.insert(project_id, project);
129            } else if let Ok(project) = self.leave_project(connection_id, project_id) {
130                result
131                    .guest_project_ids
132                    .insert(project_id, project.connection_ids);
133                result.contact_ids.extend(project.authorized_user_ids);
134            }
135        }
136
137        Ok(result)
138    }
139
140    #[cfg(test)]
141    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
142        self.channels.get(&id)
143    }
144
145    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
146        if let Some(connection) = self.connections.get_mut(&connection_id) {
147            connection.channels.insert(channel_id);
148            self.channels
149                .entry(channel_id)
150                .or_default()
151                .connection_ids
152                .insert(connection_id);
153        }
154    }
155
156    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
157        if let Some(connection) = self.connections.get_mut(&connection_id) {
158            connection.channels.remove(&channel_id);
159            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
160                entry.get_mut().connection_ids.remove(&connection_id);
161                if entry.get_mut().connection_ids.is_empty() {
162                    entry.remove();
163                }
164            }
165        }
166    }
167
168    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> Result<UserId> {
169        Ok(self
170            .connections
171            .get(&connection_id)
172            .ok_or_else(|| anyhow!("unknown connection"))?
173            .user_id)
174    }
175
176    pub fn connection_ids_for_user<'a>(
177        &'a self,
178        user_id: UserId,
179    ) -> impl 'a + Iterator<Item = ConnectionId> {
180        self.connections_by_user_id
181            .get(&user_id)
182            .into_iter()
183            .flatten()
184            .copied()
185    }
186
187    pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
188        let mut contacts = HashMap::default();
189        for project_id in self
190            .visible_projects_by_user_id
191            .get(&user_id)
192            .unwrap_or(&HashSet::default())
193        {
194            let project = &self.projects[project_id];
195
196            let mut guests = HashSet::default();
197            if let Ok(share) = project.share() {
198                for guest_connection_id in share.guests.keys() {
199                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
200                        guests.insert(user_id.to_proto());
201                    }
202                }
203            }
204
205            if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
206                let mut worktree_root_names = project
207                    .worktrees
208                    .values()
209                    .filter(|worktree| worktree.visible)
210                    .map(|worktree| worktree.root_name.clone())
211                    .collect::<Vec<_>>();
212                worktree_root_names.sort_unstable();
213                contacts
214                    .entry(host_user_id)
215                    .or_insert_with(|| proto::Contact {
216                        user_id: host_user_id.to_proto(),
217                        projects: Vec::new(),
218                    })
219                    .projects
220                    .push(proto::ProjectMetadata {
221                        id: *project_id,
222                        worktree_root_names,
223                        is_shared: project.share.is_some(),
224                        guests: guests.into_iter().collect(),
225                    });
226            }
227        }
228
229        contacts.into_values().collect()
230    }
231
232    pub fn register_project(
233        &mut self,
234        host_connection_id: ConnectionId,
235        host_user_id: UserId,
236    ) -> u64 {
237        let project_id = self.next_project_id;
238        self.projects.insert(
239            project_id,
240            Project {
241                host_connection_id,
242                host_user_id,
243                share: None,
244                worktrees: Default::default(),
245                language_servers: Default::default(),
246            },
247        );
248        if let Some(connection) = self.connections.get_mut(&host_connection_id) {
249            connection.projects.insert(project_id);
250        }
251        self.next_project_id += 1;
252        project_id
253    }
254
255    pub fn register_worktree(
256        &mut self,
257        project_id: u64,
258        worktree_id: u64,
259        connection_id: ConnectionId,
260        worktree: Worktree,
261    ) -> Result<()> {
262        let project = self
263            .projects
264            .get_mut(&project_id)
265            .ok_or_else(|| anyhow!("no such project"))?;
266        if project.host_connection_id == connection_id {
267            for authorized_user_id in &worktree.authorized_user_ids {
268                self.visible_projects_by_user_id
269                    .entry(*authorized_user_id)
270                    .or_default()
271                    .insert(project_id);
272            }
273
274            project.worktrees.insert(worktree_id, worktree);
275            if let Ok(share) = project.share_mut() {
276                share.worktrees.insert(worktree_id, Default::default());
277            }
278
279            Ok(())
280        } else {
281            Err(anyhow!("no such project"))?
282        }
283    }
284
285    pub fn unregister_project(
286        &mut self,
287        project_id: u64,
288        connection_id: ConnectionId,
289    ) -> Result<Project> {
290        match self.projects.entry(project_id) {
291            hash_map::Entry::Occupied(e) => {
292                if e.get().host_connection_id == connection_id {
293                    for user_id in e.get().authorized_user_ids() {
294                        if let hash_map::Entry::Occupied(mut projects) =
295                            self.visible_projects_by_user_id.entry(user_id)
296                        {
297                            projects.get_mut().remove(&project_id);
298                        }
299                    }
300
301                    let project = e.remove();
302
303                    if let Some(host_connection) = self.connections.get_mut(&connection_id) {
304                        host_connection.projects.remove(&project_id);
305                    }
306
307                    if let Some(share) = &project.share {
308                        for guest_connection in share.guests.keys() {
309                            if let Some(connection) = self.connections.get_mut(&guest_connection) {
310                                connection.projects.remove(&project_id);
311                            }
312                        }
313                    }
314
315                    Ok(project)
316                } else {
317                    Err(anyhow!("no such project"))?
318                }
319            }
320            hash_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?,
321        }
322    }
323
324    pub fn unregister_worktree(
325        &mut self,
326        project_id: u64,
327        worktree_id: u64,
328        acting_connection_id: ConnectionId,
329    ) -> Result<(Worktree, Vec<ConnectionId>)> {
330        let project = self
331            .projects
332            .get_mut(&project_id)
333            .ok_or_else(|| anyhow!("no such project"))?;
334        if project.host_connection_id != acting_connection_id {
335            Err(anyhow!("not your worktree"))?;
336        }
337
338        let worktree = project
339            .worktrees
340            .remove(&worktree_id)
341            .ok_or_else(|| anyhow!("no such worktree"))?;
342
343        let mut guest_connection_ids = Vec::new();
344        if let Ok(share) = project.share_mut() {
345            guest_connection_ids.extend(share.guests.keys());
346            share.worktrees.remove(&worktree_id);
347        }
348
349        for authorized_user_id in &worktree.authorized_user_ids {
350            if let Some(visible_projects) =
351                self.visible_projects_by_user_id.get_mut(authorized_user_id)
352            {
353                if !project.has_authorized_user_id(*authorized_user_id) {
354                    visible_projects.remove(&project_id);
355                }
356            }
357        }
358
359        Ok((worktree, guest_connection_ids))
360    }
361
362    pub fn share_project(
363        &mut self,
364        project_id: u64,
365        connection_id: ConnectionId,
366    ) -> Result<SharedProject> {
367        if let Some(project) = self.projects.get_mut(&project_id) {
368            if project.host_connection_id == connection_id {
369                let mut share = ProjectShare::default();
370                for worktree_id in project.worktrees.keys() {
371                    share.worktrees.insert(*worktree_id, Default::default());
372                }
373                project.share = Some(share);
374                return Ok(SharedProject {
375                    authorized_user_ids: project.authorized_user_ids(),
376                });
377            }
378        }
379        Err(anyhow!("no such project"))?
380    }
381
382    pub fn unshare_project(
383        &mut self,
384        project_id: u64,
385        acting_connection_id: ConnectionId,
386    ) -> Result<UnsharedProject> {
387        let project = if let Some(project) = self.projects.get_mut(&project_id) {
388            project
389        } else {
390            return Err(anyhow!("no such project"))?;
391        };
392
393        if project.host_connection_id != acting_connection_id {
394            return Err(anyhow!("not your project"))?;
395        }
396
397        let connection_ids = project.connection_ids();
398        let authorized_user_ids = project.authorized_user_ids();
399        if let Some(share) = project.share.take() {
400            for connection_id in share.guests.into_keys() {
401                if let Some(connection) = self.connections.get_mut(&connection_id) {
402                    connection.projects.remove(&project_id);
403                }
404            }
405
406            Ok(UnsharedProject {
407                connection_ids,
408                authorized_user_ids,
409            })
410        } else {
411            Err(anyhow!("project is not shared"))?
412        }
413    }
414
415    pub fn update_diagnostic_summary(
416        &mut self,
417        project_id: u64,
418        worktree_id: u64,
419        connection_id: ConnectionId,
420        summary: proto::DiagnosticSummary,
421    ) -> Result<Vec<ConnectionId>> {
422        let project = self
423            .projects
424            .get_mut(&project_id)
425            .ok_or_else(|| anyhow!("no such project"))?;
426        if project.host_connection_id == connection_id {
427            let worktree = project
428                .share_mut()?
429                .worktrees
430                .get_mut(&worktree_id)
431                .ok_or_else(|| anyhow!("no such worktree"))?;
432            worktree
433                .diagnostic_summaries
434                .insert(summary.path.clone().into(), summary);
435            return Ok(project.connection_ids());
436        }
437
438        Err(anyhow!("no such worktree"))?
439    }
440
441    pub fn start_language_server(
442        &mut self,
443        project_id: u64,
444        connection_id: ConnectionId,
445        language_server: proto::LanguageServer,
446    ) -> Result<Vec<ConnectionId>> {
447        let project = self
448            .projects
449            .get_mut(&project_id)
450            .ok_or_else(|| anyhow!("no such project"))?;
451        if project.host_connection_id == connection_id {
452            project.language_servers.push(language_server);
453            return Ok(project.connection_ids());
454        }
455
456        Err(anyhow!("no such project"))?
457    }
458
459    pub fn join_project(
460        &mut self,
461        connection_id: ConnectionId,
462        user_id: UserId,
463        project_id: u64,
464    ) -> Result<JoinedProject> {
465        let connection = self
466            .connections
467            .get_mut(&connection_id)
468            .ok_or_else(|| anyhow!("no such connection"))?;
469        let project = self
470            .projects
471            .get_mut(&project_id)
472            .and_then(|project| {
473                if project.has_authorized_user_id(user_id) {
474                    Some(project)
475                } else {
476                    None
477                }
478            })
479            .ok_or_else(|| anyhow!("no such project"))?;
480
481        let share = project.share_mut()?;
482        connection.projects.insert(project_id);
483
484        let mut replica_id = 1;
485        while share.active_replica_ids.contains(&replica_id) {
486            replica_id += 1;
487        }
488        share.active_replica_ids.insert(replica_id);
489        share.guests.insert(connection_id, (replica_id, user_id));
490
491        Ok(JoinedProject {
492            replica_id,
493            project: &self.projects[&project_id],
494        })
495    }
496
497    pub fn leave_project(
498        &mut self,
499        connection_id: ConnectionId,
500        project_id: u64,
501    ) -> Result<LeftProject> {
502        let project = self
503            .projects
504            .get_mut(&project_id)
505            .ok_or_else(|| anyhow!("no such project"))?;
506        let share = project
507            .share
508            .as_mut()
509            .ok_or_else(|| anyhow!("project is not shared"))?;
510        let (replica_id, _) = share
511            .guests
512            .remove(&connection_id)
513            .ok_or_else(|| anyhow!("cannot leave a project before joining it"))?;
514        share.active_replica_ids.remove(&replica_id);
515
516        if let Some(connection) = self.connections.get_mut(&connection_id) {
517            connection.projects.remove(&project_id);
518        }
519
520        let connection_ids = project.connection_ids();
521        let authorized_user_ids = project.authorized_user_ids();
522
523        Ok(LeftProject {
524            connection_ids,
525            authorized_user_ids,
526        })
527    }
528
529    pub fn update_worktree(
530        &mut self,
531        connection_id: ConnectionId,
532        project_id: u64,
533        worktree_id: u64,
534        removed_entries: &[u64],
535        updated_entries: &[proto::Entry],
536    ) -> Result<Vec<ConnectionId>> {
537        let project = self.write_project(project_id, connection_id)?;
538        let worktree = project
539            .share_mut()?
540            .worktrees
541            .get_mut(&worktree_id)
542            .ok_or_else(|| anyhow!("no such worktree"))?;
543        for entry_id in removed_entries {
544            worktree.entries.remove(&entry_id);
545        }
546        for entry in updated_entries {
547            worktree.entries.insert(entry.id, entry.clone());
548        }
549        let connection_ids = project.connection_ids();
550        Ok(connection_ids)
551    }
552
553    pub fn project_connection_ids(
554        &self,
555        project_id: u64,
556        acting_connection_id: ConnectionId,
557    ) -> Result<Vec<ConnectionId>> {
558        Ok(self
559            .read_project(project_id, acting_connection_id)?
560            .connection_ids())
561    }
562
563    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Result<Vec<ConnectionId>> {
564        Ok(self
565            .channels
566            .get(&channel_id)
567            .ok_or_else(|| anyhow!("no such channel"))?
568            .connection_ids())
569    }
570
571    #[cfg(test)]
572    pub fn project(&self, project_id: u64) -> Option<&Project> {
573        self.projects.get(&project_id)
574    }
575
576    pub fn read_project(&self, project_id: u64, connection_id: ConnectionId) -> Result<&Project> {
577        let project = self
578            .projects
579            .get(&project_id)
580            .ok_or_else(|| anyhow!("no such project"))?;
581        if project.host_connection_id == connection_id
582            || project
583                .share
584                .as_ref()
585                .ok_or_else(|| anyhow!("project is not shared"))?
586                .guests
587                .contains_key(&connection_id)
588        {
589            Ok(project)
590        } else {
591            Err(anyhow!("no such project"))?
592        }
593    }
594
595    fn write_project(
596        &mut self,
597        project_id: u64,
598        connection_id: ConnectionId,
599    ) -> Result<&mut Project> {
600        let project = self
601            .projects
602            .get_mut(&project_id)
603            .ok_or_else(|| anyhow!("no such project"))?;
604        if project.host_connection_id == connection_id
605            || project
606                .share
607                .as_ref()
608                .ok_or_else(|| anyhow!("project is not shared"))?
609                .guests
610                .contains_key(&connection_id)
611        {
612            Ok(project)
613        } else {
614            Err(anyhow!("no such project"))?
615        }
616    }
617
618    #[cfg(test)]
619    pub fn check_invariants(&self) {
620        for (connection_id, connection) in &self.connections {
621            for project_id in &connection.projects {
622                let project = &self.projects.get(&project_id).unwrap();
623                if project.host_connection_id != *connection_id {
624                    assert!(project
625                        .share
626                        .as_ref()
627                        .unwrap()
628                        .guests
629                        .contains_key(connection_id));
630                }
631
632                if let Some(share) = project.share.as_ref() {
633                    for (worktree_id, worktree) in share.worktrees.iter() {
634                        let mut paths = HashMap::default();
635                        for entry in worktree.entries.values() {
636                            let prev_entry = paths.insert(&entry.path, entry);
637                            assert_eq!(
638                                prev_entry,
639                                None,
640                                "worktree {:?}, duplicate path for entries {:?} and {:?}",
641                                worktree_id,
642                                prev_entry.unwrap(),
643                                entry
644                            );
645                        }
646                    }
647                }
648            }
649            for channel_id in &connection.channels {
650                let channel = self.channels.get(channel_id).unwrap();
651                assert!(channel.connection_ids.contains(connection_id));
652            }
653            assert!(self
654                .connections_by_user_id
655                .get(&connection.user_id)
656                .unwrap()
657                .contains(connection_id));
658        }
659
660        for (user_id, connection_ids) in &self.connections_by_user_id {
661            for connection_id in connection_ids {
662                assert_eq!(
663                    self.connections.get(connection_id).unwrap().user_id,
664                    *user_id
665                );
666            }
667        }
668
669        for (project_id, project) in &self.projects {
670            let host_connection = self.connections.get(&project.host_connection_id).unwrap();
671            assert!(host_connection.projects.contains(project_id));
672
673            for authorized_user_ids in project.authorized_user_ids() {
674                let visible_project_ids = self
675                    .visible_projects_by_user_id
676                    .get(&authorized_user_ids)
677                    .unwrap();
678                assert!(visible_project_ids.contains(project_id));
679            }
680
681            if let Some(share) = &project.share {
682                for guest_connection_id in share.guests.keys() {
683                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
684                    assert!(guest_connection.projects.contains(project_id));
685                }
686                assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
687                assert_eq!(
688                    share.active_replica_ids,
689                    share
690                        .guests
691                        .values()
692                        .map(|(replica_id, _)| *replica_id)
693                        .collect::<HashSet<_>>(),
694                );
695            }
696        }
697
698        for (user_id, visible_project_ids) in &self.visible_projects_by_user_id {
699            for project_id in visible_project_ids {
700                let project = self.projects.get(project_id).unwrap();
701                assert!(project.authorized_user_ids().contains(user_id));
702            }
703        }
704
705        for (channel_id, channel) in &self.channels {
706            for connection_id in &channel.connection_ids {
707                let connection = self.connections.get(connection_id).unwrap();
708                assert!(connection.channels.contains(channel_id));
709            }
710        }
711    }
712}
713
714impl Project {
715    pub fn has_authorized_user_id(&self, user_id: UserId) -> bool {
716        self.worktrees
717            .values()
718            .any(|worktree| worktree.authorized_user_ids.contains(&user_id))
719    }
720
721    pub fn authorized_user_ids(&self) -> Vec<UserId> {
722        let mut ids = self
723            .worktrees
724            .values()
725            .flat_map(|worktree| worktree.authorized_user_ids.iter())
726            .copied()
727            .collect::<Vec<_>>();
728        ids.sort_unstable();
729        ids.dedup();
730        ids
731    }
732
733    pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
734        if let Some(share) = &self.share {
735            share.guests.keys().copied().collect()
736        } else {
737            Vec::new()
738        }
739    }
740
741    pub fn connection_ids(&self) -> Vec<ConnectionId> {
742        if let Some(share) = &self.share {
743            share
744                .guests
745                .keys()
746                .copied()
747                .chain(Some(self.host_connection_id))
748                .collect()
749        } else {
750            vec![self.host_connection_id]
751        }
752    }
753
754    pub fn share(&self) -> Result<&ProjectShare> {
755        Ok(self
756            .share
757            .as_ref()
758            .ok_or_else(|| anyhow!("worktree is not shared"))?)
759    }
760
761    fn share_mut(&mut self) -> Result<&mut ProjectShare> {
762        Ok(self
763            .share
764            .as_mut()
765            .ok_or_else(|| anyhow!("worktree is not shared"))?)
766    }
767}
768
769impl Channel {
770    fn connection_ids(&self) -> Vec<ConnectionId> {
771        self.connection_ids.iter().copied().collect()
772    }
773}