store.rs

  1use crate::db::{self, ChannelId, UserId};
  2use anyhow::{anyhow, Result};
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use std::{collections::hash_map, path::PathBuf};
  6use tracing::instrument;
  7
  8#[derive(Default)]
  9pub struct Store {
 10    connections: HashMap<ConnectionId, ConnectionState>,
 11    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 12    projects: HashMap<u64, Project>,
 13    channels: HashMap<ChannelId, Channel>,
 14    next_project_id: u64,
 15}
 16
 17struct ConnectionState {
 18    user_id: UserId,
 19    projects: HashSet<u64>,
 20    channels: HashSet<ChannelId>,
 21}
 22
 23pub struct Project {
 24    pub host_connection_id: ConnectionId,
 25    pub host_user_id: UserId,
 26    pub share: Option<ProjectShare>,
 27    pub worktrees: HashMap<u64, Worktree>,
 28    pub language_servers: Vec<proto::LanguageServer>,
 29}
 30
 31pub struct Worktree {
 32    pub root_name: String,
 33    pub visible: bool,
 34}
 35
 36#[derive(Default)]
 37pub struct ProjectShare {
 38    pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
 39    pub active_replica_ids: HashSet<ReplicaId>,
 40    pub worktrees: HashMap<u64, WorktreeShare>,
 41}
 42
 43#[derive(Default)]
 44pub struct WorktreeShare {
 45    pub entries: HashMap<u64, proto::Entry>,
 46    pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
 47    pub scan_id: u64,
 48}
 49
 50#[derive(Default)]
 51pub struct Channel {
 52    pub connection_ids: HashSet<ConnectionId>,
 53}
 54
 55pub type ReplicaId = u16;
 56
 57#[derive(Default)]
 58pub struct RemovedConnectionState {
 59    pub user_id: UserId,
 60    pub hosted_projects: HashMap<u64, Project>,
 61    pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
 62    pub contact_ids: HashSet<UserId>,
 63}
 64
 65pub struct JoinedProject<'a> {
 66    pub replica_id: ReplicaId,
 67    pub project: &'a Project,
 68}
 69
 70pub struct SharedProject {}
 71
 72pub struct UnsharedProject {
 73    pub connection_ids: Vec<ConnectionId>,
 74    pub host_user_id: UserId,
 75}
 76
 77pub struct LeftProject {
 78    pub connection_ids: Vec<ConnectionId>,
 79    pub host_user_id: UserId,
 80}
 81
 82#[derive(Copy, Clone)]
 83pub struct Metrics {
 84    pub connections: usize,
 85    pub registered_projects: usize,
 86    pub shared_projects: usize,
 87}
 88
 89impl Store {
 90    pub fn metrics(&self) -> Metrics {
 91        let connections = self.connections.len();
 92        let mut registered_projects = 0;
 93        let mut shared_projects = 0;
 94        for project in self.projects.values() {
 95            registered_projects += 1;
 96            if project.share.is_some() {
 97                shared_projects += 1;
 98            }
 99        }
100
101        Metrics {
102            connections,
103            registered_projects,
104            shared_projects,
105        }
106    }
107
108    #[instrument(skip(self))]
109    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
110        self.connections.insert(
111            connection_id,
112            ConnectionState {
113                user_id,
114                projects: Default::default(),
115                channels: Default::default(),
116            },
117        );
118        self.connections_by_user_id
119            .entry(user_id)
120            .or_default()
121            .insert(connection_id);
122    }
123
124    #[instrument(skip(self))]
125    pub fn remove_connection(
126        &mut self,
127        connection_id: ConnectionId,
128    ) -> Result<RemovedConnectionState> {
129        let connection = if let Some(connection) = self.connections.remove(&connection_id) {
130            connection
131        } else {
132            return Err(anyhow!("no such connection"))?;
133        };
134
135        for channel_id in &connection.channels {
136            if let Some(channel) = self.channels.get_mut(&channel_id) {
137                channel.connection_ids.remove(&connection_id);
138            }
139        }
140
141        let user_connections = self
142            .connections_by_user_id
143            .get_mut(&connection.user_id)
144            .unwrap();
145        user_connections.remove(&connection_id);
146        if user_connections.is_empty() {
147            self.connections_by_user_id.remove(&connection.user_id);
148        }
149
150        let mut result = RemovedConnectionState::default();
151        result.user_id = connection.user_id;
152        for project_id in connection.projects.clone() {
153            if let Ok(project) = self.unregister_project(project_id, connection_id) {
154                result.hosted_projects.insert(project_id, project);
155            } else if let Ok(project) = self.leave_project(connection_id, project_id) {
156                result
157                    .guest_project_ids
158                    .insert(project_id, project.connection_ids);
159            }
160        }
161
162        Ok(result)
163    }
164
165    #[cfg(test)]
166    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
167        self.channels.get(&id)
168    }
169
170    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
171        if let Some(connection) = self.connections.get_mut(&connection_id) {
172            connection.channels.insert(channel_id);
173            self.channels
174                .entry(channel_id)
175                .or_default()
176                .connection_ids
177                .insert(connection_id);
178        }
179    }
180
181    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
182        if let Some(connection) = self.connections.get_mut(&connection_id) {
183            connection.channels.remove(&channel_id);
184            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
185                entry.get_mut().connection_ids.remove(&connection_id);
186                if entry.get_mut().connection_ids.is_empty() {
187                    entry.remove();
188                }
189            }
190        }
191    }
192
193    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> Result<UserId> {
194        Ok(self
195            .connections
196            .get(&connection_id)
197            .ok_or_else(|| anyhow!("unknown connection"))?
198            .user_id)
199    }
200
201    pub fn connection_ids_for_user<'a>(
202        &'a self,
203        user_id: UserId,
204    ) -> impl 'a + Iterator<Item = ConnectionId> {
205        self.connections_by_user_id
206            .get(&user_id)
207            .into_iter()
208            .flatten()
209            .copied()
210    }
211
212    pub fn is_user_online(&self, user_id: UserId) -> bool {
213        !self
214            .connections_by_user_id
215            .get(&user_id)
216            .unwrap_or(&Default::default())
217            .is_empty()
218    }
219
220    pub fn build_initial_contacts_update(
221        &self,
222        contacts: Vec<db::Contact>,
223    ) -> proto::UpdateContacts {
224        let mut update = proto::UpdateContacts::default();
225
226        for contact in contacts {
227            match contact {
228                db::Contact::Accepted {
229                    user_id,
230                    should_notify,
231                } => {
232                    update
233                        .contacts
234                        .push(self.contact_for_user(user_id, should_notify));
235                }
236                db::Contact::Outgoing { user_id } => {
237                    update.outgoing_requests.push(user_id.to_proto())
238                }
239                db::Contact::Incoming {
240                    user_id,
241                    should_notify,
242                } => update
243                    .incoming_requests
244                    .push(proto::IncomingContactRequest {
245                        requester_id: user_id.to_proto(),
246                        should_notify,
247                    }),
248            }
249        }
250
251        update
252    }
253
254    pub fn contact_for_user(&self, user_id: UserId, should_notify: bool) -> proto::Contact {
255        proto::Contact {
256            user_id: user_id.to_proto(),
257            projects: self.project_metadata_for_user(user_id),
258            online: self.is_user_online(user_id),
259            should_notify,
260        }
261    }
262
263    pub fn project_metadata_for_user(&self, user_id: UserId) -> Vec<proto::ProjectMetadata> {
264        let connection_ids = self.connections_by_user_id.get(&user_id);
265        let project_ids = connection_ids.iter().flat_map(|connection_ids| {
266            connection_ids
267                .iter()
268                .filter_map(|connection_id| self.connections.get(connection_id))
269                .flat_map(|connection| connection.projects.iter().copied())
270        });
271
272        let mut metadata = Vec::new();
273        for project_id in project_ids {
274            if let Some(project) = self.projects.get(&project_id) {
275                if project.host_user_id == user_id {
276                    metadata.push(proto::ProjectMetadata {
277                        id: project_id,
278                        is_shared: project.share.is_some(),
279                        worktree_root_names: project
280                            .worktrees
281                            .values()
282                            .map(|worktree| worktree.root_name.clone())
283                            .collect(),
284                        guests: project
285                            .share
286                            .iter()
287                            .flat_map(|share| {
288                                share.guests.values().map(|(_, user_id)| user_id.to_proto())
289                            })
290                            .collect(),
291                    });
292                }
293            }
294        }
295
296        metadata
297    }
298
299    pub fn register_project(
300        &mut self,
301        host_connection_id: ConnectionId,
302        host_user_id: UserId,
303    ) -> u64 {
304        let project_id = self.next_project_id;
305        self.projects.insert(
306            project_id,
307            Project {
308                host_connection_id,
309                host_user_id,
310                share: None,
311                worktrees: Default::default(),
312                language_servers: Default::default(),
313            },
314        );
315        if let Some(connection) = self.connections.get_mut(&host_connection_id) {
316            connection.projects.insert(project_id);
317        }
318        self.next_project_id += 1;
319        project_id
320    }
321
322    pub fn register_worktree(
323        &mut self,
324        project_id: u64,
325        worktree_id: u64,
326        connection_id: ConnectionId,
327        worktree: Worktree,
328    ) -> Result<()> {
329        let project = self
330            .projects
331            .get_mut(&project_id)
332            .ok_or_else(|| anyhow!("no such project"))?;
333        if project.host_connection_id == connection_id {
334            project.worktrees.insert(worktree_id, worktree);
335            if let Ok(share) = project.share_mut() {
336                share.worktrees.insert(worktree_id, Default::default());
337            }
338
339            Ok(())
340        } else {
341            Err(anyhow!("no such project"))?
342        }
343    }
344
345    pub fn unregister_project(
346        &mut self,
347        project_id: u64,
348        connection_id: ConnectionId,
349    ) -> Result<Project> {
350        match self.projects.entry(project_id) {
351            hash_map::Entry::Occupied(e) => {
352                if e.get().host_connection_id == connection_id {
353                    let project = e.remove();
354
355                    if let Some(host_connection) = self.connections.get_mut(&connection_id) {
356                        host_connection.projects.remove(&project_id);
357                    }
358
359                    if let Some(share) = &project.share {
360                        for guest_connection in share.guests.keys() {
361                            if let Some(connection) = self.connections.get_mut(&guest_connection) {
362                                connection.projects.remove(&project_id);
363                            }
364                        }
365                    }
366
367                    Ok(project)
368                } else {
369                    Err(anyhow!("no such project"))?
370                }
371            }
372            hash_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?,
373        }
374    }
375
376    pub fn unregister_worktree(
377        &mut self,
378        project_id: u64,
379        worktree_id: u64,
380        acting_connection_id: ConnectionId,
381    ) -> Result<(Worktree, Vec<ConnectionId>)> {
382        let project = self
383            .projects
384            .get_mut(&project_id)
385            .ok_or_else(|| anyhow!("no such project"))?;
386        if project.host_connection_id != acting_connection_id {
387            Err(anyhow!("not your worktree"))?;
388        }
389
390        let worktree = project
391            .worktrees
392            .remove(&worktree_id)
393            .ok_or_else(|| anyhow!("no such worktree"))?;
394
395        let mut guest_connection_ids = Vec::new();
396        if let Ok(share) = project.share_mut() {
397            guest_connection_ids.extend(share.guests.keys());
398            share.worktrees.remove(&worktree_id);
399        }
400
401        Ok((worktree, guest_connection_ids))
402    }
403
404    pub fn share_project(
405        &mut self,
406        project_id: u64,
407        connection_id: ConnectionId,
408    ) -> Result<SharedProject> {
409        if let Some(project) = self.projects.get_mut(&project_id) {
410            if project.host_connection_id == connection_id {
411                let mut share = ProjectShare::default();
412                for worktree_id in project.worktrees.keys() {
413                    share.worktrees.insert(*worktree_id, Default::default());
414                }
415                project.share = Some(share);
416                return Ok(SharedProject {});
417            }
418        }
419        Err(anyhow!("no such project"))?
420    }
421
422    pub fn unshare_project(
423        &mut self,
424        project_id: u64,
425        acting_connection_id: ConnectionId,
426    ) -> Result<UnsharedProject> {
427        let project = if let Some(project) = self.projects.get_mut(&project_id) {
428            project
429        } else {
430            return Err(anyhow!("no such project"))?;
431        };
432
433        if project.host_connection_id != acting_connection_id {
434            return Err(anyhow!("not your project"))?;
435        }
436
437        let connection_ids = project.connection_ids();
438        if let Some(share) = project.share.take() {
439            for connection_id in share.guests.into_keys() {
440                if let Some(connection) = self.connections.get_mut(&connection_id) {
441                    connection.projects.remove(&project_id);
442                }
443            }
444
445            Ok(UnsharedProject {
446                connection_ids,
447                host_user_id: project.host_user_id,
448            })
449        } else {
450            Err(anyhow!("project is not shared"))?
451        }
452    }
453
454    pub fn update_diagnostic_summary(
455        &mut self,
456        project_id: u64,
457        worktree_id: u64,
458        connection_id: ConnectionId,
459        summary: proto::DiagnosticSummary,
460    ) -> Result<Vec<ConnectionId>> {
461        let project = self
462            .projects
463            .get_mut(&project_id)
464            .ok_or_else(|| anyhow!("no such project"))?;
465        if project.host_connection_id == connection_id {
466            let worktree = project
467                .share_mut()?
468                .worktrees
469                .get_mut(&worktree_id)
470                .ok_or_else(|| anyhow!("no such worktree"))?;
471            worktree
472                .diagnostic_summaries
473                .insert(summary.path.clone().into(), summary);
474            return Ok(project.connection_ids());
475        }
476
477        Err(anyhow!("no such worktree"))?
478    }
479
480    pub fn start_language_server(
481        &mut self,
482        project_id: u64,
483        connection_id: ConnectionId,
484        language_server: proto::LanguageServer,
485    ) -> Result<Vec<ConnectionId>> {
486        let project = self
487            .projects
488            .get_mut(&project_id)
489            .ok_or_else(|| anyhow!("no such project"))?;
490        if project.host_connection_id == connection_id {
491            project.language_servers.push(language_server);
492            return Ok(project.connection_ids());
493        }
494
495        Err(anyhow!("no such project"))?
496    }
497
498    pub fn join_project(
499        &mut self,
500        connection_id: ConnectionId,
501        user_id: UserId,
502        project_id: u64,
503    ) -> Result<JoinedProject> {
504        let connection = self
505            .connections
506            .get_mut(&connection_id)
507            .ok_or_else(|| anyhow!("no such connection"))?;
508        let project = self
509            .projects
510            .get_mut(&project_id)
511            .ok_or_else(|| anyhow!("no such project"))?;
512
513        let share = project.share_mut()?;
514        connection.projects.insert(project_id);
515
516        let mut replica_id = 1;
517        while share.active_replica_ids.contains(&replica_id) {
518            replica_id += 1;
519        }
520        share.active_replica_ids.insert(replica_id);
521        share.guests.insert(connection_id, (replica_id, user_id));
522
523        Ok(JoinedProject {
524            replica_id,
525            project: &self.projects[&project_id],
526        })
527    }
528
529    pub fn leave_project(
530        &mut self,
531        connection_id: ConnectionId,
532        project_id: u64,
533    ) -> Result<LeftProject> {
534        let project = self
535            .projects
536            .get_mut(&project_id)
537            .ok_or_else(|| anyhow!("no such project"))?;
538        let share = project
539            .share
540            .as_mut()
541            .ok_or_else(|| anyhow!("project is not shared"))?;
542        let (replica_id, _) = share
543            .guests
544            .remove(&connection_id)
545            .ok_or_else(|| anyhow!("cannot leave a project before joining it"))?;
546        share.active_replica_ids.remove(&replica_id);
547
548        if let Some(connection) = self.connections.get_mut(&connection_id) {
549            connection.projects.remove(&project_id);
550        }
551
552        Ok(LeftProject {
553            connection_ids: project.connection_ids(),
554            host_user_id: project.host_user_id,
555        })
556    }
557
558    pub fn update_worktree(
559        &mut self,
560        connection_id: ConnectionId,
561        project_id: u64,
562        worktree_id: u64,
563        removed_entries: &[u64],
564        updated_entries: &[proto::Entry],
565        scan_id: u64,
566    ) -> Result<Vec<ConnectionId>> {
567        let project = self.write_project(project_id, connection_id)?;
568        let worktree = project
569            .share_mut()?
570            .worktrees
571            .get_mut(&worktree_id)
572            .ok_or_else(|| anyhow!("no such worktree"))?;
573        for entry_id in removed_entries {
574            worktree.entries.remove(&entry_id);
575        }
576        for entry in updated_entries {
577            worktree.entries.insert(entry.id, entry.clone());
578        }
579        worktree.scan_id = scan_id;
580        let connection_ids = project.connection_ids();
581        Ok(connection_ids)
582    }
583
584    pub fn project_connection_ids(
585        &self,
586        project_id: u64,
587        acting_connection_id: ConnectionId,
588    ) -> Result<Vec<ConnectionId>> {
589        Ok(self
590            .read_project(project_id, acting_connection_id)?
591            .connection_ids())
592    }
593
594    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Result<Vec<ConnectionId>> {
595        Ok(self
596            .channels
597            .get(&channel_id)
598            .ok_or_else(|| anyhow!("no such channel"))?
599            .connection_ids())
600    }
601
602    pub fn project(&self, project_id: u64) -> Result<&Project> {
603        self.projects
604            .get(&project_id)
605            .ok_or_else(|| anyhow!("no such project"))
606    }
607
608    pub fn read_project(&self, project_id: u64, connection_id: ConnectionId) -> Result<&Project> {
609        let project = self
610            .projects
611            .get(&project_id)
612            .ok_or_else(|| anyhow!("no such project"))?;
613        if project.host_connection_id == connection_id
614            || project
615                .share
616                .as_ref()
617                .ok_or_else(|| anyhow!("project is not shared"))?
618                .guests
619                .contains_key(&connection_id)
620        {
621            Ok(project)
622        } else {
623            Err(anyhow!("no such project"))?
624        }
625    }
626
627    fn write_project(
628        &mut self,
629        project_id: u64,
630        connection_id: ConnectionId,
631    ) -> Result<&mut Project> {
632        let project = self
633            .projects
634            .get_mut(&project_id)
635            .ok_or_else(|| anyhow!("no such project"))?;
636        if project.host_connection_id == connection_id
637            || project
638                .share
639                .as_ref()
640                .ok_or_else(|| anyhow!("project is not shared"))?
641                .guests
642                .contains_key(&connection_id)
643        {
644            Ok(project)
645        } else {
646            Err(anyhow!("no such project"))?
647        }
648    }
649
650    #[cfg(test)]
651    pub fn check_invariants(&self) {
652        for (connection_id, connection) in &self.connections {
653            for project_id in &connection.projects {
654                let project = &self.projects.get(&project_id).unwrap();
655                if project.host_connection_id != *connection_id {
656                    assert!(project
657                        .share
658                        .as_ref()
659                        .unwrap()
660                        .guests
661                        .contains_key(connection_id));
662                }
663
664                if let Some(share) = project.share.as_ref() {
665                    for (worktree_id, worktree) in share.worktrees.iter() {
666                        let mut paths = HashMap::default();
667                        for entry in worktree.entries.values() {
668                            let prev_entry = paths.insert(&entry.path, entry);
669                            assert_eq!(
670                                prev_entry,
671                                None,
672                                "worktree {:?}, duplicate path for entries {:?} and {:?}",
673                                worktree_id,
674                                prev_entry.unwrap(),
675                                entry
676                            );
677                        }
678                    }
679                }
680            }
681            for channel_id in &connection.channels {
682                let channel = self.channels.get(channel_id).unwrap();
683                assert!(channel.connection_ids.contains(connection_id));
684            }
685            assert!(self
686                .connections_by_user_id
687                .get(&connection.user_id)
688                .unwrap()
689                .contains(connection_id));
690        }
691
692        for (user_id, connection_ids) in &self.connections_by_user_id {
693            for connection_id in connection_ids {
694                assert_eq!(
695                    self.connections.get(connection_id).unwrap().user_id,
696                    *user_id
697                );
698            }
699        }
700
701        for (project_id, project) in &self.projects {
702            let host_connection = self.connections.get(&project.host_connection_id).unwrap();
703            assert!(host_connection.projects.contains(project_id));
704
705            if let Some(share) = &project.share {
706                for guest_connection_id in share.guests.keys() {
707                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
708                    assert!(guest_connection.projects.contains(project_id));
709                }
710                assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
711                assert_eq!(
712                    share.active_replica_ids,
713                    share
714                        .guests
715                        .values()
716                        .map(|(replica_id, _)| *replica_id)
717                        .collect::<HashSet<_>>(),
718                );
719            }
720        }
721
722        for (channel_id, channel) in &self.channels {
723            for connection_id in &channel.connection_ids {
724                let connection = self.connections.get(connection_id).unwrap();
725                assert!(connection.channels.contains(channel_id));
726            }
727        }
728    }
729}
730
731impl Project {
732    pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
733        if let Some(share) = &self.share {
734            share.guests.keys().copied().collect()
735        } else {
736            Vec::new()
737        }
738    }
739
740    pub fn connection_ids(&self) -> Vec<ConnectionId> {
741        if let Some(share) = &self.share {
742            share
743                .guests
744                .keys()
745                .copied()
746                .chain(Some(self.host_connection_id))
747                .collect()
748        } else {
749            vec![self.host_connection_id]
750        }
751    }
752
753    pub fn share(&self) -> Result<&ProjectShare> {
754        Ok(self
755            .share
756            .as_ref()
757            .ok_or_else(|| anyhow!("worktree is not shared"))?)
758    }
759
760    fn share_mut(&mut self) -> Result<&mut ProjectShare> {
761        Ok(self
762            .share
763            .as_mut()
764            .ok_or_else(|| anyhow!("worktree is not shared"))?)
765    }
766}
767
768impl Channel {
769    fn connection_ids(&self) -> Vec<ConnectionId> {
770        self.connection_ids.iter().copied().collect()
771    }
772}