store.rs

  1use crate::db::{ChannelId, UserId};
  2use anyhow::anyhow;
  3use collections::{HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use std::collections::hash_map;
  6
  7#[derive(Default)]
  8pub struct Store {
  9    connections: HashMap<ConnectionId, ConnectionState>,
 10    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 11    projects: HashMap<u64, Project>,
 12    visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
 13    channels: HashMap<ChannelId, Channel>,
 14    next_project_id: u64,
 15}
 16
 17struct ConnectionState {
 18    user_id: UserId,
 19    projects: HashSet<u64>,
 20    channels: HashSet<ChannelId>,
 21}
 22
 23pub struct Project {
 24    pub host_connection_id: ConnectionId,
 25    pub host_user_id: UserId,
 26    pub share: Option<ProjectShare>,
 27    pub worktrees: HashMap<u64, Worktree>,
 28}
 29
 30pub struct Worktree {
 31    pub authorized_user_ids: Vec<UserId>,
 32    pub root_name: String,
 33    pub share: Option<WorktreeShare>,
 34}
 35
 36#[derive(Default)]
 37pub struct ProjectShare {
 38    pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
 39    pub active_replica_ids: HashSet<ReplicaId>,
 40}
 41
 42pub struct WorktreeShare {
 43    pub entries: HashMap<u64, proto::Entry>,
 44}
 45
 46#[derive(Default)]
 47pub struct Channel {
 48    pub connection_ids: HashSet<ConnectionId>,
 49}
 50
 51pub type ReplicaId = u16;
 52
 53#[derive(Default)]
 54pub struct RemovedConnectionState {
 55    pub hosted_projects: HashMap<u64, Project>,
 56    pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
 57    pub contact_ids: HashSet<UserId>,
 58}
 59
 60pub struct JoinedProject<'a> {
 61    pub replica_id: ReplicaId,
 62    pub project: &'a Project,
 63}
 64
 65pub struct UnsharedWorktree {
 66    pub connection_ids: Vec<ConnectionId>,
 67    pub authorized_user_ids: Vec<UserId>,
 68}
 69
 70pub struct LeftProject {
 71    pub connection_ids: Vec<ConnectionId>,
 72    pub authorized_user_ids: Vec<UserId>,
 73}
 74
 75impl Store {
 76    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 77        self.connections.insert(
 78            connection_id,
 79            ConnectionState {
 80                user_id,
 81                projects: Default::default(),
 82                channels: Default::default(),
 83            },
 84        );
 85        self.connections_by_user_id
 86            .entry(user_id)
 87            .or_default()
 88            .insert(connection_id);
 89    }
 90
 91    pub fn remove_connection(
 92        &mut self,
 93        connection_id: ConnectionId,
 94    ) -> tide::Result<RemovedConnectionState> {
 95        let connection = if let Some(connection) = self.connections.remove(&connection_id) {
 96            connection
 97        } else {
 98            return Err(anyhow!("no such connection"))?;
 99        };
100
101        for channel_id in &connection.channels {
102            if let Some(channel) = self.channels.get_mut(&channel_id) {
103                channel.connection_ids.remove(&connection_id);
104            }
105        }
106
107        let user_connections = self
108            .connections_by_user_id
109            .get_mut(&connection.user_id)
110            .unwrap();
111        user_connections.remove(&connection_id);
112        if user_connections.is_empty() {
113            self.connections_by_user_id.remove(&connection.user_id);
114        }
115
116        let mut result = RemovedConnectionState::default();
117        for project_id in connection.projects.clone() {
118            if let Some(project) = self.unregister_project(project_id, connection_id) {
119                result.contact_ids.extend(project.authorized_user_ids());
120                result.hosted_projects.insert(project_id, project);
121            } else if let Some(project) = self.leave_project(connection_id, project_id) {
122                result
123                    .guest_project_ids
124                    .insert(project_id, project.connection_ids);
125                result.contact_ids.extend(project.authorized_user_ids);
126            }
127        }
128
129        #[cfg(test)]
130        self.check_invariants();
131
132        Ok(result)
133    }
134
135    #[cfg(test)]
136    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
137        self.channels.get(&id)
138    }
139
140    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
141        if let Some(connection) = self.connections.get_mut(&connection_id) {
142            connection.channels.insert(channel_id);
143            self.channels
144                .entry(channel_id)
145                .or_default()
146                .connection_ids
147                .insert(connection_id);
148        }
149    }
150
151    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
152        if let Some(connection) = self.connections.get_mut(&connection_id) {
153            connection.channels.remove(&channel_id);
154            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
155                entry.get_mut().connection_ids.remove(&connection_id);
156                if entry.get_mut().connection_ids.is_empty() {
157                    entry.remove();
158                }
159            }
160        }
161    }
162
163    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
164        Ok(self
165            .connections
166            .get(&connection_id)
167            .ok_or_else(|| anyhow!("unknown connection"))?
168            .user_id)
169    }
170
171    pub fn connection_ids_for_user<'a>(
172        &'a self,
173        user_id: UserId,
174    ) -> impl 'a + Iterator<Item = ConnectionId> {
175        self.connections_by_user_id
176            .get(&user_id)
177            .into_iter()
178            .flatten()
179            .copied()
180    }
181
182    pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
183        let mut contacts = HashMap::default();
184        for project_id in self
185            .visible_projects_by_user_id
186            .get(&user_id)
187            .unwrap_or(&HashSet::default())
188        {
189            let project = &self.projects[project_id];
190
191            let mut guests = HashSet::default();
192            if let Ok(share) = project.share() {
193                for guest_connection_id in share.guests.keys() {
194                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
195                        guests.insert(user_id.to_proto());
196                    }
197                }
198            }
199
200            if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
201                let mut worktree_root_names = project
202                    .worktrees
203                    .values()
204                    .map(|worktree| worktree.root_name.clone())
205                    .collect::<Vec<_>>();
206                worktree_root_names.sort_unstable();
207                contacts
208                    .entry(host_user_id)
209                    .or_insert_with(|| proto::Contact {
210                        user_id: host_user_id.to_proto(),
211                        projects: Vec::new(),
212                    })
213                    .projects
214                    .push(proto::ProjectMetadata {
215                        id: *project_id,
216                        worktree_root_names,
217                        is_shared: project.share.is_some(),
218                        guests: guests.into_iter().collect(),
219                    });
220            }
221        }
222
223        contacts.into_values().collect()
224    }
225
226    pub fn register_project(
227        &mut self,
228        host_connection_id: ConnectionId,
229        host_user_id: UserId,
230    ) -> u64 {
231        let project_id = self.next_project_id;
232        self.projects.insert(
233            project_id,
234            Project {
235                host_connection_id,
236                host_user_id,
237                share: None,
238                worktrees: Default::default(),
239            },
240        );
241        self.next_project_id += 1;
242        project_id
243    }
244
245    pub fn register_worktree(
246        &mut self,
247        project_id: u64,
248        worktree_id: u64,
249        worktree: Worktree,
250    ) -> bool {
251        if let Some(project) = self.projects.get_mut(&project_id) {
252            for authorized_user_id in &worktree.authorized_user_ids {
253                self.visible_projects_by_user_id
254                    .entry(*authorized_user_id)
255                    .or_default()
256                    .insert(project_id);
257            }
258            if let Some(connection) = self.connections.get_mut(&project.host_connection_id) {
259                connection.projects.insert(project_id);
260            }
261            project.worktrees.insert(worktree_id, worktree);
262
263            #[cfg(test)]
264            self.check_invariants();
265            true
266        } else {
267            false
268        }
269    }
270
271    pub fn unregister_project(
272        &mut self,
273        project_id: u64,
274        connection_id: ConnectionId,
275    ) -> Option<Project> {
276        match self.projects.entry(project_id) {
277            hash_map::Entry::Occupied(e) => {
278                if e.get().host_connection_id == connection_id {
279                    for user_id in e.get().authorized_user_ids() {
280                        if let hash_map::Entry::Occupied(mut projects) =
281                            self.visible_projects_by_user_id.entry(user_id)
282                        {
283                            projects.get_mut().remove(&project_id);
284                        }
285                    }
286
287                    Some(e.remove())
288                } else {
289                    None
290                }
291            }
292            hash_map::Entry::Vacant(_) => None,
293        }
294    }
295
296    pub fn unregister_worktree(
297        &mut self,
298        project_id: u64,
299        worktree_id: u64,
300        acting_connection_id: ConnectionId,
301    ) -> tide::Result<(Worktree, Vec<ConnectionId>)> {
302        let project = self
303            .projects
304            .get_mut(&project_id)
305            .ok_or_else(|| anyhow!("no such project"))?;
306        if project.host_connection_id != acting_connection_id {
307            Err(anyhow!("not your worktree"))?;
308        }
309
310        let worktree = project
311            .worktrees
312            .remove(&worktree_id)
313            .ok_or_else(|| anyhow!("no such worktree"))?;
314
315        let mut guest_connection_ids = Vec::new();
316        if let Some(share) = &project.share {
317            guest_connection_ids.extend(share.guests.keys());
318        }
319
320        for authorized_user_id in &worktree.authorized_user_ids {
321            if let Some(visible_projects) =
322                self.visible_projects_by_user_id.get_mut(authorized_user_id)
323            {
324                if !project.has_authorized_user_id(*authorized_user_id) {
325                    visible_projects.remove(&project_id);
326                }
327            }
328        }
329
330        #[cfg(test)]
331        self.check_invariants();
332
333        Ok((worktree, guest_connection_ids))
334    }
335
336    pub fn share_project(&mut self, project_id: u64, connection_id: ConnectionId) -> bool {
337        if let Some(project) = self.projects.get_mut(&project_id) {
338            if project.host_connection_id == connection_id {
339                project.share = Some(ProjectShare::default());
340                return true;
341            }
342        }
343        false
344    }
345
346    pub fn unshare_project(
347        &mut self,
348        project_id: u64,
349        acting_connection_id: ConnectionId,
350    ) -> tide::Result<UnsharedWorktree> {
351        let project = if let Some(project) = self.projects.get_mut(&project_id) {
352            project
353        } else {
354            return Err(anyhow!("no such project"))?;
355        };
356
357        if project.host_connection_id != acting_connection_id {
358            return Err(anyhow!("not your project"))?;
359        }
360
361        let connection_ids = project.connection_ids();
362        let authorized_user_ids = project.authorized_user_ids();
363        if let Some(share) = project.share.take() {
364            for connection_id in share.guests.into_keys() {
365                if let Some(connection) = self.connections.get_mut(&connection_id) {
366                    connection.projects.remove(&project_id);
367                }
368            }
369
370            #[cfg(test)]
371            self.check_invariants();
372
373            Ok(UnsharedWorktree {
374                connection_ids,
375                authorized_user_ids,
376            })
377        } else {
378            Err(anyhow!("project is not shared"))?
379        }
380    }
381
382    pub fn share_worktree(
383        &mut self,
384        project_id: u64,
385        worktree_id: u64,
386        connection_id: ConnectionId,
387        entries: HashMap<u64, proto::Entry>,
388    ) -> Option<Vec<UserId>> {
389        let project = self.projects.get_mut(&project_id)?;
390        let worktree = project.worktrees.get_mut(&worktree_id)?;
391        if project.host_connection_id == connection_id && project.share.is_some() {
392            worktree.share = Some(WorktreeShare { entries });
393            Some(project.authorized_user_ids())
394        } else {
395            None
396        }
397    }
398
399    pub fn join_project(
400        &mut self,
401        connection_id: ConnectionId,
402        user_id: UserId,
403        project_id: u64,
404    ) -> tide::Result<JoinedProject> {
405        let connection = self
406            .connections
407            .get_mut(&connection_id)
408            .ok_or_else(|| anyhow!("no such connection"))?;
409        let project = self
410            .projects
411            .get_mut(&project_id)
412            .and_then(|project| {
413                if project.has_authorized_user_id(user_id) {
414                    Some(project)
415                } else {
416                    None
417                }
418            })
419            .ok_or_else(|| anyhow!("no such project"))?;
420
421        let share = project.share_mut()?;
422        connection.projects.insert(project_id);
423
424        let mut replica_id = 1;
425        while share.active_replica_ids.contains(&replica_id) {
426            replica_id += 1;
427        }
428        share.active_replica_ids.insert(replica_id);
429        share.guests.insert(connection_id, (replica_id, user_id));
430
431        #[cfg(test)]
432        self.check_invariants();
433
434        Ok(JoinedProject {
435            replica_id,
436            project: &self.projects[&project_id],
437        })
438    }
439
440    pub fn leave_project(
441        &mut self,
442        connection_id: ConnectionId,
443        project_id: u64,
444    ) -> Option<LeftProject> {
445        let project = self.projects.get_mut(&project_id)?;
446        let share = project.share.as_mut()?;
447        let (replica_id, _) = share.guests.remove(&connection_id)?;
448        share.active_replica_ids.remove(&replica_id);
449
450        if let Some(connection) = self.connections.get_mut(&connection_id) {
451            connection.projects.remove(&project_id);
452        }
453
454        let connection_ids = project.connection_ids();
455        let authorized_user_ids = project.authorized_user_ids();
456
457        #[cfg(test)]
458        self.check_invariants();
459
460        Some(LeftProject {
461            connection_ids,
462            authorized_user_ids,
463        })
464    }
465
466    pub fn update_worktree(
467        &mut self,
468        connection_id: ConnectionId,
469        project_id: u64,
470        worktree_id: u64,
471        removed_entries: &[u64],
472        updated_entries: &[proto::Entry],
473    ) -> Option<Vec<ConnectionId>> {
474        let project = self.write_project(project_id, connection_id)?;
475        let share = project.worktrees.get_mut(&worktree_id)?.share.as_mut()?;
476        for entry_id in removed_entries {
477            share.entries.remove(&entry_id);
478        }
479        for entry in updated_entries {
480            share.entries.insert(entry.id, entry.clone());
481        }
482        Some(project.connection_ids())
483    }
484
485    pub fn project_connection_ids(
486        &self,
487        project_id: u64,
488        acting_connection_id: ConnectionId,
489    ) -> Option<Vec<ConnectionId>> {
490        Some(
491            self.read_project(project_id, acting_connection_id)?
492                .connection_ids(),
493        )
494    }
495
496    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
497        Some(self.channels.get(&channel_id)?.connection_ids())
498    }
499
500    pub fn read_project(&self, project_id: u64, connection_id: ConnectionId) -> Option<&Project> {
501        let project = self.projects.get(&project_id)?;
502        if project.host_connection_id == connection_id
503            || project.share.as_ref()?.guests.contains_key(&connection_id)
504        {
505            Some(project)
506        } else {
507            None
508        }
509    }
510
511    fn write_project(
512        &mut self,
513        project_id: u64,
514        connection_id: ConnectionId,
515    ) -> Option<&mut Project> {
516        let project = self.projects.get_mut(&project_id)?;
517        if project.host_connection_id == connection_id
518            || project.share.as_ref()?.guests.contains_key(&connection_id)
519        {
520            Some(project)
521        } else {
522            None
523        }
524    }
525
526    #[cfg(test)]
527    fn check_invariants(&self) {
528        for (connection_id, connection) in &self.connections {
529            for project_id in &connection.projects {
530                let project = &self.projects.get(&project_id).unwrap();
531                if project.host_connection_id != *connection_id {
532                    assert!(project
533                        .share
534                        .as_ref()
535                        .unwrap()
536                        .guests
537                        .contains_key(connection_id));
538                }
539            }
540            for channel_id in &connection.channels {
541                let channel = self.channels.get(channel_id).unwrap();
542                assert!(channel.connection_ids.contains(connection_id));
543            }
544            assert!(self
545                .connections_by_user_id
546                .get(&connection.user_id)
547                .unwrap()
548                .contains(connection_id));
549        }
550
551        for (user_id, connection_ids) in &self.connections_by_user_id {
552            for connection_id in connection_ids {
553                assert_eq!(
554                    self.connections.get(connection_id).unwrap().user_id,
555                    *user_id
556                );
557            }
558        }
559
560        for (project_id, project) in &self.projects {
561            let host_connection = self.connections.get(&project.host_connection_id).unwrap();
562            assert!(host_connection.projects.contains(project_id));
563
564            for authorized_user_ids in project.authorized_user_ids() {
565                let visible_project_ids = self
566                    .visible_projects_by_user_id
567                    .get(&authorized_user_ids)
568                    .unwrap();
569                assert!(visible_project_ids.contains(project_id));
570            }
571
572            if let Some(share) = &project.share {
573                for guest_connection_id in share.guests.keys() {
574                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
575                    assert!(guest_connection.projects.contains(project_id));
576                }
577                assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
578                assert_eq!(
579                    share.active_replica_ids,
580                    share
581                        .guests
582                        .values()
583                        .map(|(replica_id, _)| *replica_id)
584                        .collect::<HashSet<_>>(),
585                );
586            }
587        }
588
589        for (user_id, visible_project_ids) in &self.visible_projects_by_user_id {
590            for project_id in visible_project_ids {
591                let project = self.projects.get(project_id).unwrap();
592                assert!(project.authorized_user_ids().contains(user_id));
593            }
594        }
595
596        for (channel_id, channel) in &self.channels {
597            for connection_id in &channel.connection_ids {
598                let connection = self.connections.get(connection_id).unwrap();
599                assert!(connection.channels.contains(channel_id));
600            }
601        }
602    }
603}
604
605impl Project {
606    pub fn has_authorized_user_id(&self, user_id: UserId) -> bool {
607        self.worktrees
608            .values()
609            .any(|worktree| worktree.authorized_user_ids.contains(&user_id))
610    }
611
612    pub fn authorized_user_ids(&self) -> Vec<UserId> {
613        let mut ids = self
614            .worktrees
615            .values()
616            .flat_map(|worktree| worktree.authorized_user_ids.iter())
617            .copied()
618            .collect::<Vec<_>>();
619        ids.sort_unstable();
620        ids.dedup();
621        ids
622    }
623
624    pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
625        if let Some(share) = &self.share {
626            share.guests.keys().copied().collect()
627        } else {
628            Vec::new()
629        }
630    }
631
632    pub fn connection_ids(&self) -> Vec<ConnectionId> {
633        if let Some(share) = &self.share {
634            share
635                .guests
636                .keys()
637                .copied()
638                .chain(Some(self.host_connection_id))
639                .collect()
640        } else {
641            vec![self.host_connection_id]
642        }
643    }
644
645    pub fn share(&self) -> tide::Result<&ProjectShare> {
646        Ok(self
647            .share
648            .as_ref()
649            .ok_or_else(|| anyhow!("worktree is not shared"))?)
650    }
651
652    fn share_mut(&mut self) -> tide::Result<&mut ProjectShare> {
653        Ok(self
654            .share
655            .as_mut()
656            .ok_or_else(|| anyhow!("worktree is not shared"))?)
657    }
658}
659
660impl Channel {
661    fn connection_ids(&self) -> Vec<ConnectionId> {
662        self.connection_ids.iter().copied().collect()
663    }
664}