store.rs

  1use crate::db::{ChannelId, UserId};
  2use anyhow::anyhow;
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use std::{collections::hash_map, path::PathBuf};
  6
  7#[derive(Default)]
  8pub struct Store {
  9    connections: HashMap<ConnectionId, ConnectionState>,
 10    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 11    projects: HashMap<u64, Project>,
 12    visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
 13    channels: HashMap<ChannelId, Channel>,
 14    next_project_id: u64,
 15}
 16
 17struct ConnectionState {
 18    user_id: UserId,
 19    projects: HashSet<u64>,
 20    channels: HashSet<ChannelId>,
 21}
 22
 23pub struct Project {
 24    pub host_connection_id: ConnectionId,
 25    pub host_user_id: UserId,
 26    pub share: Option<ProjectShare>,
 27    pub worktrees: HashMap<u64, Worktree>,
 28    pub language_servers: Vec<proto::LanguageServer>,
 29}
 30
 31pub struct Worktree {
 32    pub authorized_user_ids: Vec<UserId>,
 33    pub root_name: String,
 34    pub visible: bool,
 35}
 36
 37#[derive(Default)]
 38pub struct ProjectShare {
 39    pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
 40    pub active_replica_ids: HashSet<ReplicaId>,
 41    pub worktrees: HashMap<u64, WorktreeShare>,
 42}
 43
 44#[derive(Default)]
 45pub struct WorktreeShare {
 46    pub entries: HashMap<u64, proto::Entry>,
 47    pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
 48}
 49
 50#[derive(Default)]
 51pub struct Channel {
 52    pub connection_ids: HashSet<ConnectionId>,
 53}
 54
 55pub type ReplicaId = u16;
 56
 57#[derive(Default)]
 58pub struct RemovedConnectionState {
 59    pub hosted_projects: HashMap<u64, Project>,
 60    pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
 61    pub contact_ids: HashSet<UserId>,
 62}
 63
 64pub struct JoinedProject<'a> {
 65    pub replica_id: ReplicaId,
 66    pub project: &'a Project,
 67}
 68
 69pub struct SharedProject {
 70    pub authorized_user_ids: Vec<UserId>,
 71}
 72
 73pub struct UnsharedProject {
 74    pub connection_ids: Vec<ConnectionId>,
 75    pub authorized_user_ids: Vec<UserId>,
 76}
 77
 78pub struct LeftProject {
 79    pub connection_ids: Vec<ConnectionId>,
 80    pub authorized_user_ids: Vec<UserId>,
 81}
 82
 83impl Store {
 84    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 85        self.connections.insert(
 86            connection_id,
 87            ConnectionState {
 88                user_id,
 89                projects: Default::default(),
 90                channels: Default::default(),
 91            },
 92        );
 93        self.connections_by_user_id
 94            .entry(user_id)
 95            .or_default()
 96            .insert(connection_id);
 97    }
 98
 99    pub fn remove_connection(
100        &mut self,
101        connection_id: ConnectionId,
102    ) -> tide::Result<RemovedConnectionState> {
103        let connection = if let Some(connection) = self.connections.remove(&connection_id) {
104            connection
105        } else {
106            return Err(anyhow!("no such connection"))?;
107        };
108
109        for channel_id in &connection.channels {
110            if let Some(channel) = self.channels.get_mut(&channel_id) {
111                channel.connection_ids.remove(&connection_id);
112            }
113        }
114
115        let user_connections = self
116            .connections_by_user_id
117            .get_mut(&connection.user_id)
118            .unwrap();
119        user_connections.remove(&connection_id);
120        if user_connections.is_empty() {
121            self.connections_by_user_id.remove(&connection.user_id);
122        }
123
124        let mut result = RemovedConnectionState::default();
125        for project_id in connection.projects.clone() {
126            if let Ok(project) = self.unregister_project(project_id, connection_id) {
127                result.contact_ids.extend(project.authorized_user_ids());
128                result.hosted_projects.insert(project_id, project);
129            } else if let Ok(project) = self.leave_project(connection_id, project_id) {
130                result
131                    .guest_project_ids
132                    .insert(project_id, project.connection_ids);
133                result.contact_ids.extend(project.authorized_user_ids);
134            }
135        }
136
137        Ok(result)
138    }
139
140    #[cfg(test)]
141    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
142        self.channels.get(&id)
143    }
144
145    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
146        if let Some(connection) = self.connections.get_mut(&connection_id) {
147            connection.channels.insert(channel_id);
148            self.channels
149                .entry(channel_id)
150                .or_default()
151                .connection_ids
152                .insert(connection_id);
153        }
154    }
155
156    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
157        if let Some(connection) = self.connections.get_mut(&connection_id) {
158            connection.channels.remove(&channel_id);
159            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
160                entry.get_mut().connection_ids.remove(&connection_id);
161                if entry.get_mut().connection_ids.is_empty() {
162                    entry.remove();
163                }
164            }
165        }
166    }
167
168    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
169        Ok(self
170            .connections
171            .get(&connection_id)
172            .ok_or_else(|| anyhow!("unknown connection"))?
173            .user_id)
174    }
175
176    pub fn connection_ids_for_user<'a>(
177        &'a self,
178        user_id: UserId,
179    ) -> impl 'a + Iterator<Item = ConnectionId> {
180        self.connections_by_user_id
181            .get(&user_id)
182            .into_iter()
183            .flatten()
184            .copied()
185    }
186
187    pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
188        let mut contacts = HashMap::default();
189        for project_id in self
190            .visible_projects_by_user_id
191            .get(&user_id)
192            .unwrap_or(&HashSet::default())
193        {
194            let project = &self.projects[project_id];
195
196            let mut guests = HashSet::default();
197            if let Ok(share) = project.share() {
198                for guest_connection_id in share.guests.keys() {
199                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
200                        guests.insert(user_id.to_proto());
201                    }
202                }
203            }
204
205            if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
206                let mut worktree_root_names = project
207                    .worktrees
208                    .values()
209                    .filter(|worktree| worktree.visible)
210                    .map(|worktree| worktree.root_name.clone())
211                    .collect::<Vec<_>>();
212                worktree_root_names.sort_unstable();
213                contacts
214                    .entry(host_user_id)
215                    .or_insert_with(|| proto::Contact {
216                        user_id: host_user_id.to_proto(),
217                        projects: Vec::new(),
218                    })
219                    .projects
220                    .push(proto::ProjectMetadata {
221                        id: *project_id,
222                        worktree_root_names,
223                        is_shared: project.share.is_some(),
224                        guests: guests.into_iter().collect(),
225                    });
226            }
227        }
228
229        contacts.into_values().collect()
230    }
231
232    pub fn register_project(
233        &mut self,
234        host_connection_id: ConnectionId,
235        host_user_id: UserId,
236    ) -> u64 {
237        let project_id = self.next_project_id;
238        self.projects.insert(
239            project_id,
240            Project {
241                host_connection_id,
242                host_user_id,
243                share: None,
244                worktrees: Default::default(),
245                language_servers: Default::default(),
246            },
247        );
248        if let Some(connection) = self.connections.get_mut(&host_connection_id) {
249            connection.projects.insert(project_id);
250        }
251        self.next_project_id += 1;
252        project_id
253    }
254
255    pub fn register_worktree(
256        &mut self,
257        project_id: u64,
258        worktree_id: u64,
259        connection_id: ConnectionId,
260        worktree: Worktree,
261    ) -> tide::Result<()> {
262        let project = self
263            .projects
264            .get_mut(&project_id)
265            .ok_or_else(|| anyhow!("no such project"))?;
266        if project.host_connection_id == connection_id {
267            for authorized_user_id in &worktree.authorized_user_ids {
268                self.visible_projects_by_user_id
269                    .entry(*authorized_user_id)
270                    .or_default()
271                    .insert(project_id);
272            }
273
274            project.worktrees.insert(worktree_id, worktree);
275            if let Ok(share) = project.share_mut() {
276                share.worktrees.insert(worktree_id, Default::default());
277            }
278
279            Ok(())
280        } else {
281            Err(anyhow!("no such project"))?
282        }
283    }
284
285    pub fn unregister_project(
286        &mut self,
287        project_id: u64,
288        connection_id: ConnectionId,
289    ) -> tide::Result<Project> {
290        match self.projects.entry(project_id) {
291            hash_map::Entry::Occupied(e) => {
292                if e.get().host_connection_id == connection_id {
293                    for user_id in e.get().authorized_user_ids() {
294                        if let hash_map::Entry::Occupied(mut projects) =
295                            self.visible_projects_by_user_id.entry(user_id)
296                        {
297                            projects.get_mut().remove(&project_id);
298                        }
299                    }
300
301                    let project = e.remove();
302
303                    if let Some(host_connection) = self.connections.get_mut(&connection_id) {
304                        host_connection.projects.remove(&project_id);
305                    }
306
307                    if let Some(share) = &project.share {
308                        for guest_connection in share.guests.keys() {
309                            if let Some(connection) = self.connections.get_mut(&guest_connection) {
310                                connection.projects.remove(&project_id);
311                            }
312                        }
313                    }
314
315                    Ok(project)
316                } else {
317                    Err(anyhow!("no such project"))?
318                }
319            }
320            hash_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?,
321        }
322    }
323
324    pub fn unregister_worktree(
325        &mut self,
326        project_id: u64,
327        worktree_id: u64,
328        acting_connection_id: ConnectionId,
329    ) -> tide::Result<(Worktree, Vec<ConnectionId>)> {
330        let project = self
331            .projects
332            .get_mut(&project_id)
333            .ok_or_else(|| anyhow!("no such project"))?;
334        if project.host_connection_id != acting_connection_id {
335            Err(anyhow!("not your worktree"))?;
336        }
337
338        let worktree = project
339            .worktrees
340            .remove(&worktree_id)
341            .ok_or_else(|| anyhow!("no such worktree"))?;
342
343        let mut guest_connection_ids = Vec::new();
344        if let Ok(share) = project.share_mut() {
345            guest_connection_ids.extend(share.guests.keys());
346            share.worktrees.remove(&worktree_id);
347        }
348
349        for authorized_user_id in &worktree.authorized_user_ids {
350            if let Some(visible_projects) =
351                self.visible_projects_by_user_id.get_mut(authorized_user_id)
352            {
353                if !project.has_authorized_user_id(*authorized_user_id) {
354                    visible_projects.remove(&project_id);
355                }
356            }
357        }
358
359        Ok((worktree, guest_connection_ids))
360    }
361
362    pub fn share_project(
363        &mut self,
364        project_id: u64,
365        connection_id: ConnectionId,
366    ) -> tide::Result<SharedProject> {
367        if let Some(project) = self.projects.get_mut(&project_id) {
368            if project.host_connection_id == connection_id {
369                let mut share = ProjectShare::default();
370                for worktree_id in project.worktrees.keys() {
371                    share.worktrees.insert(*worktree_id, Default::default());
372                }
373                project.share = Some(share);
374                return Ok(SharedProject {
375                    authorized_user_ids: project.authorized_user_ids(),
376                });
377            }
378        }
379        Err(anyhow!("no such project"))?
380    }
381
382    pub fn unshare_project(
383        &mut self,
384        project_id: u64,
385        acting_connection_id: ConnectionId,
386    ) -> tide::Result<UnsharedProject> {
387        let project = if let Some(project) = self.projects.get_mut(&project_id) {
388            project
389        } else {
390            return Err(anyhow!("no such project"))?;
391        };
392
393        if project.host_connection_id != acting_connection_id {
394            return Err(anyhow!("not your project"))?;
395        }
396
397        let connection_ids = project.connection_ids();
398        let authorized_user_ids = project.authorized_user_ids();
399        if let Some(share) = project.share.take() {
400            for connection_id in share.guests.into_keys() {
401                if let Some(connection) = self.connections.get_mut(&connection_id) {
402                    connection.projects.remove(&project_id);
403                }
404            }
405
406            Ok(UnsharedProject {
407                connection_ids,
408                authorized_user_ids,
409            })
410        } else {
411            Err(anyhow!("project is not shared"))?
412        }
413    }
414
415    pub fn update_diagnostic_summary(
416        &mut self,
417        project_id: u64,
418        worktree_id: u64,
419        connection_id: ConnectionId,
420        summary: proto::DiagnosticSummary,
421    ) -> tide::Result<Vec<ConnectionId>> {
422        let project = self
423            .projects
424            .get_mut(&project_id)
425            .ok_or_else(|| anyhow!("no such project"))?;
426        if project.host_connection_id == connection_id {
427            let worktree = project
428                .share_mut()?
429                .worktrees
430                .get_mut(&worktree_id)
431                .ok_or_else(|| anyhow!("no such worktree"))?;
432            worktree
433                .diagnostic_summaries
434                .insert(summary.path.clone().into(), summary);
435            return Ok(project.connection_ids());
436        }
437
438        Err(anyhow!("no such worktree"))?
439    }
440
441    pub fn start_language_server(
442        &mut self,
443        project_id: u64,
444        connection_id: ConnectionId,
445        language_server: proto::LanguageServer,
446    ) -> tide::Result<Vec<ConnectionId>> {
447        let project = self
448            .projects
449            .get_mut(&project_id)
450            .ok_or_else(|| anyhow!("no such project"))?;
451        if project.host_connection_id == connection_id {
452            project.language_servers.push(language_server);
453            return Ok(project.connection_ids());
454        }
455
456        Err(anyhow!("no such project"))?
457    }
458
459    pub fn join_project(
460        &mut self,
461        connection_id: ConnectionId,
462        user_id: UserId,
463        project_id: u64,
464    ) -> tide::Result<JoinedProject> {
465        let connection = self
466            .connections
467            .get_mut(&connection_id)
468            .ok_or_else(|| anyhow!("no such connection"))?;
469        let project = self
470            .projects
471            .get_mut(&project_id)
472            .and_then(|project| {
473                if project.has_authorized_user_id(user_id) {
474                    Some(project)
475                } else {
476                    None
477                }
478            })
479            .ok_or_else(|| anyhow!("no such project"))?;
480
481        let share = project.share_mut()?;
482        connection.projects.insert(project_id);
483
484        let mut replica_id = 1;
485        while share.active_replica_ids.contains(&replica_id) {
486            replica_id += 1;
487        }
488        share.active_replica_ids.insert(replica_id);
489        share.guests.insert(connection_id, (replica_id, user_id));
490
491        Ok(JoinedProject {
492            replica_id,
493            project: &self.projects[&project_id],
494        })
495    }
496
497    pub fn leave_project(
498        &mut self,
499        connection_id: ConnectionId,
500        project_id: u64,
501    ) -> tide::Result<LeftProject> {
502        let project = self
503            .projects
504            .get_mut(&project_id)
505            .ok_or_else(|| anyhow!("no such project"))?;
506        let share = project
507            .share
508            .as_mut()
509            .ok_or_else(|| anyhow!("project is not shared"))?;
510        let (replica_id, _) = share
511            .guests
512            .remove(&connection_id)
513            .ok_or_else(|| anyhow!("cannot leave a project before joining it"))?;
514        share.active_replica_ids.remove(&replica_id);
515
516        if let Some(connection) = self.connections.get_mut(&connection_id) {
517            connection.projects.remove(&project_id);
518        }
519
520        let connection_ids = project.connection_ids();
521        let authorized_user_ids = project.authorized_user_ids();
522
523        Ok(LeftProject {
524            connection_ids,
525            authorized_user_ids,
526        })
527    }
528
529    pub fn update_worktree(
530        &mut self,
531        connection_id: ConnectionId,
532        project_id: u64,
533        worktree_id: u64,
534        removed_entries: &[u64],
535        updated_entries: &[proto::Entry],
536    ) -> tide::Result<Vec<ConnectionId>> {
537        let project = self.write_project(project_id, connection_id)?;
538        let worktree = project
539            .share_mut()?
540            .worktrees
541            .get_mut(&worktree_id)
542            .ok_or_else(|| anyhow!("no such worktree"))?;
543        for entry_id in removed_entries {
544            worktree.entries.remove(&entry_id);
545        }
546        for entry in updated_entries {
547            worktree.entries.insert(entry.id, entry.clone());
548        }
549        let connection_ids = project.connection_ids();
550        Ok(connection_ids)
551    }
552
553    pub fn project_connection_ids(
554        &self,
555        project_id: u64,
556        acting_connection_id: ConnectionId,
557    ) -> tide::Result<Vec<ConnectionId>> {
558        Ok(self
559            .read_project(project_id, acting_connection_id)?
560            .connection_ids())
561    }
562
563    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> tide::Result<Vec<ConnectionId>> {
564        Ok(self
565            .channels
566            .get(&channel_id)
567            .ok_or_else(|| anyhow!("no such channel"))?
568            .connection_ids())
569    }
570
571    #[cfg(test)]
572    pub fn project(&self, project_id: u64) -> Option<&Project> {
573        self.projects.get(&project_id)
574    }
575
576    pub fn read_project(
577        &self,
578        project_id: u64,
579        connection_id: ConnectionId,
580    ) -> tide::Result<&Project> {
581        let project = self
582            .projects
583            .get(&project_id)
584            .ok_or_else(|| anyhow!("no such project"))?;
585        if project.host_connection_id == connection_id
586            || project
587                .share
588                .as_ref()
589                .ok_or_else(|| anyhow!("project is not shared"))?
590                .guests
591                .contains_key(&connection_id)
592        {
593            Ok(project)
594        } else {
595            Err(anyhow!("no such project"))?
596        }
597    }
598
599    fn write_project(
600        &mut self,
601        project_id: u64,
602        connection_id: ConnectionId,
603    ) -> tide::Result<&mut Project> {
604        let project = self
605            .projects
606            .get_mut(&project_id)
607            .ok_or_else(|| anyhow!("no such project"))?;
608        if project.host_connection_id == connection_id
609            || project
610                .share
611                .as_ref()
612                .ok_or_else(|| anyhow!("project is not shared"))?
613                .guests
614                .contains_key(&connection_id)
615        {
616            Ok(project)
617        } else {
618            Err(anyhow!("no such project"))?
619        }
620    }
621
622    #[cfg(test)]
623    pub fn check_invariants(&self) {
624        for (connection_id, connection) in &self.connections {
625            for project_id in &connection.projects {
626                let project = &self.projects.get(&project_id).unwrap();
627                if project.host_connection_id != *connection_id {
628                    assert!(project
629                        .share
630                        .as_ref()
631                        .unwrap()
632                        .guests
633                        .contains_key(connection_id));
634                }
635
636                if let Some(share) = project.share.as_ref() {
637                    for (worktree_id, worktree) in share.worktrees.iter() {
638                        let mut paths = HashMap::default();
639                        for entry in worktree.entries.values() {
640                            let prev_entry = paths.insert(&entry.path, entry);
641                            assert_eq!(
642                                prev_entry,
643                                None,
644                                "worktree {:?}, duplicate path for entries {:?} and {:?}",
645                                worktree_id,
646                                prev_entry.unwrap(),
647                                entry
648                            );
649                        }
650                    }
651                }
652            }
653            for channel_id in &connection.channels {
654                let channel = self.channels.get(channel_id).unwrap();
655                assert!(channel.connection_ids.contains(connection_id));
656            }
657            assert!(self
658                .connections_by_user_id
659                .get(&connection.user_id)
660                .unwrap()
661                .contains(connection_id));
662        }
663
664        for (user_id, connection_ids) in &self.connections_by_user_id {
665            for connection_id in connection_ids {
666                assert_eq!(
667                    self.connections.get(connection_id).unwrap().user_id,
668                    *user_id
669                );
670            }
671        }
672
673        for (project_id, project) in &self.projects {
674            let host_connection = self.connections.get(&project.host_connection_id).unwrap();
675            assert!(host_connection.projects.contains(project_id));
676
677            for authorized_user_ids in project.authorized_user_ids() {
678                let visible_project_ids = self
679                    .visible_projects_by_user_id
680                    .get(&authorized_user_ids)
681                    .unwrap();
682                assert!(visible_project_ids.contains(project_id));
683            }
684
685            if let Some(share) = &project.share {
686                for guest_connection_id in share.guests.keys() {
687                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
688                    assert!(guest_connection.projects.contains(project_id));
689                }
690                assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
691                assert_eq!(
692                    share.active_replica_ids,
693                    share
694                        .guests
695                        .values()
696                        .map(|(replica_id, _)| *replica_id)
697                        .collect::<HashSet<_>>(),
698                );
699            }
700        }
701
702        for (user_id, visible_project_ids) in &self.visible_projects_by_user_id {
703            for project_id in visible_project_ids {
704                let project = self.projects.get(project_id).unwrap();
705                assert!(project.authorized_user_ids().contains(user_id));
706            }
707        }
708
709        for (channel_id, channel) in &self.channels {
710            for connection_id in &channel.connection_ids {
711                let connection = self.connections.get(connection_id).unwrap();
712                assert!(connection.channels.contains(channel_id));
713            }
714        }
715    }
716}
717
718impl Project {
719    pub fn has_authorized_user_id(&self, user_id: UserId) -> bool {
720        self.worktrees
721            .values()
722            .any(|worktree| worktree.authorized_user_ids.contains(&user_id))
723    }
724
725    pub fn authorized_user_ids(&self) -> Vec<UserId> {
726        let mut ids = self
727            .worktrees
728            .values()
729            .flat_map(|worktree| worktree.authorized_user_ids.iter())
730            .copied()
731            .collect::<Vec<_>>();
732        ids.sort_unstable();
733        ids.dedup();
734        ids
735    }
736
737    pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
738        if let Some(share) = &self.share {
739            share.guests.keys().copied().collect()
740        } else {
741            Vec::new()
742        }
743    }
744
745    pub fn connection_ids(&self) -> Vec<ConnectionId> {
746        if let Some(share) = &self.share {
747            share
748                .guests
749                .keys()
750                .copied()
751                .chain(Some(self.host_connection_id))
752                .collect()
753        } else {
754            vec![self.host_connection_id]
755        }
756    }
757
758    pub fn share(&self) -> tide::Result<&ProjectShare> {
759        Ok(self
760            .share
761            .as_ref()
762            .ok_or_else(|| anyhow!("worktree is not shared"))?)
763    }
764
765    fn share_mut(&mut self) -> tide::Result<&mut ProjectShare> {
766        Ok(self
767            .share
768            .as_mut()
769            .ok_or_else(|| anyhow!("worktree is not shared"))?)
770    }
771}
772
773impl Channel {
774    fn connection_ids(&self) -> Vec<ConnectionId> {
775        self.connection_ids.iter().copied().collect()
776    }
777}