store.rs

  1use crate::db::{self, ChannelId, UserId};
  2use anyhow::{anyhow, Result};
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use std::{collections::hash_map, path::PathBuf};
  6use tracing::instrument;
  7
  8#[derive(Default)]
  9pub struct Store {
 10    connections: HashMap<ConnectionId, ConnectionState>,
 11    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 12    projects: HashMap<u64, Project>,
 13    channels: HashMap<ChannelId, Channel>,
 14    next_project_id: u64,
 15}
 16
 17struct ConnectionState {
 18    user_id: UserId,
 19    projects: HashSet<u64>,
 20    channels: HashSet<ChannelId>,
 21}
 22
 23pub struct Project {
 24    pub host_connection_id: ConnectionId,
 25    pub host_user_id: UserId,
 26    pub share: Option<ProjectShare>,
 27    pub worktrees: HashMap<u64, Worktree>,
 28    pub language_servers: Vec<proto::LanguageServer>,
 29}
 30
 31pub struct Worktree {
 32    pub root_name: String,
 33    pub visible: bool,
 34}
 35
 36#[derive(Default)]
 37pub struct ProjectShare {
 38    pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
 39    pub active_replica_ids: HashSet<ReplicaId>,
 40    pub worktrees: HashMap<u64, WorktreeShare>,
 41}
 42
 43#[derive(Default)]
 44pub struct WorktreeShare {
 45    pub entries: HashMap<u64, proto::Entry>,
 46    pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
 47    pub scan_id: u64,
 48}
 49
 50#[derive(Default)]
 51pub struct Channel {
 52    pub connection_ids: HashSet<ConnectionId>,
 53}
 54
 55pub type ReplicaId = u16;
 56
 57#[derive(Default)]
 58pub struct RemovedConnectionState {
 59    pub user_id: UserId,
 60    pub hosted_projects: HashMap<u64, Project>,
 61    pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
 62    pub contact_ids: HashSet<UserId>,
 63}
 64
 65pub struct JoinedProject<'a> {
 66    pub replica_id: ReplicaId,
 67    pub project: &'a Project,
 68}
 69
 70pub struct SharedProject {}
 71
 72pub struct UnsharedProject {
 73    pub connection_ids: Vec<ConnectionId>,
 74    pub host_user_id: UserId,
 75}
 76
 77pub struct LeftProject {
 78    pub connection_ids: Vec<ConnectionId>,
 79    pub host_user_id: UserId,
 80}
 81
 82#[derive(Copy, Clone)]
 83pub struct Metrics {
 84    pub connections: usize,
 85    pub registered_projects: usize,
 86    pub shared_projects: usize,
 87}
 88
 89impl Store {
 90    pub fn metrics(&self) -> Metrics {
 91        let connections = self.connections.len();
 92        let mut registered_projects = 0;
 93        let mut shared_projects = 0;
 94        for project in self.projects.values() {
 95            registered_projects += 1;
 96            if project.share.is_some() {
 97                shared_projects += 1;
 98            }
 99        }
100
101        Metrics {
102            connections,
103            registered_projects,
104            shared_projects,
105        }
106    }
107
108    #[instrument(skip(self))]
109    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
110        self.connections.insert(
111            connection_id,
112            ConnectionState {
113                user_id,
114                projects: Default::default(),
115                channels: Default::default(),
116            },
117        );
118        self.connections_by_user_id
119            .entry(user_id)
120            .or_default()
121            .insert(connection_id);
122    }
123
124    #[instrument(skip(self))]
125    pub fn remove_connection(
126        &mut self,
127        connection_id: ConnectionId,
128    ) -> Result<RemovedConnectionState> {
129        let connection = if let Some(connection) = self.connections.remove(&connection_id) {
130            connection
131        } else {
132            return Err(anyhow!("no such connection"))?;
133        };
134
135        for channel_id in &connection.channels {
136            if let Some(channel) = self.channels.get_mut(&channel_id) {
137                channel.connection_ids.remove(&connection_id);
138            }
139        }
140
141        let user_connections = self
142            .connections_by_user_id
143            .get_mut(&connection.user_id)
144            .unwrap();
145        user_connections.remove(&connection_id);
146        if user_connections.is_empty() {
147            self.connections_by_user_id.remove(&connection.user_id);
148        }
149
150        let mut result = RemovedConnectionState::default();
151        result.user_id = connection.user_id;
152        for project_id in connection.projects.clone() {
153            if let Ok(project) = self.unregister_project(project_id, connection_id) {
154                result.hosted_projects.insert(project_id, project);
155            } else if let Ok(project) = self.leave_project(connection_id, project_id) {
156                result
157                    .guest_project_ids
158                    .insert(project_id, project.connection_ids);
159            }
160        }
161
162        Ok(result)
163    }
164
165    #[cfg(test)]
166    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
167        self.channels.get(&id)
168    }
169
170    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
171        if let Some(connection) = self.connections.get_mut(&connection_id) {
172            connection.channels.insert(channel_id);
173            self.channels
174                .entry(channel_id)
175                .or_default()
176                .connection_ids
177                .insert(connection_id);
178        }
179    }
180
181    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
182        if let Some(connection) = self.connections.get_mut(&connection_id) {
183            connection.channels.remove(&channel_id);
184            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
185                entry.get_mut().connection_ids.remove(&connection_id);
186                if entry.get_mut().connection_ids.is_empty() {
187                    entry.remove();
188                }
189            }
190        }
191    }
192
193    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> Result<UserId> {
194        Ok(self
195            .connections
196            .get(&connection_id)
197            .ok_or_else(|| anyhow!("unknown connection"))?
198            .user_id)
199    }
200
201    pub fn connection_ids_for_user<'a>(
202        &'a self,
203        user_id: UserId,
204    ) -> impl 'a + Iterator<Item = ConnectionId> {
205        self.connections_by_user_id
206            .get(&user_id)
207            .into_iter()
208            .flatten()
209            .copied()
210    }
211
212    pub fn is_user_online(&self, user_id: UserId) -> bool {
213        !self
214            .connections_by_user_id
215            .get(&user_id)
216            .unwrap_or(&Default::default())
217            .is_empty()
218    }
219
220    pub fn build_initial_contacts_update(&self, contacts: db::Contacts) -> proto::UpdateContacts {
221        let mut update = proto::UpdateContacts::default();
222        for user_id in contacts.current {
223            update.contacts.push(self.contact_for_user(user_id));
224        }
225
226        for request in contacts.incoming_requests {
227            update
228                .incoming_requests
229                .push(proto::IncomingContactRequest {
230                    requester_id: request.requester_id.to_proto(),
231                    should_notify: request.should_notify,
232                })
233        }
234
235        for requested_user_id in contacts.outgoing_requests {
236            update.outgoing_requests.push(requested_user_id.to_proto())
237        }
238
239        update
240    }
241
242    pub fn contact_for_user(&self, user_id: UserId) -> proto::Contact {
243        proto::Contact {
244            user_id: user_id.to_proto(),
245            projects: self.project_metadata_for_user(user_id),
246            online: self.is_user_online(user_id),
247        }
248    }
249
250    pub fn project_metadata_for_user(&self, user_id: UserId) -> Vec<proto::ProjectMetadata> {
251        let connection_ids = self.connections_by_user_id.get(&user_id);
252        let project_ids = connection_ids.iter().flat_map(|connection_ids| {
253            connection_ids
254                .iter()
255                .filter_map(|connection_id| self.connections.get(connection_id))
256                .flat_map(|connection| connection.projects.iter().copied())
257        });
258
259        let mut metadata = Vec::new();
260        for project_id in project_ids {
261            if let Some(project) = self.projects.get(&project_id) {
262                metadata.push(proto::ProjectMetadata {
263                    id: project_id,
264                    is_shared: project.share.is_some(),
265                    worktree_root_names: project
266                        .worktrees
267                        .values()
268                        .map(|worktree| worktree.root_name.clone())
269                        .collect(),
270                    guests: project
271                        .share
272                        .iter()
273                        .flat_map(|share| {
274                            share.guests.values().map(|(_, user_id)| user_id.to_proto())
275                        })
276                        .collect(),
277                });
278            }
279        }
280
281        metadata
282    }
283
284    // pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
285    //     let mut contacts = HashMap::default();
286    //     for project_id in self
287    //         .visible_projects_by_user_id
288    //         .get(&user_id)
289    //         .unwrap_or(&HashSet::default())
290    //     {
291    //         let project = &self.projects[project_id];
292
293    //         let mut guests = HashSet::default();
294    //         if let Ok(share) = project.share() {
295    //             for guest_connection_id in share.guests.keys() {
296    //                 if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
297    //                     guests.insert(user_id.to_proto());
298    //                 }
299    //             }
300    //         }
301
302    //         if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
303    //             let mut worktree_root_names = project
304    //                 .worktrees
305    //                 .values()
306    //                 .filter(|worktree| worktree.visible)
307    //                 .map(|worktree| worktree.root_name.clone())
308    //                 .collect::<Vec<_>>();
309    //             worktree_root_names.sort_unstable();
310    //             contacts
311    //                 .entry(host_user_id)
312    //                 .or_insert_with(|| proto::Contact {
313    //                     user_id: host_user_id.to_proto(),
314    //                     projects: Vec::new(),
315    //                 })
316    //                 .projects
317    //                 .push(proto::ProjectMetadata {
318    //                     id: *project_id,
319    //                     worktree_root_names,
320    //                     is_shared: project.share.is_some(),
321    //                     guests: guests.into_iter().collect(),
322    //                 });
323    //         }
324    //     }
325
326    //     contacts.into_values().collect()
327    // }
328
329    pub fn register_project(
330        &mut self,
331        host_connection_id: ConnectionId,
332        host_user_id: UserId,
333    ) -> u64 {
334        let project_id = self.next_project_id;
335        self.projects.insert(
336            project_id,
337            Project {
338                host_connection_id,
339                host_user_id,
340                share: None,
341                worktrees: Default::default(),
342                language_servers: Default::default(),
343            },
344        );
345        if let Some(connection) = self.connections.get_mut(&host_connection_id) {
346            connection.projects.insert(project_id);
347        }
348        self.next_project_id += 1;
349        project_id
350    }
351
352    pub fn register_worktree(
353        &mut self,
354        project_id: u64,
355        worktree_id: u64,
356        connection_id: ConnectionId,
357        worktree: Worktree,
358    ) -> Result<()> {
359        let project = self
360            .projects
361            .get_mut(&project_id)
362            .ok_or_else(|| anyhow!("no such project"))?;
363        if project.host_connection_id == connection_id {
364            project.worktrees.insert(worktree_id, worktree);
365            if let Ok(share) = project.share_mut() {
366                share.worktrees.insert(worktree_id, Default::default());
367            }
368
369            Ok(())
370        } else {
371            Err(anyhow!("no such project"))?
372        }
373    }
374
375    pub fn unregister_project(
376        &mut self,
377        project_id: u64,
378        connection_id: ConnectionId,
379    ) -> Result<Project> {
380        match self.projects.entry(project_id) {
381            hash_map::Entry::Occupied(e) => {
382                if e.get().host_connection_id == connection_id {
383                    let project = e.remove();
384
385                    if let Some(host_connection) = self.connections.get_mut(&connection_id) {
386                        host_connection.projects.remove(&project_id);
387                    }
388
389                    if let Some(share) = &project.share {
390                        for guest_connection in share.guests.keys() {
391                            if let Some(connection) = self.connections.get_mut(&guest_connection) {
392                                connection.projects.remove(&project_id);
393                            }
394                        }
395                    }
396
397                    Ok(project)
398                } else {
399                    Err(anyhow!("no such project"))?
400                }
401            }
402            hash_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?,
403        }
404    }
405
406    pub fn unregister_worktree(
407        &mut self,
408        project_id: u64,
409        worktree_id: u64,
410        acting_connection_id: ConnectionId,
411    ) -> Result<(Worktree, Vec<ConnectionId>)> {
412        let project = self
413            .projects
414            .get_mut(&project_id)
415            .ok_or_else(|| anyhow!("no such project"))?;
416        if project.host_connection_id != acting_connection_id {
417            Err(anyhow!("not your worktree"))?;
418        }
419
420        let worktree = project
421            .worktrees
422            .remove(&worktree_id)
423            .ok_or_else(|| anyhow!("no such worktree"))?;
424
425        let mut guest_connection_ids = Vec::new();
426        if let Ok(share) = project.share_mut() {
427            guest_connection_ids.extend(share.guests.keys());
428            share.worktrees.remove(&worktree_id);
429        }
430
431        Ok((worktree, guest_connection_ids))
432    }
433
434    pub fn share_project(
435        &mut self,
436        project_id: u64,
437        connection_id: ConnectionId,
438    ) -> Result<SharedProject> {
439        if let Some(project) = self.projects.get_mut(&project_id) {
440            if project.host_connection_id == connection_id {
441                let mut share = ProjectShare::default();
442                for worktree_id in project.worktrees.keys() {
443                    share.worktrees.insert(*worktree_id, Default::default());
444                }
445                project.share = Some(share);
446                return Ok(SharedProject {});
447            }
448        }
449        Err(anyhow!("no such project"))?
450    }
451
452    pub fn unshare_project(
453        &mut self,
454        project_id: u64,
455        acting_connection_id: ConnectionId,
456    ) -> Result<UnsharedProject> {
457        let project = if let Some(project) = self.projects.get_mut(&project_id) {
458            project
459        } else {
460            return Err(anyhow!("no such project"))?;
461        };
462
463        if project.host_connection_id != acting_connection_id {
464            return Err(anyhow!("not your project"))?;
465        }
466
467        let connection_ids = project.connection_ids();
468        if let Some(share) = project.share.take() {
469            for connection_id in share.guests.into_keys() {
470                if let Some(connection) = self.connections.get_mut(&connection_id) {
471                    connection.projects.remove(&project_id);
472                }
473            }
474
475            Ok(UnsharedProject {
476                connection_ids,
477                host_user_id: project.host_user_id,
478            })
479        } else {
480            Err(anyhow!("project is not shared"))?
481        }
482    }
483
484    pub fn update_diagnostic_summary(
485        &mut self,
486        project_id: u64,
487        worktree_id: u64,
488        connection_id: ConnectionId,
489        summary: proto::DiagnosticSummary,
490    ) -> Result<Vec<ConnectionId>> {
491        let project = self
492            .projects
493            .get_mut(&project_id)
494            .ok_or_else(|| anyhow!("no such project"))?;
495        if project.host_connection_id == connection_id {
496            let worktree = project
497                .share_mut()?
498                .worktrees
499                .get_mut(&worktree_id)
500                .ok_or_else(|| anyhow!("no such worktree"))?;
501            worktree
502                .diagnostic_summaries
503                .insert(summary.path.clone().into(), summary);
504            return Ok(project.connection_ids());
505        }
506
507        Err(anyhow!("no such worktree"))?
508    }
509
510    pub fn start_language_server(
511        &mut self,
512        project_id: u64,
513        connection_id: ConnectionId,
514        language_server: proto::LanguageServer,
515    ) -> Result<Vec<ConnectionId>> {
516        let project = self
517            .projects
518            .get_mut(&project_id)
519            .ok_or_else(|| anyhow!("no such project"))?;
520        if project.host_connection_id == connection_id {
521            project.language_servers.push(language_server);
522            return Ok(project.connection_ids());
523        }
524
525        Err(anyhow!("no such project"))?
526    }
527
528    pub fn join_project(
529        &mut self,
530        connection_id: ConnectionId,
531        user_id: UserId,
532        project_id: u64,
533    ) -> Result<JoinedProject> {
534        let connection = self
535            .connections
536            .get_mut(&connection_id)
537            .ok_or_else(|| anyhow!("no such connection"))?;
538        let project = self
539            .projects
540            .get_mut(&project_id)
541            .ok_or_else(|| anyhow!("no such project"))?;
542
543        let share = project.share_mut()?;
544        connection.projects.insert(project_id);
545
546        let mut replica_id = 1;
547        while share.active_replica_ids.contains(&replica_id) {
548            replica_id += 1;
549        }
550        share.active_replica_ids.insert(replica_id);
551        share.guests.insert(connection_id, (replica_id, user_id));
552
553        Ok(JoinedProject {
554            replica_id,
555            project: &self.projects[&project_id],
556        })
557    }
558
559    pub fn leave_project(
560        &mut self,
561        connection_id: ConnectionId,
562        project_id: u64,
563    ) -> Result<LeftProject> {
564        let project = self
565            .projects
566            .get_mut(&project_id)
567            .ok_or_else(|| anyhow!("no such project"))?;
568        let share = project
569            .share
570            .as_mut()
571            .ok_or_else(|| anyhow!("project is not shared"))?;
572        let (replica_id, _) = share
573            .guests
574            .remove(&connection_id)
575            .ok_or_else(|| anyhow!("cannot leave a project before joining it"))?;
576        share.active_replica_ids.remove(&replica_id);
577
578        if let Some(connection) = self.connections.get_mut(&connection_id) {
579            connection.projects.remove(&project_id);
580        }
581
582        Ok(LeftProject {
583            connection_ids: project.connection_ids(),
584            host_user_id: project.host_user_id,
585        })
586    }
587
588    pub fn update_worktree(
589        &mut self,
590        connection_id: ConnectionId,
591        project_id: u64,
592        worktree_id: u64,
593        removed_entries: &[u64],
594        updated_entries: &[proto::Entry],
595        scan_id: u64,
596    ) -> Result<Vec<ConnectionId>> {
597        let project = self.write_project(project_id, connection_id)?;
598        let worktree = project
599            .share_mut()?
600            .worktrees
601            .get_mut(&worktree_id)
602            .ok_or_else(|| anyhow!("no such worktree"))?;
603        for entry_id in removed_entries {
604            worktree.entries.remove(&entry_id);
605        }
606        for entry in updated_entries {
607            worktree.entries.insert(entry.id, entry.clone());
608        }
609        worktree.scan_id = scan_id;
610        let connection_ids = project.connection_ids();
611        Ok(connection_ids)
612    }
613
614    pub fn project_connection_ids(
615        &self,
616        project_id: u64,
617        acting_connection_id: ConnectionId,
618    ) -> Result<Vec<ConnectionId>> {
619        Ok(self
620            .read_project(project_id, acting_connection_id)?
621            .connection_ids())
622    }
623
624    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Result<Vec<ConnectionId>> {
625        Ok(self
626            .channels
627            .get(&channel_id)
628            .ok_or_else(|| anyhow!("no such channel"))?
629            .connection_ids())
630    }
631
632    pub fn project(&self, project_id: u64) -> Result<&Project> {
633        self.projects
634            .get(&project_id)
635            .ok_or_else(|| anyhow!("no such project"))
636    }
637
638    pub fn read_project(&self, project_id: u64, connection_id: ConnectionId) -> Result<&Project> {
639        let project = self
640            .projects
641            .get(&project_id)
642            .ok_or_else(|| anyhow!("no such project"))?;
643        if project.host_connection_id == connection_id
644            || project
645                .share
646                .as_ref()
647                .ok_or_else(|| anyhow!("project is not shared"))?
648                .guests
649                .contains_key(&connection_id)
650        {
651            Ok(project)
652        } else {
653            Err(anyhow!("no such project"))?
654        }
655    }
656
657    fn write_project(
658        &mut self,
659        project_id: u64,
660        connection_id: ConnectionId,
661    ) -> Result<&mut Project> {
662        let project = self
663            .projects
664            .get_mut(&project_id)
665            .ok_or_else(|| anyhow!("no such project"))?;
666        if project.host_connection_id == connection_id
667            || project
668                .share
669                .as_ref()
670                .ok_or_else(|| anyhow!("project is not shared"))?
671                .guests
672                .contains_key(&connection_id)
673        {
674            Ok(project)
675        } else {
676            Err(anyhow!("no such project"))?
677        }
678    }
679
680    #[cfg(test)]
681    pub fn check_invariants(&self) {
682        for (connection_id, connection) in &self.connections {
683            for project_id in &connection.projects {
684                let project = &self.projects.get(&project_id).unwrap();
685                if project.host_connection_id != *connection_id {
686                    assert!(project
687                        .share
688                        .as_ref()
689                        .unwrap()
690                        .guests
691                        .contains_key(connection_id));
692                }
693
694                if let Some(share) = project.share.as_ref() {
695                    for (worktree_id, worktree) in share.worktrees.iter() {
696                        let mut paths = HashMap::default();
697                        for entry in worktree.entries.values() {
698                            let prev_entry = paths.insert(&entry.path, entry);
699                            assert_eq!(
700                                prev_entry,
701                                None,
702                                "worktree {:?}, duplicate path for entries {:?} and {:?}",
703                                worktree_id,
704                                prev_entry.unwrap(),
705                                entry
706                            );
707                        }
708                    }
709                }
710            }
711            for channel_id in &connection.channels {
712                let channel = self.channels.get(channel_id).unwrap();
713                assert!(channel.connection_ids.contains(connection_id));
714            }
715            assert!(self
716                .connections_by_user_id
717                .get(&connection.user_id)
718                .unwrap()
719                .contains(connection_id));
720        }
721
722        for (user_id, connection_ids) in &self.connections_by_user_id {
723            for connection_id in connection_ids {
724                assert_eq!(
725                    self.connections.get(connection_id).unwrap().user_id,
726                    *user_id
727                );
728            }
729        }
730
731        for (project_id, project) in &self.projects {
732            let host_connection = self.connections.get(&project.host_connection_id).unwrap();
733            assert!(host_connection.projects.contains(project_id));
734
735            if let Some(share) = &project.share {
736                for guest_connection_id in share.guests.keys() {
737                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
738                    assert!(guest_connection.projects.contains(project_id));
739                }
740                assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
741                assert_eq!(
742                    share.active_replica_ids,
743                    share
744                        .guests
745                        .values()
746                        .map(|(replica_id, _)| *replica_id)
747                        .collect::<HashSet<_>>(),
748                );
749            }
750        }
751
752        for (channel_id, channel) in &self.channels {
753            for connection_id in &channel.connection_ids {
754                let connection = self.connections.get(connection_id).unwrap();
755                assert!(connection.channels.contains(channel_id));
756            }
757        }
758    }
759}
760
761impl Project {
762    pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
763        if let Some(share) = &self.share {
764            share.guests.keys().copied().collect()
765        } else {
766            Vec::new()
767        }
768    }
769
770    pub fn connection_ids(&self) -> Vec<ConnectionId> {
771        if let Some(share) = &self.share {
772            share
773                .guests
774                .keys()
775                .copied()
776                .chain(Some(self.host_connection_id))
777                .collect()
778        } else {
779            vec![self.host_connection_id]
780        }
781    }
782
783    pub fn share(&self) -> Result<&ProjectShare> {
784        Ok(self
785            .share
786            .as_ref()
787            .ok_or_else(|| anyhow!("worktree is not shared"))?)
788    }
789
790    fn share_mut(&mut self) -> Result<&mut ProjectShare> {
791        Ok(self
792            .share
793            .as_mut()
794            .ok_or_else(|| anyhow!("worktree is not shared"))?)
795    }
796}
797
798impl Channel {
799    fn connection_ids(&self) -> Vec<ConnectionId> {
800        self.connection_ids.iter().copied().collect()
801    }
802}