store.rs

  1use crate::db::{ChannelId, UserId};
  2use anyhow::anyhow;
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use std::{collections::hash_map, path::PathBuf};
  6
  7#[derive(Default)]
  8pub struct Store {
  9    connections: HashMap<ConnectionId, ConnectionState>,
 10    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 11    projects: HashMap<u64, Project>,
 12    visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
 13    channels: HashMap<ChannelId, Channel>,
 14    next_project_id: u64,
 15}
 16
 17struct ConnectionState {
 18    user_id: UserId,
 19    projects: HashSet<u64>,
 20    channels: HashSet<ChannelId>,
 21}
 22
 23pub struct Project {
 24    pub host_connection_id: ConnectionId,
 25    pub host_user_id: UserId,
 26    pub share: Option<ProjectShare>,
 27    pub worktrees: HashMap<u64, Worktree>,
 28    pub language_servers: Vec<proto::LanguageServer>,
 29}
 30
 31pub struct Worktree {
 32    pub authorized_user_ids: Vec<UserId>,
 33    pub root_name: String,
 34    pub visible: bool,
 35}
 36
 37#[derive(Default)]
 38pub struct ProjectShare {
 39    pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
 40    pub active_replica_ids: HashSet<ReplicaId>,
 41    pub worktrees: HashMap<u64, WorktreeShare>,
 42}
 43
 44#[derive(Default)]
 45pub struct WorktreeShare {
 46    pub entries: HashMap<u64, proto::Entry>,
 47    pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
 48}
 49
 50#[derive(Default)]
 51pub struct Channel {
 52    pub connection_ids: HashSet<ConnectionId>,
 53}
 54
 55pub type ReplicaId = u16;
 56
 57#[derive(Default)]
 58pub struct RemovedConnectionState {
 59    pub hosted_projects: HashMap<u64, Project>,
 60    pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
 61    pub contact_ids: HashSet<UserId>,
 62}
 63
 64pub struct JoinedProject<'a> {
 65    pub replica_id: ReplicaId,
 66    pub project: &'a Project,
 67}
 68
 69pub struct UnsharedProject {
 70    pub connection_ids: Vec<ConnectionId>,
 71    pub authorized_user_ids: Vec<UserId>,
 72}
 73
 74pub struct LeftProject {
 75    pub connection_ids: Vec<ConnectionId>,
 76    pub authorized_user_ids: Vec<UserId>,
 77}
 78
 79impl Store {
 80    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 81        self.connections.insert(
 82            connection_id,
 83            ConnectionState {
 84                user_id,
 85                projects: Default::default(),
 86                channels: Default::default(),
 87            },
 88        );
 89        self.connections_by_user_id
 90            .entry(user_id)
 91            .or_default()
 92            .insert(connection_id);
 93    }
 94
 95    pub fn remove_connection(
 96        &mut self,
 97        connection_id: ConnectionId,
 98    ) -> tide::Result<RemovedConnectionState> {
 99        let connection = if let Some(connection) = self.connections.remove(&connection_id) {
100            connection
101        } else {
102            return Err(anyhow!("no such connection"))?;
103        };
104
105        for channel_id in &connection.channels {
106            if let Some(channel) = self.channels.get_mut(&channel_id) {
107                channel.connection_ids.remove(&connection_id);
108            }
109        }
110
111        let user_connections = self
112            .connections_by_user_id
113            .get_mut(&connection.user_id)
114            .unwrap();
115        user_connections.remove(&connection_id);
116        if user_connections.is_empty() {
117            self.connections_by_user_id.remove(&connection.user_id);
118        }
119
120        let mut result = RemovedConnectionState::default();
121        for project_id in connection.projects.clone() {
122            if let Ok(project) = self.unregister_project(project_id, connection_id) {
123                result.contact_ids.extend(project.authorized_user_ids());
124                result.hosted_projects.insert(project_id, project);
125            } else if let Ok(project) = self.leave_project(connection_id, project_id) {
126                result
127                    .guest_project_ids
128                    .insert(project_id, project.connection_ids);
129                result.contact_ids.extend(project.authorized_user_ids);
130            }
131        }
132
133        Ok(result)
134    }
135
136    #[cfg(test)]
137    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
138        self.channels.get(&id)
139    }
140
141    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
142        if let Some(connection) = self.connections.get_mut(&connection_id) {
143            connection.channels.insert(channel_id);
144            self.channels
145                .entry(channel_id)
146                .or_default()
147                .connection_ids
148                .insert(connection_id);
149        }
150    }
151
152    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
153        if let Some(connection) = self.connections.get_mut(&connection_id) {
154            connection.channels.remove(&channel_id);
155            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
156                entry.get_mut().connection_ids.remove(&connection_id);
157                if entry.get_mut().connection_ids.is_empty() {
158                    entry.remove();
159                }
160            }
161        }
162    }
163
164    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
165        Ok(self
166            .connections
167            .get(&connection_id)
168            .ok_or_else(|| anyhow!("unknown connection"))?
169            .user_id)
170    }
171
172    pub fn connection_ids_for_user<'a>(
173        &'a self,
174        user_id: UserId,
175    ) -> impl 'a + Iterator<Item = ConnectionId> {
176        self.connections_by_user_id
177            .get(&user_id)
178            .into_iter()
179            .flatten()
180            .copied()
181    }
182
183    pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
184        let mut contacts = HashMap::default();
185        for project_id in self
186            .visible_projects_by_user_id
187            .get(&user_id)
188            .unwrap_or(&HashSet::default())
189        {
190            let project = &self.projects[project_id];
191
192            let mut guests = HashSet::default();
193            if let Ok(share) = project.share() {
194                for guest_connection_id in share.guests.keys() {
195                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
196                        guests.insert(user_id.to_proto());
197                    }
198                }
199            }
200
201            if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
202                let mut worktree_root_names = project
203                    .worktrees
204                    .values()
205                    .filter(|worktree| worktree.visible)
206                    .map(|worktree| worktree.root_name.clone())
207                    .collect::<Vec<_>>();
208                worktree_root_names.sort_unstable();
209                contacts
210                    .entry(host_user_id)
211                    .or_insert_with(|| proto::Contact {
212                        user_id: host_user_id.to_proto(),
213                        projects: Vec::new(),
214                    })
215                    .projects
216                    .push(proto::ProjectMetadata {
217                        id: *project_id,
218                        worktree_root_names,
219                        is_shared: project.share.is_some(),
220                        guests: guests.into_iter().collect(),
221                    });
222            }
223        }
224
225        contacts.into_values().collect()
226    }
227
228    pub fn register_project(
229        &mut self,
230        host_connection_id: ConnectionId,
231        host_user_id: UserId,
232    ) -> u64 {
233        let project_id = self.next_project_id;
234        self.projects.insert(
235            project_id,
236            Project {
237                host_connection_id,
238                host_user_id,
239                share: None,
240                worktrees: Default::default(),
241                language_servers: Default::default(),
242            },
243        );
244        if let Some(connection) = self.connections.get_mut(&host_connection_id) {
245            connection.projects.insert(project_id);
246        }
247        self.next_project_id += 1;
248        project_id
249    }
250
251    pub fn register_worktree(
252        &mut self,
253        project_id: u64,
254        worktree_id: u64,
255        connection_id: ConnectionId,
256        worktree: Worktree,
257    ) -> tide::Result<()> {
258        let project = self
259            .projects
260            .get_mut(&project_id)
261            .ok_or_else(|| anyhow!("no such project"))?;
262        if project.host_connection_id == connection_id {
263            for authorized_user_id in &worktree.authorized_user_ids {
264                self.visible_projects_by_user_id
265                    .entry(*authorized_user_id)
266                    .or_default()
267                    .insert(project_id);
268            }
269
270            project.worktrees.insert(worktree_id, worktree);
271            if let Ok(share) = project.share_mut() {
272                share.worktrees.insert(worktree_id, Default::default());
273            }
274
275            Ok(())
276        } else {
277            Err(anyhow!("no such project"))?
278        }
279    }
280
281    pub fn unregister_project(
282        &mut self,
283        project_id: u64,
284        connection_id: ConnectionId,
285    ) -> tide::Result<Project> {
286        match self.projects.entry(project_id) {
287            hash_map::Entry::Occupied(e) => {
288                if e.get().host_connection_id == connection_id {
289                    for user_id in e.get().authorized_user_ids() {
290                        if let hash_map::Entry::Occupied(mut projects) =
291                            self.visible_projects_by_user_id.entry(user_id)
292                        {
293                            projects.get_mut().remove(&project_id);
294                        }
295                    }
296
297                    let project = e.remove();
298
299                    if let Some(host_connection) = self.connections.get_mut(&connection_id) {
300                        host_connection.projects.remove(&project_id);
301                    }
302
303                    if let Some(share) = &project.share {
304                        for guest_connection in share.guests.keys() {
305                            if let Some(connection) = self.connections.get_mut(&guest_connection) {
306                                connection.projects.remove(&project_id);
307                            }
308                        }
309                    }
310
311                    Ok(project)
312                } else {
313                    Err(anyhow!("no such project"))?
314                }
315            }
316            hash_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?,
317        }
318    }
319
320    pub fn unregister_worktree(
321        &mut self,
322        project_id: u64,
323        worktree_id: u64,
324        acting_connection_id: ConnectionId,
325    ) -> tide::Result<(Worktree, Vec<ConnectionId>)> {
326        let project = self
327            .projects
328            .get_mut(&project_id)
329            .ok_or_else(|| anyhow!("no such project"))?;
330        if project.host_connection_id != acting_connection_id {
331            Err(anyhow!("not your worktree"))?;
332        }
333
334        let worktree = project
335            .worktrees
336            .remove(&worktree_id)
337            .ok_or_else(|| anyhow!("no such worktree"))?;
338
339        let mut guest_connection_ids = Vec::new();
340        if let Ok(share) = project.share_mut() {
341            guest_connection_ids.extend(share.guests.keys());
342            share.worktrees.remove(&worktree_id);
343        }
344
345        for authorized_user_id in &worktree.authorized_user_ids {
346            if let Some(visible_projects) =
347                self.visible_projects_by_user_id.get_mut(authorized_user_id)
348            {
349                if !project.has_authorized_user_id(*authorized_user_id) {
350                    visible_projects.remove(&project_id);
351                }
352            }
353        }
354
355        Ok((worktree, guest_connection_ids))
356    }
357
358    pub fn share_project(&mut self, project_id: u64, connection_id: ConnectionId) -> bool {
359        if let Some(project) = self.projects.get_mut(&project_id) {
360            if project.host_connection_id == connection_id {
361                let mut share = ProjectShare::default();
362                for worktree_id in project.worktrees.keys() {
363                    share.worktrees.insert(*worktree_id, Default::default());
364                }
365                project.share = Some(share);
366                return true;
367            }
368        }
369        false
370    }
371
372    pub fn unshare_project(
373        &mut self,
374        project_id: u64,
375        acting_connection_id: ConnectionId,
376    ) -> tide::Result<UnsharedProject> {
377        let project = if let Some(project) = self.projects.get_mut(&project_id) {
378            project
379        } else {
380            return Err(anyhow!("no such project"))?;
381        };
382
383        if project.host_connection_id != acting_connection_id {
384            return Err(anyhow!("not your project"))?;
385        }
386
387        let connection_ids = project.connection_ids();
388        let authorized_user_ids = project.authorized_user_ids();
389        if let Some(share) = project.share.take() {
390            for connection_id in share.guests.into_keys() {
391                if let Some(connection) = self.connections.get_mut(&connection_id) {
392                    connection.projects.remove(&project_id);
393                }
394            }
395
396            Ok(UnsharedProject {
397                connection_ids,
398                authorized_user_ids,
399            })
400        } else {
401            Err(anyhow!("project is not shared"))?
402        }
403    }
404
405    pub fn update_diagnostic_summary(
406        &mut self,
407        project_id: u64,
408        worktree_id: u64,
409        connection_id: ConnectionId,
410        summary: proto::DiagnosticSummary,
411    ) -> tide::Result<Vec<ConnectionId>> {
412        let project = self
413            .projects
414            .get_mut(&project_id)
415            .ok_or_else(|| anyhow!("no such project"))?;
416        if project.host_connection_id == connection_id {
417            let worktree = project
418                .share_mut()?
419                .worktrees
420                .get_mut(&worktree_id)
421                .ok_or_else(|| anyhow!("no such worktree"))?;
422            worktree
423                .diagnostic_summaries
424                .insert(summary.path.clone().into(), summary);
425            return Ok(project.connection_ids());
426        }
427
428        Err(anyhow!("no such worktree"))?
429    }
430
431    pub fn start_language_server(
432        &mut self,
433        project_id: u64,
434        connection_id: ConnectionId,
435        language_server: proto::LanguageServer,
436    ) -> tide::Result<Vec<ConnectionId>> {
437        let project = self
438            .projects
439            .get_mut(&project_id)
440            .ok_or_else(|| anyhow!("no such project"))?;
441        if project.host_connection_id == connection_id {
442            project.language_servers.push(language_server);
443            return Ok(project.connection_ids());
444        }
445
446        Err(anyhow!("no such project"))?
447    }
448
449    pub fn join_project(
450        &mut self,
451        connection_id: ConnectionId,
452        user_id: UserId,
453        project_id: u64,
454    ) -> tide::Result<JoinedProject> {
455        let connection = self
456            .connections
457            .get_mut(&connection_id)
458            .ok_or_else(|| anyhow!("no such connection"))?;
459        let project = self
460            .projects
461            .get_mut(&project_id)
462            .and_then(|project| {
463                if project.has_authorized_user_id(user_id) {
464                    Some(project)
465                } else {
466                    None
467                }
468            })
469            .ok_or_else(|| anyhow!("no such project"))?;
470
471        let share = project.share_mut()?;
472        connection.projects.insert(project_id);
473
474        let mut replica_id = 1;
475        while share.active_replica_ids.contains(&replica_id) {
476            replica_id += 1;
477        }
478        share.active_replica_ids.insert(replica_id);
479        share.guests.insert(connection_id, (replica_id, user_id));
480
481        Ok(JoinedProject {
482            replica_id,
483            project: &self.projects[&project_id],
484        })
485    }
486
487    pub fn leave_project(
488        &mut self,
489        connection_id: ConnectionId,
490        project_id: u64,
491    ) -> tide::Result<LeftProject> {
492        let project = self
493            .projects
494            .get_mut(&project_id)
495            .ok_or_else(|| anyhow!("no such project"))?;
496        let share = project
497            .share
498            .as_mut()
499            .ok_or_else(|| anyhow!("project is not shared"))?;
500        let (replica_id, _) = share
501            .guests
502            .remove(&connection_id)
503            .ok_or_else(|| anyhow!("cannot leave a project before joining it"))?;
504        share.active_replica_ids.remove(&replica_id);
505
506        if let Some(connection) = self.connections.get_mut(&connection_id) {
507            connection.projects.remove(&project_id);
508        }
509
510        let connection_ids = project.connection_ids();
511        let authorized_user_ids = project.authorized_user_ids();
512
513        Ok(LeftProject {
514            connection_ids,
515            authorized_user_ids,
516        })
517    }
518
519    pub fn update_worktree(
520        &mut self,
521        connection_id: ConnectionId,
522        project_id: u64,
523        worktree_id: u64,
524        removed_entries: &[u64],
525        updated_entries: &[proto::Entry],
526    ) -> tide::Result<Vec<ConnectionId>> {
527        let project = self.write_project(project_id, connection_id)?;
528        let worktree = project
529            .share_mut()?
530            .worktrees
531            .get_mut(&worktree_id)
532            .ok_or_else(|| anyhow!("no such worktree"))?;
533        for entry_id in removed_entries {
534            worktree.entries.remove(&entry_id);
535        }
536        for entry in updated_entries {
537            worktree.entries.insert(entry.id, entry.clone());
538        }
539        let connection_ids = project.connection_ids();
540        Ok(connection_ids)
541    }
542
543    pub fn project_connection_ids(
544        &self,
545        project_id: u64,
546        acting_connection_id: ConnectionId,
547    ) -> tide::Result<Vec<ConnectionId>> {
548        Ok(self
549            .read_project(project_id, acting_connection_id)?
550            .connection_ids())
551    }
552
553    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> tide::Result<Vec<ConnectionId>> {
554        Ok(self
555            .channels
556            .get(&channel_id)
557            .ok_or_else(|| anyhow!("no such channel"))?
558            .connection_ids())
559    }
560
561    #[cfg(test)]
562    pub fn project(&self, project_id: u64) -> Option<&Project> {
563        self.projects.get(&project_id)
564    }
565
566    pub fn read_project(
567        &self,
568        project_id: u64,
569        connection_id: ConnectionId,
570    ) -> tide::Result<&Project> {
571        let project = self
572            .projects
573            .get(&project_id)
574            .ok_or_else(|| anyhow!("no such project"))?;
575        if project.host_connection_id == connection_id
576            || project
577                .share
578                .as_ref()
579                .ok_or_else(|| anyhow!("project is not shared"))?
580                .guests
581                .contains_key(&connection_id)
582        {
583            Ok(project)
584        } else {
585            Err(anyhow!("no such project"))?
586        }
587    }
588
589    fn write_project(
590        &mut self,
591        project_id: u64,
592        connection_id: ConnectionId,
593    ) -> tide::Result<&mut Project> {
594        let project = self
595            .projects
596            .get_mut(&project_id)
597            .ok_or_else(|| anyhow!("no such project"))?;
598        if project.host_connection_id == connection_id
599            || project
600                .share
601                .as_ref()
602                .ok_or_else(|| anyhow!("project is not shared"))?
603                .guests
604                .contains_key(&connection_id)
605        {
606            Ok(project)
607        } else {
608            Err(anyhow!("no such project"))?
609        }
610    }
611
612    #[cfg(test)]
613    pub fn check_invariants(&self) {
614        for (connection_id, connection) in &self.connections {
615            for project_id in &connection.projects {
616                let project = &self.projects.get(&project_id).unwrap();
617                if project.host_connection_id != *connection_id {
618                    assert!(project
619                        .share
620                        .as_ref()
621                        .unwrap()
622                        .guests
623                        .contains_key(connection_id));
624                }
625
626                if let Some(share) = project.share.as_ref() {
627                    for (worktree_id, worktree) in share.worktrees.iter() {
628                        let mut paths = HashMap::default();
629                        for entry in worktree.entries.values() {
630                            let prev_entry = paths.insert(&entry.path, entry);
631                            assert_eq!(
632                                prev_entry,
633                                None,
634                                "worktree {:?}, duplicate path for entries {:?} and {:?}",
635                                worktree_id,
636                                prev_entry.unwrap(),
637                                entry
638                            );
639                        }
640                    }
641                }
642            }
643            for channel_id in &connection.channels {
644                let channel = self.channels.get(channel_id).unwrap();
645                assert!(channel.connection_ids.contains(connection_id));
646            }
647            assert!(self
648                .connections_by_user_id
649                .get(&connection.user_id)
650                .unwrap()
651                .contains(connection_id));
652        }
653
654        for (user_id, connection_ids) in &self.connections_by_user_id {
655            for connection_id in connection_ids {
656                assert_eq!(
657                    self.connections.get(connection_id).unwrap().user_id,
658                    *user_id
659                );
660            }
661        }
662
663        for (project_id, project) in &self.projects {
664            let host_connection = self.connections.get(&project.host_connection_id).unwrap();
665            assert!(host_connection.projects.contains(project_id));
666
667            for authorized_user_ids in project.authorized_user_ids() {
668                let visible_project_ids = self
669                    .visible_projects_by_user_id
670                    .get(&authorized_user_ids)
671                    .unwrap();
672                assert!(visible_project_ids.contains(project_id));
673            }
674
675            if let Some(share) = &project.share {
676                for guest_connection_id in share.guests.keys() {
677                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
678                    assert!(guest_connection.projects.contains(project_id));
679                }
680                assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
681                assert_eq!(
682                    share.active_replica_ids,
683                    share
684                        .guests
685                        .values()
686                        .map(|(replica_id, _)| *replica_id)
687                        .collect::<HashSet<_>>(),
688                );
689            }
690        }
691
692        for (user_id, visible_project_ids) in &self.visible_projects_by_user_id {
693            for project_id in visible_project_ids {
694                let project = self.projects.get(project_id).unwrap();
695                assert!(project.authorized_user_ids().contains(user_id));
696            }
697        }
698
699        for (channel_id, channel) in &self.channels {
700            for connection_id in &channel.connection_ids {
701                let connection = self.connections.get(connection_id).unwrap();
702                assert!(connection.channels.contains(channel_id));
703            }
704        }
705    }
706}
707
708impl Project {
709    pub fn has_authorized_user_id(&self, user_id: UserId) -> bool {
710        self.worktrees
711            .values()
712            .any(|worktree| worktree.authorized_user_ids.contains(&user_id))
713    }
714
715    pub fn authorized_user_ids(&self) -> Vec<UserId> {
716        let mut ids = self
717            .worktrees
718            .values()
719            .flat_map(|worktree| worktree.authorized_user_ids.iter())
720            .copied()
721            .collect::<Vec<_>>();
722        ids.sort_unstable();
723        ids.dedup();
724        ids
725    }
726
727    pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
728        if let Some(share) = &self.share {
729            share.guests.keys().copied().collect()
730        } else {
731            Vec::new()
732        }
733    }
734
735    pub fn connection_ids(&self) -> Vec<ConnectionId> {
736        if let Some(share) = &self.share {
737            share
738                .guests
739                .keys()
740                .copied()
741                .chain(Some(self.host_connection_id))
742                .collect()
743        } else {
744            vec![self.host_connection_id]
745        }
746    }
747
748    pub fn share(&self) -> tide::Result<&ProjectShare> {
749        Ok(self
750            .share
751            .as_ref()
752            .ok_or_else(|| anyhow!("worktree is not shared"))?)
753    }
754
755    fn share_mut(&mut self) -> tide::Result<&mut ProjectShare> {
756        Ok(self
757            .share
758            .as_mut()
759            .ok_or_else(|| anyhow!("worktree is not shared"))?)
760    }
761}
762
763impl Channel {
764    fn connection_ids(&self) -> Vec<ConnectionId> {
765        self.connection_ids.iter().copied().collect()
766    }
767}