store.rs

  1use crate::db::{ChannelId, UserId};
  2use anyhow::anyhow;
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use std::{collections::hash_map, path::PathBuf};
  6
  7#[derive(Default)]
  8pub struct Store {
  9    connections: HashMap<ConnectionId, ConnectionState>,
 10    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 11    projects: HashMap<u64, Project>,
 12    visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
 13    channels: HashMap<ChannelId, Channel>,
 14    next_project_id: u64,
 15}
 16
 17struct ConnectionState {
 18    user_id: UserId,
 19    projects: HashSet<u64>,
 20    channels: HashSet<ChannelId>,
 21}
 22
 23pub struct Project {
 24    pub host_connection_id: ConnectionId,
 25    pub host_user_id: UserId,
 26    pub share: Option<ProjectShare>,
 27    pub worktrees: HashMap<u64, Worktree>,
 28    pub language_servers: Vec<proto::LanguageServer>,
 29}
 30
 31pub struct Worktree {
 32    pub authorized_user_ids: Vec<UserId>,
 33    pub root_name: String,
 34    pub visible: bool,
 35}
 36
 37#[derive(Default)]
 38pub struct ProjectShare {
 39    pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
 40    pub active_replica_ids: HashSet<ReplicaId>,
 41    pub worktrees: HashMap<u64, WorktreeShare>,
 42}
 43
 44#[derive(Default)]
 45pub struct WorktreeShare {
 46    pub entries: HashMap<u64, proto::Entry>,
 47    pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
 48}
 49
 50#[derive(Default)]
 51pub struct Channel {
 52    pub connection_ids: HashSet<ConnectionId>,
 53}
 54
 55pub type ReplicaId = u16;
 56
 57#[derive(Default)]
 58pub struct RemovedConnectionState {
 59    pub hosted_projects: HashMap<u64, Project>,
 60    pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
 61    pub contact_ids: HashSet<UserId>,
 62}
 63
 64pub struct JoinedProject<'a> {
 65    pub replica_id: ReplicaId,
 66    pub project: &'a Project,
 67}
 68
 69pub struct UnsharedProject {
 70    pub connection_ids: Vec<ConnectionId>,
 71    pub authorized_user_ids: Vec<UserId>,
 72}
 73
 74pub struct LeftProject {
 75    pub connection_ids: Vec<ConnectionId>,
 76    pub authorized_user_ids: Vec<UserId>,
 77}
 78
 79impl Store {
 80    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 81        self.connections.insert(
 82            connection_id,
 83            ConnectionState {
 84                user_id,
 85                projects: Default::default(),
 86                channels: Default::default(),
 87            },
 88        );
 89        self.connections_by_user_id
 90            .entry(user_id)
 91            .or_default()
 92            .insert(connection_id);
 93    }
 94
 95    pub fn remove_connection(
 96        &mut self,
 97        connection_id: ConnectionId,
 98    ) -> tide::Result<RemovedConnectionState> {
 99        let connection = if let Some(connection) = self.connections.remove(&connection_id) {
100            connection
101        } else {
102            return Err(anyhow!("no such connection"))?;
103        };
104
105        for channel_id in &connection.channels {
106            if let Some(channel) = self.channels.get_mut(&channel_id) {
107                channel.connection_ids.remove(&connection_id);
108            }
109        }
110
111        let user_connections = self
112            .connections_by_user_id
113            .get_mut(&connection.user_id)
114            .unwrap();
115        user_connections.remove(&connection_id);
116        if user_connections.is_empty() {
117            self.connections_by_user_id.remove(&connection.user_id);
118        }
119
120        let mut result = RemovedConnectionState::default();
121        for project_id in connection.projects.clone() {
122            if let Ok(project) = self.unregister_project(project_id, connection_id) {
123                result.contact_ids.extend(project.authorized_user_ids());
124                result.hosted_projects.insert(project_id, project);
125            } else if let Ok(project) = self.leave_project(connection_id, project_id) {
126                result
127                    .guest_project_ids
128                    .insert(project_id, project.connection_ids);
129                result.contact_ids.extend(project.authorized_user_ids);
130            }
131        }
132
133        #[cfg(test)]
134        self.check_invariants();
135
136        Ok(result)
137    }
138
139    #[cfg(test)]
140    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
141        self.channels.get(&id)
142    }
143
144    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
145        if let Some(connection) = self.connections.get_mut(&connection_id) {
146            connection.channels.insert(channel_id);
147            self.channels
148                .entry(channel_id)
149                .or_default()
150                .connection_ids
151                .insert(connection_id);
152        }
153    }
154
155    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
156        if let Some(connection) = self.connections.get_mut(&connection_id) {
157            connection.channels.remove(&channel_id);
158            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
159                entry.get_mut().connection_ids.remove(&connection_id);
160                if entry.get_mut().connection_ids.is_empty() {
161                    entry.remove();
162                }
163            }
164        }
165    }
166
167    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
168        Ok(self
169            .connections
170            .get(&connection_id)
171            .ok_or_else(|| anyhow!("unknown connection"))?
172            .user_id)
173    }
174
175    pub fn connection_ids_for_user<'a>(
176        &'a self,
177        user_id: UserId,
178    ) -> impl 'a + Iterator<Item = ConnectionId> {
179        self.connections_by_user_id
180            .get(&user_id)
181            .into_iter()
182            .flatten()
183            .copied()
184    }
185
186    pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
187        let mut contacts = HashMap::default();
188        for project_id in self
189            .visible_projects_by_user_id
190            .get(&user_id)
191            .unwrap_or(&HashSet::default())
192        {
193            let project = &self.projects[project_id];
194
195            let mut guests = HashSet::default();
196            if let Ok(share) = project.share() {
197                for guest_connection_id in share.guests.keys() {
198                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
199                        guests.insert(user_id.to_proto());
200                    }
201                }
202            }
203
204            if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
205                let mut worktree_root_names = project
206                    .worktrees
207                    .values()
208                    .filter(|worktree| worktree.visible)
209                    .map(|worktree| worktree.root_name.clone())
210                    .collect::<Vec<_>>();
211                worktree_root_names.sort_unstable();
212                contacts
213                    .entry(host_user_id)
214                    .or_insert_with(|| proto::Contact {
215                        user_id: host_user_id.to_proto(),
216                        projects: Vec::new(),
217                    })
218                    .projects
219                    .push(proto::ProjectMetadata {
220                        id: *project_id,
221                        worktree_root_names,
222                        is_shared: project.share.is_some(),
223                        guests: guests.into_iter().collect(),
224                    });
225            }
226        }
227
228        contacts.into_values().collect()
229    }
230
231    pub fn register_project(
232        &mut self,
233        host_connection_id: ConnectionId,
234        host_user_id: UserId,
235    ) -> u64 {
236        let project_id = self.next_project_id;
237        self.projects.insert(
238            project_id,
239            Project {
240                host_connection_id,
241                host_user_id,
242                share: None,
243                worktrees: Default::default(),
244                language_servers: Default::default(),
245            },
246        );
247        if let Some(connection) = self.connections.get_mut(&host_connection_id) {
248            connection.projects.insert(project_id);
249        }
250        self.next_project_id += 1;
251        project_id
252    }
253
254    pub fn register_worktree(
255        &mut self,
256        project_id: u64,
257        worktree_id: u64,
258        connection_id: ConnectionId,
259        worktree: Worktree,
260    ) -> tide::Result<()> {
261        let project = self
262            .projects
263            .get_mut(&project_id)
264            .ok_or_else(|| anyhow!("no such project"))?;
265        if project.host_connection_id == connection_id {
266            for authorized_user_id in &worktree.authorized_user_ids {
267                self.visible_projects_by_user_id
268                    .entry(*authorized_user_id)
269                    .or_default()
270                    .insert(project_id);
271            }
272
273            project.worktrees.insert(worktree_id, worktree);
274            if let Ok(share) = project.share_mut() {
275                share.worktrees.insert(worktree_id, Default::default());
276            }
277
278            #[cfg(test)]
279            self.check_invariants();
280            Ok(())
281        } else {
282            Err(anyhow!("no such project"))?
283        }
284    }
285
286    pub fn unregister_project(
287        &mut self,
288        project_id: u64,
289        connection_id: ConnectionId,
290    ) -> tide::Result<Project> {
291        match self.projects.entry(project_id) {
292            hash_map::Entry::Occupied(e) => {
293                if e.get().host_connection_id == connection_id {
294                    for user_id in e.get().authorized_user_ids() {
295                        if let hash_map::Entry::Occupied(mut projects) =
296                            self.visible_projects_by_user_id.entry(user_id)
297                        {
298                            projects.get_mut().remove(&project_id);
299                        }
300                    }
301
302                    let project = e.remove();
303
304                    if let Some(host_connection) = self.connections.get_mut(&connection_id) {
305                        host_connection.projects.remove(&project_id);
306                    }
307
308                    if let Some(share) = &project.share {
309                        for guest_connection in share.guests.keys() {
310                            if let Some(connection) = self.connections.get_mut(&guest_connection) {
311                                connection.projects.remove(&project_id);
312                            }
313                        }
314                    }
315
316                    #[cfg(test)]
317                    self.check_invariants();
318                    Ok(project)
319                } else {
320                    Err(anyhow!("no such project"))?
321                }
322            }
323            hash_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?,
324        }
325    }
326
327    pub fn unregister_worktree(
328        &mut self,
329        project_id: u64,
330        worktree_id: u64,
331        acting_connection_id: ConnectionId,
332    ) -> tide::Result<(Worktree, Vec<ConnectionId>)> {
333        let project = self
334            .projects
335            .get_mut(&project_id)
336            .ok_or_else(|| anyhow!("no such project"))?;
337        if project.host_connection_id != acting_connection_id {
338            Err(anyhow!("not your worktree"))?;
339        }
340
341        let worktree = project
342            .worktrees
343            .remove(&worktree_id)
344            .ok_or_else(|| anyhow!("no such worktree"))?;
345
346        let mut guest_connection_ids = Vec::new();
347        if let Ok(share) = project.share_mut() {
348            guest_connection_ids.extend(share.guests.keys());
349            share.worktrees.remove(&worktree_id);
350        }
351
352        for authorized_user_id in &worktree.authorized_user_ids {
353            if let Some(visible_projects) =
354                self.visible_projects_by_user_id.get_mut(authorized_user_id)
355            {
356                if !project.has_authorized_user_id(*authorized_user_id) {
357                    visible_projects.remove(&project_id);
358                }
359            }
360        }
361
362        #[cfg(test)]
363        self.check_invariants();
364
365        Ok((worktree, guest_connection_ids))
366    }
367
368    pub fn share_project(&mut self, project_id: u64, connection_id: ConnectionId) -> bool {
369        if let Some(project) = self.projects.get_mut(&project_id) {
370            if project.host_connection_id == connection_id {
371                let mut share = ProjectShare::default();
372                for worktree_id in project.worktrees.keys() {
373                    share.worktrees.insert(*worktree_id, Default::default());
374                }
375                project.share = Some(share);
376                return true;
377            }
378        }
379        false
380    }
381
382    pub fn unshare_project(
383        &mut self,
384        project_id: u64,
385        acting_connection_id: ConnectionId,
386    ) -> tide::Result<UnsharedProject> {
387        let project = if let Some(project) = self.projects.get_mut(&project_id) {
388            project
389        } else {
390            return Err(anyhow!("no such project"))?;
391        };
392
393        if project.host_connection_id != acting_connection_id {
394            return Err(anyhow!("not your project"))?;
395        }
396
397        let connection_ids = project.connection_ids();
398        let authorized_user_ids = project.authorized_user_ids();
399        if let Some(share) = project.share.take() {
400            for connection_id in share.guests.into_keys() {
401                if let Some(connection) = self.connections.get_mut(&connection_id) {
402                    connection.projects.remove(&project_id);
403                }
404            }
405
406            #[cfg(test)]
407            self.check_invariants();
408
409            Ok(UnsharedProject {
410                connection_ids,
411                authorized_user_ids,
412            })
413        } else {
414            Err(anyhow!("project is not shared"))?
415        }
416    }
417
418    pub fn update_diagnostic_summary(
419        &mut self,
420        project_id: u64,
421        worktree_id: u64,
422        connection_id: ConnectionId,
423        summary: proto::DiagnosticSummary,
424    ) -> tide::Result<Vec<ConnectionId>> {
425        let project = self
426            .projects
427            .get_mut(&project_id)
428            .ok_or_else(|| anyhow!("no such project"))?;
429        if project.host_connection_id == connection_id {
430            let worktree = project
431                .share_mut()?
432                .worktrees
433                .get_mut(&worktree_id)
434                .ok_or_else(|| anyhow!("no such worktree"))?;
435            worktree
436                .diagnostic_summaries
437                .insert(summary.path.clone().into(), summary);
438            return Ok(project.connection_ids());
439        }
440
441        Err(anyhow!("no such worktree"))?
442    }
443
444    pub fn start_language_server(
445        &mut self,
446        project_id: u64,
447        connection_id: ConnectionId,
448        language_server: proto::LanguageServer,
449    ) -> tide::Result<Vec<ConnectionId>> {
450        let project = self
451            .projects
452            .get_mut(&project_id)
453            .ok_or_else(|| anyhow!("no such project"))?;
454        if project.host_connection_id == connection_id {
455            project.language_servers.push(language_server);
456            return Ok(project.connection_ids());
457        }
458
459        Err(anyhow!("no such project"))?
460    }
461
462    pub fn join_project(
463        &mut self,
464        connection_id: ConnectionId,
465        user_id: UserId,
466        project_id: u64,
467    ) -> tide::Result<JoinedProject> {
468        let connection = self
469            .connections
470            .get_mut(&connection_id)
471            .ok_or_else(|| anyhow!("no such connection"))?;
472        let project = self
473            .projects
474            .get_mut(&project_id)
475            .and_then(|project| {
476                if project.has_authorized_user_id(user_id) {
477                    Some(project)
478                } else {
479                    None
480                }
481            })
482            .ok_or_else(|| anyhow!("no such project"))?;
483
484        let share = project.share_mut()?;
485        connection.projects.insert(project_id);
486
487        let mut replica_id = 1;
488        while share.active_replica_ids.contains(&replica_id) {
489            replica_id += 1;
490        }
491        share.active_replica_ids.insert(replica_id);
492        share.guests.insert(connection_id, (replica_id, user_id));
493
494        #[cfg(test)]
495        self.check_invariants();
496
497        Ok(JoinedProject {
498            replica_id,
499            project: &self.projects[&project_id],
500        })
501    }
502
503    pub fn leave_project(
504        &mut self,
505        connection_id: ConnectionId,
506        project_id: u64,
507    ) -> tide::Result<LeftProject> {
508        let project = self
509            .projects
510            .get_mut(&project_id)
511            .ok_or_else(|| anyhow!("no such project"))?;
512        let share = project
513            .share
514            .as_mut()
515            .ok_or_else(|| anyhow!("project is not shared"))?;
516        let (replica_id, _) = share
517            .guests
518            .remove(&connection_id)
519            .ok_or_else(|| anyhow!("cannot leave a project before joining it"))?;
520        share.active_replica_ids.remove(&replica_id);
521
522        if let Some(connection) = self.connections.get_mut(&connection_id) {
523            connection.projects.remove(&project_id);
524        }
525
526        let connection_ids = project.connection_ids();
527        let authorized_user_ids = project.authorized_user_ids();
528
529        #[cfg(test)]
530        self.check_invariants();
531
532        Ok(LeftProject {
533            connection_ids,
534            authorized_user_ids,
535        })
536    }
537
538    pub fn update_worktree(
539        &mut self,
540        connection_id: ConnectionId,
541        project_id: u64,
542        worktree_id: u64,
543        removed_entries: &[u64],
544        updated_entries: &[proto::Entry],
545    ) -> tide::Result<Vec<ConnectionId>> {
546        let project = self.write_project(project_id, connection_id)?;
547        let worktree = project
548            .share_mut()?
549            .worktrees
550            .get_mut(&worktree_id)
551            .ok_or_else(|| anyhow!("no such worktree"))?;
552        for entry_id in removed_entries {
553            worktree.entries.remove(&entry_id);
554        }
555        for entry in updated_entries {
556            worktree.entries.insert(entry.id, entry.clone());
557        }
558        let connection_ids = project.connection_ids();
559
560        #[cfg(test)]
561        self.check_invariants();
562
563        Ok(connection_ids)
564    }
565
566    pub fn project_connection_ids(
567        &self,
568        project_id: u64,
569        acting_connection_id: ConnectionId,
570    ) -> tide::Result<Vec<ConnectionId>> {
571        Ok(self
572            .read_project(project_id, acting_connection_id)?
573            .connection_ids())
574    }
575
576    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> tide::Result<Vec<ConnectionId>> {
577        Ok(self
578            .channels
579            .get(&channel_id)
580            .ok_or_else(|| anyhow!("no such channel"))?
581            .connection_ids())
582    }
583
584    #[cfg(test)]
585    pub fn project(&self, project_id: u64) -> Option<&Project> {
586        self.projects.get(&project_id)
587    }
588
589    pub fn read_project(
590        &self,
591        project_id: u64,
592        connection_id: ConnectionId,
593    ) -> tide::Result<&Project> {
594        let project = self
595            .projects
596            .get(&project_id)
597            .ok_or_else(|| anyhow!("no such project"))?;
598        if project.host_connection_id == connection_id
599            || project
600                .share
601                .as_ref()
602                .ok_or_else(|| anyhow!("project is not shared"))?
603                .guests
604                .contains_key(&connection_id)
605        {
606            Ok(project)
607        } else {
608            Err(anyhow!("no such project"))?
609        }
610    }
611
612    fn write_project(
613        &mut self,
614        project_id: u64,
615        connection_id: ConnectionId,
616    ) -> tide::Result<&mut Project> {
617        let project = self
618            .projects
619            .get_mut(&project_id)
620            .ok_or_else(|| anyhow!("no such project"))?;
621        if project.host_connection_id == connection_id
622            || project
623                .share
624                .as_ref()
625                .ok_or_else(|| anyhow!("project is not shared"))?
626                .guests
627                .contains_key(&connection_id)
628        {
629            Ok(project)
630        } else {
631            Err(anyhow!("no such project"))?
632        }
633    }
634
635    #[cfg(test)]
636    fn check_invariants(&self) {
637        for (connection_id, connection) in &self.connections {
638            for project_id in &connection.projects {
639                let project = &self.projects.get(&project_id).unwrap();
640                if project.host_connection_id != *connection_id {
641                    assert!(project
642                        .share
643                        .as_ref()
644                        .unwrap()
645                        .guests
646                        .contains_key(connection_id));
647                }
648
649                if let Some(share) = project.share.as_ref() {
650                    for (worktree_id, worktree) in share.worktrees.iter() {
651                        let mut paths = HashMap::default();
652                        for entry in worktree.entries.values() {
653                            let prev_entry = paths.insert(&entry.path, entry);
654                            assert_eq!(
655                                prev_entry,
656                                None,
657                                "worktree {:?}, duplicate path for entries {:?} and {:?}",
658                                worktree_id,
659                                prev_entry.unwrap(),
660                                entry
661                            );
662                        }
663                    }
664                }
665            }
666            for channel_id in &connection.channels {
667                let channel = self.channels.get(channel_id).unwrap();
668                assert!(channel.connection_ids.contains(connection_id));
669            }
670            assert!(self
671                .connections_by_user_id
672                .get(&connection.user_id)
673                .unwrap()
674                .contains(connection_id));
675        }
676
677        for (user_id, connection_ids) in &self.connections_by_user_id {
678            for connection_id in connection_ids {
679                assert_eq!(
680                    self.connections.get(connection_id).unwrap().user_id,
681                    *user_id
682                );
683            }
684        }
685
686        for (project_id, project) in &self.projects {
687            let host_connection = self.connections.get(&project.host_connection_id).unwrap();
688            assert!(host_connection.projects.contains(project_id));
689
690            for authorized_user_ids in project.authorized_user_ids() {
691                let visible_project_ids = self
692                    .visible_projects_by_user_id
693                    .get(&authorized_user_ids)
694                    .unwrap();
695                assert!(visible_project_ids.contains(project_id));
696            }
697
698            if let Some(share) = &project.share {
699                for guest_connection_id in share.guests.keys() {
700                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
701                    assert!(guest_connection.projects.contains(project_id));
702                }
703                assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
704                assert_eq!(
705                    share.active_replica_ids,
706                    share
707                        .guests
708                        .values()
709                        .map(|(replica_id, _)| *replica_id)
710                        .collect::<HashSet<_>>(),
711                );
712            }
713        }
714
715        for (user_id, visible_project_ids) in &self.visible_projects_by_user_id {
716            for project_id in visible_project_ids {
717                let project = self.projects.get(project_id).unwrap();
718                assert!(project.authorized_user_ids().contains(user_id));
719            }
720        }
721
722        for (channel_id, channel) in &self.channels {
723            for connection_id in &channel.connection_ids {
724                let connection = self.connections.get(connection_id).unwrap();
725                assert!(connection.channels.contains(channel_id));
726            }
727        }
728    }
729}
730
731impl Project {
732    pub fn has_authorized_user_id(&self, user_id: UserId) -> bool {
733        self.worktrees
734            .values()
735            .any(|worktree| worktree.authorized_user_ids.contains(&user_id))
736    }
737
738    pub fn authorized_user_ids(&self) -> Vec<UserId> {
739        let mut ids = self
740            .worktrees
741            .values()
742            .flat_map(|worktree| worktree.authorized_user_ids.iter())
743            .copied()
744            .collect::<Vec<_>>();
745        ids.sort_unstable();
746        ids.dedup();
747        ids
748    }
749
750    pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
751        if let Some(share) = &self.share {
752            share.guests.keys().copied().collect()
753        } else {
754            Vec::new()
755        }
756    }
757
758    pub fn connection_ids(&self) -> Vec<ConnectionId> {
759        if let Some(share) = &self.share {
760            share
761                .guests
762                .keys()
763                .copied()
764                .chain(Some(self.host_connection_id))
765                .collect()
766        } else {
767            vec![self.host_connection_id]
768        }
769    }
770
771    pub fn share(&self) -> tide::Result<&ProjectShare> {
772        Ok(self
773            .share
774            .as_ref()
775            .ok_or_else(|| anyhow!("worktree is not shared"))?)
776    }
777
778    fn share_mut(&mut self) -> tide::Result<&mut ProjectShare> {
779        Ok(self
780            .share
781            .as_mut()
782            .ok_or_else(|| anyhow!("worktree is not shared"))?)
783    }
784}
785
786impl Channel {
787    fn connection_ids(&self) -> Vec<ConnectionId> {
788        self.connection_ids.iter().copied().collect()
789    }
790}