store.rs

  1use crate::db::{ChannelId, UserId};
  2use anyhow::anyhow;
  3use collections::{HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use std::collections::hash_map;
  6
  7#[derive(Default)]
  8pub struct Store {
  9    connections: HashMap<ConnectionId, ConnectionState>,
 10    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 11    projects: HashMap<u64, Project>,
 12    visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
 13    channels: HashMap<ChannelId, Channel>,
 14    next_worktree_id: u64,
 15}
 16
 17struct ConnectionState {
 18    user_id: UserId,
 19    projects: HashSet<u64>,
 20    channels: HashSet<ChannelId>,
 21}
 22
 23pub struct Project {
 24    pub host_connection_id: ConnectionId,
 25    pub host_user_id: UserId,
 26    pub share: Option<ProjectShare>,
 27    worktrees: HashMap<u64, Worktree>,
 28}
 29
 30pub struct Worktree {
 31    pub authorized_user_ids: Vec<UserId>,
 32    pub root_name: String,
 33}
 34
 35#[derive(Default)]
 36pub struct ProjectShare {
 37    pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
 38    pub active_replica_ids: HashSet<ReplicaId>,
 39    pub worktrees: HashMap<u64, WorktreeShare>,
 40}
 41
 42pub struct WorktreeShare {
 43    pub entries: HashMap<u64, proto::Entry>,
 44}
 45
 46#[derive(Default)]
 47pub struct Channel {
 48    pub connection_ids: HashSet<ConnectionId>,
 49}
 50
 51pub type ReplicaId = u16;
 52
 53#[derive(Default)]
 54pub struct RemovedConnectionState {
 55    pub hosted_projects: HashMap<u64, Project>,
 56    pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
 57    pub contact_ids: HashSet<UserId>,
 58}
 59
 60pub struct JoinedWorktree<'a> {
 61    pub replica_id: ReplicaId,
 62    pub worktree: &'a Worktree,
 63}
 64
 65pub struct UnsharedWorktree {
 66    pub connection_ids: Vec<ConnectionId>,
 67    pub authorized_user_ids: Vec<UserId>,
 68}
 69
 70pub struct LeftWorktree {
 71    pub connection_ids: Vec<ConnectionId>,
 72    pub authorized_user_ids: Vec<UserId>,
 73}
 74
 75impl Store {
 76    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 77        self.connections.insert(
 78            connection_id,
 79            ConnectionState {
 80                user_id,
 81                projects: Default::default(),
 82                channels: Default::default(),
 83            },
 84        );
 85        self.connections_by_user_id
 86            .entry(user_id)
 87            .or_default()
 88            .insert(connection_id);
 89    }
 90
 91    pub fn remove_connection(
 92        &mut self,
 93        connection_id: ConnectionId,
 94    ) -> tide::Result<RemovedConnectionState> {
 95        let connection = if let Some(connection) = self.connections.remove(&connection_id) {
 96            connection
 97        } else {
 98            return Err(anyhow!("no such connection"))?;
 99        };
100
101        for channel_id in &connection.channels {
102            if let Some(channel) = self.channels.get_mut(&channel_id) {
103                channel.connection_ids.remove(&connection_id);
104            }
105        }
106
107        let user_connections = self
108            .connections_by_user_id
109            .get_mut(&connection.user_id)
110            .unwrap();
111        user_connections.remove(&connection_id);
112        if user_connections.is_empty() {
113            self.connections_by_user_id.remove(&connection.user_id);
114        }
115
116        let mut result = RemovedConnectionState::default();
117        for worktree_id in connection.worktrees.clone() {
118            if let Ok(worktree) = self.unregister_worktree(worktree_id, connection_id) {
119                result
120                    .contact_ids
121                    .extend(worktree.authorized_user_ids.iter().copied());
122                result.hosted_worktrees.insert(worktree_id, worktree);
123            } else if let Some(worktree) = self.leave_worktree(connection_id, worktree_id) {
124                result
125                    .guest_worktree_ids
126                    .insert(worktree_id, worktree.connection_ids);
127                result.contact_ids.extend(worktree.authorized_user_ids);
128            }
129        }
130
131        #[cfg(test)]
132        self.check_invariants();
133
134        Ok(result)
135    }
136
137    #[cfg(test)]
138    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
139        self.channels.get(&id)
140    }
141
142    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
143        if let Some(connection) = self.connections.get_mut(&connection_id) {
144            connection.channels.insert(channel_id);
145            self.channels
146                .entry(channel_id)
147                .or_default()
148                .connection_ids
149                .insert(connection_id);
150        }
151    }
152
153    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
154        if let Some(connection) = self.connections.get_mut(&connection_id) {
155            connection.channels.remove(&channel_id);
156            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
157                entry.get_mut().connection_ids.remove(&connection_id);
158                if entry.get_mut().connection_ids.is_empty() {
159                    entry.remove();
160                }
161            }
162        }
163    }
164
165    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
166        Ok(self
167            .connections
168            .get(&connection_id)
169            .ok_or_else(|| anyhow!("unknown connection"))?
170            .user_id)
171    }
172
173    pub fn connection_ids_for_user<'a>(
174        &'a self,
175        user_id: UserId,
176    ) -> impl 'a + Iterator<Item = ConnectionId> {
177        self.connections_by_user_id
178            .get(&user_id)
179            .into_iter()
180            .flatten()
181            .copied()
182    }
183
184    pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
185        let mut contacts = HashMap::default();
186        for project_id in self
187            .visible_projects_by_user_id
188            .get(&user_id)
189            .unwrap_or(&HashSet::default())
190        {
191            let project = &self.projects[project_id];
192
193            let mut guests = HashSet::default();
194            if let Ok(share) = worktree.share() {
195                for guest_connection_id in share.guests.keys() {
196                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
197                        guests.insert(user_id.to_proto());
198                    }
199                }
200            }
201
202            if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
203                contacts
204                    .entry(host_user_id)
205                    .or_insert_with(|| proto::Contact {
206                        user_id: host_user_id.to_proto(),
207                        projects: Vec::new(),
208                    })
209                    .projects
210                    .push(proto::ProjectMetadata {
211                        id: *project_id,
212                        worktree_root_names: project
213                            .worktrees
214                            .iter()
215                            .map(|worktree| worktree.root_name.clone())
216                            .collect(),
217                        is_shared: project.share.is_some(),
218                        guests: guests.into_iter().collect(),
219                    });
220            }
221        }
222
223        contacts.into_values().collect()
224    }
225
226    pub fn register_project(
227        &mut self,
228        host_connection_id: ConnectionId,
229        host_user_id: UserId,
230    ) -> u64 {
231        let project_id = self.next_project_id;
232        self.projects.insert(
233            project_id,
234            Project {
235                host_connection_id,
236                host_user_id,
237                share: None,
238                worktrees: Default::default(),
239            },
240        );
241        self.next_project_id += 1;
242        project_id
243    }
244
245    pub fn register_worktree(
246        &mut self,
247        project_id: u64,
248        worktree_id: u64,
249        worktree: Worktree,
250    ) -> bool {
251        if let Some(project) = self.projects.get_mut(&project_id) {
252            for authorized_user_id in &worktree.authorized_user_ids {
253                self.visible_projects_by_user_id
254                    .entry(*authorized_user_id)
255                    .or_default()
256                    .insert(project_id);
257            }
258            if let Some(connection) = self.connections.get_mut(&project.host_connection_id) {
259                connection.projects.insert(project_id);
260            }
261            project.worktrees.insert(worktree_id, worktree);
262
263            #[cfg(test)]
264            self.check_invariants();
265            true
266        } else {
267            false
268        }
269    }
270
271    pub fn unregister_project(&mut self, project_id: u64) {
272        todo!()
273    }
274
275    pub fn unregister_worktree(
276        &mut self,
277        project_id: u64,
278        worktree_id: u64,
279        acting_connection_id: ConnectionId,
280    ) -> tide::Result<Worktree> {
281        let project = self
282            .projects
283            .get_mut(&project_id)
284            .ok_or_else(|| anyhow!("no such project"))?;
285        if project.host_connection_id != acting_connection_id {
286            Err(anyhow!("not your worktree"))?;
287        }
288
289        let worktree = project
290            .worktrees
291            .remove(&worktree_id)
292            .ok_or_else(|| anyhow!("no such worktree"))?;
293
294        if let Some(connection) = self.connections.get_mut(&project.host_connection_id) {
295            connection.worktrees.remove(&worktree_id);
296        }
297
298        if let Some(share) = &worktree.share {
299            for connection_id in share.guests.keys() {
300                if let Some(connection) = self.connections.get_mut(connection_id) {
301                    connection.worktrees.remove(&worktree_id);
302                }
303            }
304        }
305
306        for authorized_user_id in &worktree.authorized_user_ids {
307            if let Some(visible_worktrees) = self
308                .visible_worktrees_by_user_id
309                .get_mut(&authorized_user_id)
310            {
311                visible_worktrees.remove(&worktree_id);
312            }
313        }
314
315        #[cfg(test)]
316        self.check_invariants();
317
318        Ok(worktree)
319    }
320
321    pub fn share_project(&mut self, project_id: u64, connection_id: ConnectionId) -> bool {
322        if let Some(project) = self.projects.get_mut(&project_id) {
323            if project.host_connection_id == connection_id {
324                project.share = Some(ProjectShare::default());
325                return true;
326            }
327        }
328        false
329    }
330
331    pub fn share_worktree(
332        &mut self,
333        project_id: u64,
334        worktree_id: u64,
335        connection_id: ConnectionId,
336        entries: HashMap<u64, proto::Entry>,
337    ) -> Option<Vec<UserId>> {
338        if let Some(project) = self.projects.get_mut(&project_id) {
339            if project.host_connection_id == connection_id {
340                if let Some(share) = project.share.as_mut() {
341                    share
342                        .worktrees
343                        .insert(worktree_id, WorktreeShare { entries });
344                    return Some(project.authorized_user_ids());
345                }
346            }
347        }
348        None
349    }
350
351    pub fn unshare_worktree(
352        &mut self,
353        worktree_id: u64,
354        acting_connection_id: ConnectionId,
355    ) -> tide::Result<UnsharedWorktree> {
356        let worktree = if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
357            worktree
358        } else {
359            return Err(anyhow!("no such worktree"))?;
360        };
361
362        if worktree.host_connection_id != acting_connection_id {
363            return Err(anyhow!("not your worktree"))?;
364        }
365
366        let connection_ids = worktree.connection_ids();
367        let authorized_user_ids = worktree.authorized_user_ids.clone();
368        if let Some(share) = worktree.share.take() {
369            for connection_id in share.guests.into_keys() {
370                if let Some(connection) = self.connections.get_mut(&connection_id) {
371                    connection.worktrees.remove(&worktree_id);
372                }
373            }
374
375            #[cfg(test)]
376            self.check_invariants();
377
378            Ok(UnsharedWorktree {
379                connection_ids,
380                authorized_user_ids,
381            })
382        } else {
383            Err(anyhow!("worktree is not shared"))?
384        }
385    }
386
387    pub fn join_worktree(
388        &mut self,
389        connection_id: ConnectionId,
390        user_id: UserId,
391        worktree_id: u64,
392    ) -> tide::Result<JoinedWorktree> {
393        let connection = self
394            .connections
395            .get_mut(&connection_id)
396            .ok_or_else(|| anyhow!("no such connection"))?;
397        let worktree = self
398            .worktrees
399            .get_mut(&worktree_id)
400            .and_then(|worktree| {
401                if worktree.authorized_user_ids.contains(&user_id) {
402                    Some(worktree)
403                } else {
404                    None
405                }
406            })
407            .ok_or_else(|| anyhow!("no such worktree"))?;
408
409        let share = worktree.share_mut()?;
410        connection.worktrees.insert(worktree_id);
411
412        let mut replica_id = 1;
413        while share.active_replica_ids.contains(&replica_id) {
414            replica_id += 1;
415        }
416        share.active_replica_ids.insert(replica_id);
417        share.guests.insert(connection_id, (replica_id, user_id));
418
419        #[cfg(test)]
420        self.check_invariants();
421
422        Ok(JoinedWorktree {
423            replica_id,
424            worktree: &self.worktrees[&worktree_id],
425        })
426    }
427
428    pub fn leave_worktree(
429        &mut self,
430        connection_id: ConnectionId,
431        worktree_id: u64,
432    ) -> Option<LeftWorktree> {
433        let worktree = self.worktrees.get_mut(&worktree_id)?;
434        let share = worktree.share.as_mut()?;
435        let (replica_id, _) = share.guests.remove(&connection_id)?;
436        share.active_replica_ids.remove(&replica_id);
437
438        if let Some(connection) = self.connections.get_mut(&connection_id) {
439            connection.worktrees.remove(&worktree_id);
440        }
441
442        let connection_ids = worktree.connection_ids();
443        let authorized_user_ids = worktree.authorized_user_ids.clone();
444
445        #[cfg(test)]
446        self.check_invariants();
447
448        Some(LeftWorktree {
449            connection_ids,
450            authorized_user_ids,
451        })
452    }
453
454    pub fn update_worktree(
455        &mut self,
456        connection_id: ConnectionId,
457        worktree_id: u64,
458        removed_entries: &[u64],
459        updated_entries: &[proto::Entry],
460    ) -> tide::Result<Vec<ConnectionId>> {
461        let worktree = self.write_worktree(worktree_id, connection_id)?;
462        let share = worktree.share_mut()?;
463        for entry_id in removed_entries {
464            share.entries.remove(&entry_id);
465        }
466        for entry in updated_entries {
467            share.entries.insert(entry.id, entry.clone());
468        }
469        Ok(worktree.connection_ids())
470    }
471
472    pub fn worktree_host_connection_id(
473        &self,
474        connection_id: ConnectionId,
475        worktree_id: u64,
476    ) -> tide::Result<ConnectionId> {
477        Ok(self
478            .read_worktree(worktree_id, connection_id)?
479            .host_connection_id)
480    }
481
482    pub fn worktree_guest_connection_ids(
483        &self,
484        connection_id: ConnectionId,
485        worktree_id: u64,
486    ) -> tide::Result<Vec<ConnectionId>> {
487        Ok(self
488            .read_worktree(worktree_id, connection_id)?
489            .share()?
490            .guests
491            .keys()
492            .copied()
493            .collect())
494    }
495
496    pub fn worktree_connection_ids(
497        &self,
498        connection_id: ConnectionId,
499        worktree_id: u64,
500    ) -> tide::Result<Vec<ConnectionId>> {
501        Ok(self
502            .read_worktree(worktree_id, connection_id)?
503            .connection_ids())
504    }
505
506    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
507        Some(self.channels.get(&channel_id)?.connection_ids())
508    }
509
510    fn read_worktree(
511        &self,
512        worktree_id: u64,
513        connection_id: ConnectionId,
514    ) -> tide::Result<&Worktree> {
515        let worktree = self
516            .worktrees
517            .get(&worktree_id)
518            .ok_or_else(|| anyhow!("worktree not found"))?;
519
520        if worktree.host_connection_id == connection_id
521            || worktree.share()?.guests.contains_key(&connection_id)
522        {
523            Ok(worktree)
524        } else {
525            Err(anyhow!(
526                "{} is not a member of worktree {}",
527                connection_id,
528                worktree_id
529            ))?
530        }
531    }
532
533    fn write_worktree(
534        &mut self,
535        worktree_id: u64,
536        connection_id: ConnectionId,
537    ) -> tide::Result<&mut Worktree> {
538        let worktree = self
539            .worktrees
540            .get_mut(&worktree_id)
541            .ok_or_else(|| anyhow!("worktree not found"))?;
542
543        if worktree.host_connection_id == connection_id
544            || worktree
545                .share
546                .as_ref()
547                .map_or(false, |share| share.guests.contains_key(&connection_id))
548        {
549            Ok(worktree)
550        } else {
551            Err(anyhow!(
552                "{} is not a member of worktree {}",
553                connection_id,
554                worktree_id
555            ))?
556        }
557    }
558
559    #[cfg(test)]
560    fn check_invariants(&self) {
561        for (connection_id, connection) in &self.connections {
562            for worktree_id in &connection.worktrees {
563                let worktree = &self.worktrees.get(&worktree_id).unwrap();
564                if worktree.host_connection_id != *connection_id {
565                    assert!(worktree.share().unwrap().guests.contains_key(connection_id));
566                }
567            }
568            for channel_id in &connection.channels {
569                let channel = self.channels.get(channel_id).unwrap();
570                assert!(channel.connection_ids.contains(connection_id));
571            }
572            assert!(self
573                .connections_by_user_id
574                .get(&connection.user_id)
575                .unwrap()
576                .contains(connection_id));
577        }
578
579        for (user_id, connection_ids) in &self.connections_by_user_id {
580            for connection_id in connection_ids {
581                assert_eq!(
582                    self.connections.get(connection_id).unwrap().user_id,
583                    *user_id
584                );
585            }
586        }
587
588        for (worktree_id, worktree) in &self.worktrees {
589            let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
590            assert!(host_connection.worktrees.contains(worktree_id));
591
592            for authorized_user_ids in &worktree.authorized_user_ids {
593                let visible_worktree_ids = self
594                    .visible_worktrees_by_user_id
595                    .get(authorized_user_ids)
596                    .unwrap();
597                assert!(visible_worktree_ids.contains(worktree_id));
598            }
599
600            if let Some(share) = &worktree.share {
601                for guest_connection_id in share.guests.keys() {
602                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
603                    assert!(guest_connection.worktrees.contains(worktree_id));
604                }
605                assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
606                assert_eq!(
607                    share.active_replica_ids,
608                    share
609                        .guests
610                        .values()
611                        .map(|(replica_id, _)| *replica_id)
612                        .collect::<HashSet<_>>(),
613                );
614            }
615        }
616
617        for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
618            for worktree_id in visible_worktree_ids {
619                let worktree = self.worktrees.get(worktree_id).unwrap();
620                assert!(worktree.authorized_user_ids.contains(user_id));
621            }
622        }
623
624        for (channel_id, channel) in &self.channels {
625            for connection_id in &channel.connection_ids {
626                let connection = self.connections.get(connection_id).unwrap();
627                assert!(connection.channels.contains(channel_id));
628            }
629        }
630    }
631}
632
633impl Worktree {
634    pub fn connection_ids(&self) -> Vec<ConnectionId> {
635        if let Some(share) = &self.share {
636            share
637                .guests
638                .keys()
639                .copied()
640                .chain(Some(self.host_connection_id))
641                .collect()
642        } else {
643            vec![self.host_connection_id]
644        }
645    }
646
647    pub fn share(&self) -> tide::Result<&ProjectShare> {
648        Ok(self
649            .share
650            .as_ref()
651            .ok_or_else(|| anyhow!("worktree is not shared"))?)
652    }
653
654    fn share_mut(&mut self) -> tide::Result<&mut ProjectShare> {
655        Ok(self
656            .share
657            .as_mut()
658            .ok_or_else(|| anyhow!("worktree is not shared"))?)
659    }
660}
661
662impl Channel {
663    fn connection_ids(&self) -> Vec<ConnectionId> {
664        self.connection_ids.iter().copied().collect()
665    }
666}