store.rs

  1use crate::db::{ChannelId, UserId};
  2use anyhow::anyhow;
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use std::{collections::hash_map, path::PathBuf};
  6
  7#[derive(Default)]
  8pub struct Store {
  9    connections: HashMap<ConnectionId, ConnectionState>,
 10    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 11    projects: HashMap<u64, Project>,
 12    visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
 13    channels: HashMap<ChannelId, Channel>,
 14    next_project_id: u64,
 15}
 16
 17struct ConnectionState {
 18    user_id: UserId,
 19    projects: HashSet<u64>,
 20    channels: HashSet<ChannelId>,
 21}
 22
 23pub struct Project {
 24    pub host_connection_id: ConnectionId,
 25    pub host_user_id: UserId,
 26    pub share: Option<ProjectShare>,
 27    pub worktrees: HashMap<u64, Worktree>,
 28    pub language_servers: Vec<proto::LanguageServer>,
 29}
 30
 31pub struct Worktree {
 32    pub authorized_user_ids: Vec<UserId>,
 33    pub root_name: String,
 34    pub visible: bool,
 35}
 36
 37#[derive(Default)]
 38pub struct ProjectShare {
 39    pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
 40    pub active_replica_ids: HashSet<ReplicaId>,
 41    pub worktrees: HashMap<u64, WorktreeShare>,
 42}
 43
 44#[derive(Default)]
 45pub struct WorktreeShare {
 46    pub entries: HashMap<u64, proto::Entry>,
 47    pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
 48}
 49
 50#[derive(Default)]
 51pub struct Channel {
 52    pub connection_ids: HashSet<ConnectionId>,
 53}
 54
 55pub type ReplicaId = u16;
 56
 57#[derive(Default)]
 58pub struct RemovedConnectionState {
 59    pub hosted_projects: HashMap<u64, Project>,
 60    pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
 61    pub contact_ids: HashSet<UserId>,
 62}
 63
 64pub struct JoinedProject<'a> {
 65    pub replica_id: ReplicaId,
 66    pub project: &'a Project,
 67}
 68
 69pub struct UnsharedProject {
 70    pub connection_ids: Vec<ConnectionId>,
 71    pub authorized_user_ids: Vec<UserId>,
 72}
 73
 74pub struct LeftProject {
 75    pub connection_ids: Vec<ConnectionId>,
 76    pub authorized_user_ids: Vec<UserId>,
 77}
 78
 79impl Store {
 80    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 81        self.connections.insert(
 82            connection_id,
 83            ConnectionState {
 84                user_id,
 85                projects: Default::default(),
 86                channels: Default::default(),
 87            },
 88        );
 89        self.connections_by_user_id
 90            .entry(user_id)
 91            .or_default()
 92            .insert(connection_id);
 93    }
 94
 95    pub fn remove_connection(
 96        &mut self,
 97        connection_id: ConnectionId,
 98    ) -> tide::Result<RemovedConnectionState> {
 99        let connection = if let Some(connection) = self.connections.remove(&connection_id) {
100            connection
101        } else {
102            return Err(anyhow!("no such connection"))?;
103        };
104
105        for channel_id in &connection.channels {
106            if let Some(channel) = self.channels.get_mut(&channel_id) {
107                channel.connection_ids.remove(&connection_id);
108            }
109        }
110
111        let user_connections = self
112            .connections_by_user_id
113            .get_mut(&connection.user_id)
114            .unwrap();
115        user_connections.remove(&connection_id);
116        if user_connections.is_empty() {
117            self.connections_by_user_id.remove(&connection.user_id);
118        }
119
120        let mut result = RemovedConnectionState::default();
121        for project_id in connection.projects.clone() {
122            if let Ok(project) = self.unregister_project(project_id, connection_id) {
123                result.contact_ids.extend(project.authorized_user_ids());
124                result.hosted_projects.insert(project_id, project);
125            } else if let Ok(project) = self.leave_project(connection_id, project_id) {
126                result
127                    .guest_project_ids
128                    .insert(project_id, project.connection_ids);
129                result.contact_ids.extend(project.authorized_user_ids);
130            }
131        }
132
133        #[cfg(test)]
134        self.check_invariants();
135
136        Ok(result)
137    }
138
139    #[cfg(test)]
140    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
141        self.channels.get(&id)
142    }
143
144    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
145        if let Some(connection) = self.connections.get_mut(&connection_id) {
146            connection.channels.insert(channel_id);
147            self.channels
148                .entry(channel_id)
149                .or_default()
150                .connection_ids
151                .insert(connection_id);
152        }
153    }
154
155    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
156        if let Some(connection) = self.connections.get_mut(&connection_id) {
157            connection.channels.remove(&channel_id);
158            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
159                entry.get_mut().connection_ids.remove(&connection_id);
160                if entry.get_mut().connection_ids.is_empty() {
161                    entry.remove();
162                }
163            }
164        }
165    }
166
167    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
168        Ok(self
169            .connections
170            .get(&connection_id)
171            .ok_or_else(|| anyhow!("unknown connection"))?
172            .user_id)
173    }
174
175    pub fn connection_ids_for_user<'a>(
176        &'a self,
177        user_id: UserId,
178    ) -> impl 'a + Iterator<Item = ConnectionId> {
179        self.connections_by_user_id
180            .get(&user_id)
181            .into_iter()
182            .flatten()
183            .copied()
184    }
185
186    pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
187        let mut contacts = HashMap::default();
188        for project_id in self
189            .visible_projects_by_user_id
190            .get(&user_id)
191            .unwrap_or(&HashSet::default())
192        {
193            let project = &self.projects[project_id];
194
195            let mut guests = HashSet::default();
196            if let Ok(share) = project.share() {
197                for guest_connection_id in share.guests.keys() {
198                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
199                        guests.insert(user_id.to_proto());
200                    }
201                }
202            }
203
204            if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
205                let mut worktree_root_names = project
206                    .worktrees
207                    .values()
208                    .filter(|worktree| worktree.visible)
209                    .map(|worktree| worktree.root_name.clone())
210                    .collect::<Vec<_>>();
211                worktree_root_names.sort_unstable();
212                contacts
213                    .entry(host_user_id)
214                    .or_insert_with(|| proto::Contact {
215                        user_id: host_user_id.to_proto(),
216                        projects: Vec::new(),
217                    })
218                    .projects
219                    .push(proto::ProjectMetadata {
220                        id: *project_id,
221                        worktree_root_names,
222                        is_shared: project.share.is_some(),
223                        guests: guests.into_iter().collect(),
224                    });
225            }
226        }
227
228        contacts.into_values().collect()
229    }
230
231    pub fn register_project(
232        &mut self,
233        host_connection_id: ConnectionId,
234        host_user_id: UserId,
235    ) -> u64 {
236        let project_id = self.next_project_id;
237        self.projects.insert(
238            project_id,
239            Project {
240                host_connection_id,
241                host_user_id,
242                share: None,
243                worktrees: Default::default(),
244                language_servers: Default::default(),
245            },
246        );
247        self.next_project_id += 1;
248        project_id
249    }
250
251    pub fn register_worktree(
252        &mut self,
253        project_id: u64,
254        worktree_id: u64,
255        connection_id: ConnectionId,
256        worktree: Worktree,
257    ) -> tide::Result<()> {
258        let project = self
259            .projects
260            .get_mut(&project_id)
261            .ok_or_else(|| anyhow!("no such project"))?;
262        if project.host_connection_id == connection_id {
263            for authorized_user_id in &worktree.authorized_user_ids {
264                self.visible_projects_by_user_id
265                    .entry(*authorized_user_id)
266                    .or_default()
267                    .insert(project_id);
268            }
269            if let Some(connection) = self.connections.get_mut(&project.host_connection_id) {
270                connection.projects.insert(project_id);
271            }
272            project.worktrees.insert(worktree_id, worktree);
273            if let Ok(share) = project.share_mut() {
274                share.worktrees.insert(worktree_id, Default::default());
275            }
276
277            #[cfg(test)]
278            self.check_invariants();
279            Ok(())
280        } else {
281            Err(anyhow!("no such project"))?
282        }
283    }
284
285    pub fn unregister_project(
286        &mut self,
287        project_id: u64,
288        connection_id: ConnectionId,
289    ) -> tide::Result<Project> {
290        match self.projects.entry(project_id) {
291            hash_map::Entry::Occupied(e) => {
292                if e.get().host_connection_id == connection_id {
293                    for user_id in e.get().authorized_user_ids() {
294                        if let hash_map::Entry::Occupied(mut projects) =
295                            self.visible_projects_by_user_id.entry(user_id)
296                        {
297                            projects.get_mut().remove(&project_id);
298                        }
299                    }
300
301                    let project = e.remove();
302
303                    if let Some(host_connection) = self.connections.get_mut(&connection_id) {
304                        host_connection.projects.remove(&project_id);
305                    }
306
307                    if let Some(share) = &project.share {
308                        for guest_connection in share.guests.keys() {
309                            if let Some(connection) = self.connections.get_mut(&guest_connection) {
310                                connection.projects.remove(&project_id);
311                            }
312                        }
313                    }
314
315                    #[cfg(test)]
316                    self.check_invariants();
317                    Ok(project)
318                } else {
319                    Err(anyhow!("no such project"))?
320                }
321            }
322            hash_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?,
323        }
324    }
325
326    pub fn unregister_worktree(
327        &mut self,
328        project_id: u64,
329        worktree_id: u64,
330        acting_connection_id: ConnectionId,
331    ) -> tide::Result<(Worktree, Vec<ConnectionId>)> {
332        let project = self
333            .projects
334            .get_mut(&project_id)
335            .ok_or_else(|| anyhow!("no such project"))?;
336        if project.host_connection_id != acting_connection_id {
337            Err(anyhow!("not your worktree"))?;
338        }
339
340        let worktree = project
341            .worktrees
342            .remove(&worktree_id)
343            .ok_or_else(|| anyhow!("no such worktree"))?;
344
345        let mut guest_connection_ids = Vec::new();
346        if let Ok(share) = project.share_mut() {
347            guest_connection_ids.extend(share.guests.keys());
348            share.worktrees.remove(&worktree_id);
349        }
350
351        for authorized_user_id in &worktree.authorized_user_ids {
352            if let Some(visible_projects) =
353                self.visible_projects_by_user_id.get_mut(authorized_user_id)
354            {
355                if !project.has_authorized_user_id(*authorized_user_id) {
356                    visible_projects.remove(&project_id);
357                }
358            }
359        }
360
361        #[cfg(test)]
362        self.check_invariants();
363
364        Ok((worktree, guest_connection_ids))
365    }
366
367    pub fn share_project(&mut self, project_id: u64, connection_id: ConnectionId) -> bool {
368        if let Some(project) = self.projects.get_mut(&project_id) {
369            if project.host_connection_id == connection_id {
370                let mut share = ProjectShare::default();
371                for worktree_id in project.worktrees.keys() {
372                    share.worktrees.insert(*worktree_id, Default::default());
373                }
374                project.share = Some(share);
375                return true;
376            }
377        }
378        false
379    }
380
381    pub fn unshare_project(
382        &mut self,
383        project_id: u64,
384        acting_connection_id: ConnectionId,
385    ) -> tide::Result<UnsharedProject> {
386        let project = if let Some(project) = self.projects.get_mut(&project_id) {
387            project
388        } else {
389            return Err(anyhow!("no such project"))?;
390        };
391
392        if project.host_connection_id != acting_connection_id {
393            return Err(anyhow!("not your project"))?;
394        }
395
396        let connection_ids = project.connection_ids();
397        let authorized_user_ids = project.authorized_user_ids();
398        if let Some(share) = project.share.take() {
399            for connection_id in share.guests.into_keys() {
400                if let Some(connection) = self.connections.get_mut(&connection_id) {
401                    connection.projects.remove(&project_id);
402                }
403            }
404
405            #[cfg(test)]
406            self.check_invariants();
407
408            Ok(UnsharedProject {
409                connection_ids,
410                authorized_user_ids,
411            })
412        } else {
413            Err(anyhow!("project is not shared"))?
414        }
415    }
416
417    pub fn update_diagnostic_summary(
418        &mut self,
419        project_id: u64,
420        worktree_id: u64,
421        connection_id: ConnectionId,
422        summary: proto::DiagnosticSummary,
423    ) -> tide::Result<Vec<ConnectionId>> {
424        let project = self
425            .projects
426            .get_mut(&project_id)
427            .ok_or_else(|| anyhow!("no such project"))?;
428        if project.host_connection_id == connection_id {
429            let worktree = project
430                .share_mut()?
431                .worktrees
432                .get_mut(&worktree_id)
433                .ok_or_else(|| anyhow!("no such worktree"))?;
434            worktree
435                .diagnostic_summaries
436                .insert(summary.path.clone().into(), summary);
437            return Ok(project.connection_ids());
438        }
439
440        Err(anyhow!("no such worktree"))?
441    }
442
443    pub fn start_language_server(
444        &mut self,
445        project_id: u64,
446        connection_id: ConnectionId,
447        language_server: proto::LanguageServer,
448    ) -> tide::Result<Vec<ConnectionId>> {
449        let project = self
450            .projects
451            .get_mut(&project_id)
452            .ok_or_else(|| anyhow!("no such project"))?;
453        if project.host_connection_id == connection_id {
454            project.language_servers.push(language_server);
455            return Ok(project.connection_ids());
456        }
457
458        Err(anyhow!("no such project"))?
459    }
460
461    pub fn join_project(
462        &mut self,
463        connection_id: ConnectionId,
464        user_id: UserId,
465        project_id: u64,
466    ) -> tide::Result<JoinedProject> {
467        let connection = self
468            .connections
469            .get_mut(&connection_id)
470            .ok_or_else(|| anyhow!("no such connection"))?;
471        let project = self
472            .projects
473            .get_mut(&project_id)
474            .and_then(|project| {
475                if project.has_authorized_user_id(user_id) {
476                    Some(project)
477                } else {
478                    None
479                }
480            })
481            .ok_or_else(|| anyhow!("no such project"))?;
482
483        let share = project.share_mut()?;
484        connection.projects.insert(project_id);
485
486        let mut replica_id = 1;
487        while share.active_replica_ids.contains(&replica_id) {
488            replica_id += 1;
489        }
490        share.active_replica_ids.insert(replica_id);
491        share.guests.insert(connection_id, (replica_id, user_id));
492
493        #[cfg(test)]
494        self.check_invariants();
495
496        Ok(JoinedProject {
497            replica_id,
498            project: &self.projects[&project_id],
499        })
500    }
501
502    pub fn leave_project(
503        &mut self,
504        connection_id: ConnectionId,
505        project_id: u64,
506    ) -> tide::Result<LeftProject> {
507        let project = self
508            .projects
509            .get_mut(&project_id)
510            .ok_or_else(|| anyhow!("no such project"))?;
511        let share = project
512            .share
513            .as_mut()
514            .ok_or_else(|| anyhow!("project is not shared"))?;
515        let (replica_id, _) = share
516            .guests
517            .remove(&connection_id)
518            .ok_or_else(|| anyhow!("cannot leave a project before joining it"))?;
519        share.active_replica_ids.remove(&replica_id);
520
521        if let Some(connection) = self.connections.get_mut(&connection_id) {
522            connection.projects.remove(&project_id);
523        }
524
525        let connection_ids = project.connection_ids();
526        let authorized_user_ids = project.authorized_user_ids();
527
528        #[cfg(test)]
529        self.check_invariants();
530
531        Ok(LeftProject {
532            connection_ids,
533            authorized_user_ids,
534        })
535    }
536
537    pub fn update_worktree(
538        &mut self,
539        connection_id: ConnectionId,
540        project_id: u64,
541        worktree_id: u64,
542        removed_entries: &[u64],
543        updated_entries: &[proto::Entry],
544    ) -> tide::Result<Vec<ConnectionId>> {
545        let project = self.write_project(project_id, connection_id)?;
546        let worktree = project
547            .share_mut()?
548            .worktrees
549            .get_mut(&worktree_id)
550            .ok_or_else(|| anyhow!("no such worktree"))?;
551        for entry_id in removed_entries {
552            worktree.entries.remove(&entry_id);
553        }
554        for entry in updated_entries {
555            worktree.entries.insert(entry.id, entry.clone());
556        }
557        let connection_ids = project.connection_ids();
558
559        #[cfg(test)]
560        self.check_invariants();
561
562        Ok(connection_ids)
563    }
564
565    pub fn project_connection_ids(
566        &self,
567        project_id: u64,
568        acting_connection_id: ConnectionId,
569    ) -> tide::Result<Vec<ConnectionId>> {
570        Ok(self
571            .read_project(project_id, acting_connection_id)?
572            .connection_ids())
573    }
574
575    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> tide::Result<Vec<ConnectionId>> {
576        Ok(self
577            .channels
578            .get(&channel_id)
579            .ok_or_else(|| anyhow!("no such channel"))?
580            .connection_ids())
581    }
582
583    #[cfg(test)]
584    pub fn project(&self, project_id: u64) -> Option<&Project> {
585        self.projects.get(&project_id)
586    }
587
588    pub fn read_project(
589        &self,
590        project_id: u64,
591        connection_id: ConnectionId,
592    ) -> tide::Result<&Project> {
593        let project = self
594            .projects
595            .get(&project_id)
596            .ok_or_else(|| anyhow!("no such project"))?;
597        if project.host_connection_id == connection_id
598            || project
599                .share
600                .as_ref()
601                .ok_or_else(|| anyhow!("project is not shared"))?
602                .guests
603                .contains_key(&connection_id)
604        {
605            Ok(project)
606        } else {
607            Err(anyhow!("no such project"))?
608        }
609    }
610
611    fn write_project(
612        &mut self,
613        project_id: u64,
614        connection_id: ConnectionId,
615    ) -> tide::Result<&mut Project> {
616        let project = self
617            .projects
618            .get_mut(&project_id)
619            .ok_or_else(|| anyhow!("no such project"))?;
620        if project.host_connection_id == connection_id
621            || project
622                .share
623                .as_ref()
624                .ok_or_else(|| anyhow!("project is not shared"))?
625                .guests
626                .contains_key(&connection_id)
627        {
628            Ok(project)
629        } else {
630            Err(anyhow!("no such project"))?
631        }
632    }
633
634    #[cfg(test)]
635    fn check_invariants(&self) {
636        for (connection_id, connection) in &self.connections {
637            for project_id in &connection.projects {
638                let project = &self.projects.get(&project_id).unwrap();
639                if project.host_connection_id != *connection_id {
640                    assert!(project
641                        .share
642                        .as_ref()
643                        .unwrap()
644                        .guests
645                        .contains_key(connection_id));
646                }
647
648                if let Some(share) = project.share.as_ref() {
649                    for (worktree_id, worktree) in share.worktrees.iter() {
650                        let mut paths = HashMap::default();
651                        for entry in worktree.entries.values() {
652                            let prev_entry = paths.insert(&entry.path, entry);
653                            assert_eq!(
654                                prev_entry,
655                                None,
656                                "worktree {:?}, duplicate path for entries {:?} and {:?}",
657                                worktree_id,
658                                prev_entry.unwrap(),
659                                entry
660                            );
661                        }
662                    }
663                }
664            }
665            for channel_id in &connection.channels {
666                let channel = self.channels.get(channel_id).unwrap();
667                assert!(channel.connection_ids.contains(connection_id));
668            }
669            assert!(self
670                .connections_by_user_id
671                .get(&connection.user_id)
672                .unwrap()
673                .contains(connection_id));
674        }
675
676        for (user_id, connection_ids) in &self.connections_by_user_id {
677            for connection_id in connection_ids {
678                assert_eq!(
679                    self.connections.get(connection_id).unwrap().user_id,
680                    *user_id
681                );
682            }
683        }
684
685        for (project_id, project) in &self.projects {
686            let host_connection = self.connections.get(&project.host_connection_id).unwrap();
687            assert!(host_connection.projects.contains(project_id));
688
689            for authorized_user_ids in project.authorized_user_ids() {
690                let visible_project_ids = self
691                    .visible_projects_by_user_id
692                    .get(&authorized_user_ids)
693                    .unwrap();
694                assert!(visible_project_ids.contains(project_id));
695            }
696
697            if let Some(share) = &project.share {
698                for guest_connection_id in share.guests.keys() {
699                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
700                    assert!(guest_connection.projects.contains(project_id));
701                }
702                assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
703                assert_eq!(
704                    share.active_replica_ids,
705                    share
706                        .guests
707                        .values()
708                        .map(|(replica_id, _)| *replica_id)
709                        .collect::<HashSet<_>>(),
710                );
711            }
712        }
713
714        for (user_id, visible_project_ids) in &self.visible_projects_by_user_id {
715            for project_id in visible_project_ids {
716                let project = self.projects.get(project_id).unwrap();
717                assert!(project.authorized_user_ids().contains(user_id));
718            }
719        }
720
721        for (channel_id, channel) in &self.channels {
722            for connection_id in &channel.connection_ids {
723                let connection = self.connections.get(connection_id).unwrap();
724                assert!(connection.channels.contains(channel_id));
725            }
726        }
727    }
728}
729
730impl Project {
731    pub fn has_authorized_user_id(&self, user_id: UserId) -> bool {
732        self.worktrees
733            .values()
734            .any(|worktree| worktree.authorized_user_ids.contains(&user_id))
735    }
736
737    pub fn authorized_user_ids(&self) -> Vec<UserId> {
738        let mut ids = self
739            .worktrees
740            .values()
741            .flat_map(|worktree| worktree.authorized_user_ids.iter())
742            .copied()
743            .collect::<Vec<_>>();
744        ids.sort_unstable();
745        ids.dedup();
746        ids
747    }
748
749    pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
750        if let Some(share) = &self.share {
751            share.guests.keys().copied().collect()
752        } else {
753            Vec::new()
754        }
755    }
756
757    pub fn connection_ids(&self) -> Vec<ConnectionId> {
758        if let Some(share) = &self.share {
759            share
760                .guests
761                .keys()
762                .copied()
763                .chain(Some(self.host_connection_id))
764                .collect()
765        } else {
766            vec![self.host_connection_id]
767        }
768    }
769
770    pub fn share(&self) -> tide::Result<&ProjectShare> {
771        Ok(self
772            .share
773            .as_ref()
774            .ok_or_else(|| anyhow!("worktree is not shared"))?)
775    }
776
777    fn share_mut(&mut self) -> tide::Result<&mut ProjectShare> {
778        Ok(self
779            .share
780            .as_mut()
781            .ok_or_else(|| anyhow!("worktree is not shared"))?)
782    }
783}
784
785impl Channel {
786    fn connection_ids(&self) -> Vec<ConnectionId> {
787        self.connection_ids.iter().copied().collect()
788    }
789}