store.rs

  1use crate::db::{ChannelId, UserId};
  2use anyhow::anyhow;
  3use std::collections::{hash_map, HashMap, HashSet};
  4use zrpc::{proto, ConnectionId};
  5
  6#[derive(Default)]
  7pub struct Store {
  8    connections: HashMap<ConnectionId, ConnectionState>,
  9    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 10    worktrees: HashMap<u64, Worktree>,
 11    visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
 12    channels: HashMap<ChannelId, Channel>,
 13    next_worktree_id: u64,
 14}
 15
 16struct ConnectionState {
 17    user_id: UserId,
 18    worktrees: HashSet<u64>,
 19    channels: HashSet<ChannelId>,
 20}
 21
 22pub struct Worktree {
 23    pub host_connection_id: ConnectionId,
 24    pub collaborator_user_ids: Vec<UserId>,
 25    pub root_name: String,
 26    pub share: Option<WorktreeShare>,
 27}
 28
 29pub struct WorktreeShare {
 30    pub guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
 31    pub active_replica_ids: HashSet<ReplicaId>,
 32    pub entries: HashMap<u64, proto::Entry>,
 33}
 34
 35#[derive(Default)]
 36pub struct Channel {
 37    pub connection_ids: HashSet<ConnectionId>,
 38}
 39
 40pub type ReplicaId = u16;
 41
 42#[derive(Default)]
 43pub struct RemovedConnectionState {
 44    pub hosted_worktrees: HashMap<u64, Worktree>,
 45    pub guest_worktree_ids: HashMap<u64, Vec<ConnectionId>>,
 46    pub collaborator_ids: HashSet<UserId>,
 47}
 48
 49pub struct JoinedWorktree<'a> {
 50    pub replica_id: ReplicaId,
 51    pub worktree: &'a Worktree,
 52}
 53
 54pub struct UnsharedWorktree {
 55    pub connection_ids: Vec<ConnectionId>,
 56    pub collaborator_ids: Vec<UserId>,
 57}
 58
 59pub struct LeftWorktree {
 60    pub connection_ids: Vec<ConnectionId>,
 61    pub collaborator_ids: Vec<UserId>,
 62}
 63
 64impl Store {
 65    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 66        self.connections.insert(
 67            connection_id,
 68            ConnectionState {
 69                user_id,
 70                worktrees: Default::default(),
 71                channels: Default::default(),
 72            },
 73        );
 74        self.connections_by_user_id
 75            .entry(user_id)
 76            .or_default()
 77            .insert(connection_id);
 78    }
 79
 80    pub fn remove_connection(
 81        &mut self,
 82        connection_id: ConnectionId,
 83    ) -> tide::Result<RemovedConnectionState> {
 84        let connection = if let Some(connection) = self.connections.get(&connection_id) {
 85            connection
 86        } else {
 87            return Err(anyhow!("no such connection"))?;
 88        };
 89
 90        for channel_id in &connection.channels {
 91            if let Some(channel) = self.channels.get_mut(&channel_id) {
 92                channel.connection_ids.remove(&connection_id);
 93            }
 94        }
 95
 96        let user_connections = self
 97            .connections_by_user_id
 98            .get_mut(&connection.user_id)
 99            .unwrap();
100        user_connections.remove(&connection_id);
101        if user_connections.is_empty() {
102            self.connections_by_user_id.remove(&connection.user_id);
103        }
104
105        let mut result = RemovedConnectionState::default();
106        for worktree_id in connection.worktrees.clone() {
107            if let Ok(worktree) = self.remove_worktree(worktree_id, connection_id) {
108                result
109                    .collaborator_ids
110                    .extend(worktree.collaborator_user_ids.iter().copied());
111                result.hosted_worktrees.insert(worktree_id, worktree);
112            } else {
113                if let Some(worktree) = self.worktrees.get(&worktree_id) {
114                    result
115                        .guest_worktree_ids
116                        .insert(worktree_id, worktree.connection_ids());
117                    result
118                        .collaborator_ids
119                        .extend(worktree.collaborator_user_ids.iter().copied());
120                }
121            }
122        }
123
124        Ok(result)
125    }
126
127    #[cfg(test)]
128    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
129        self.channels.get(&id)
130    }
131
132    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
133        if let Some(connection) = self.connections.get_mut(&connection_id) {
134            connection.channels.insert(channel_id);
135            self.channels
136                .entry(channel_id)
137                .or_default()
138                .connection_ids
139                .insert(connection_id);
140        }
141    }
142
143    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
144        if let Some(connection) = self.connections.get_mut(&connection_id) {
145            connection.channels.remove(&channel_id);
146            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
147                entry.get_mut().connection_ids.remove(&connection_id);
148                if entry.get_mut().connection_ids.is_empty() {
149                    entry.remove();
150                }
151            }
152        }
153    }
154
155    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
156        Ok(self
157            .connections
158            .get(&connection_id)
159            .ok_or_else(|| anyhow!("unknown connection"))?
160            .user_id)
161    }
162
163    pub fn connection_ids_for_user<'a>(
164        &'a self,
165        user_id: UserId,
166    ) -> impl 'a + Iterator<Item = ConnectionId> {
167        self.connections_by_user_id
168            .get(&user_id)
169            .into_iter()
170            .flatten()
171            .copied()
172    }
173
174    pub fn collaborators_for_user(&self, user_id: UserId) -> Vec<proto::Collaborator> {
175        let mut collaborators = HashMap::new();
176        for worktree_id in self
177            .visible_worktrees_by_user_id
178            .get(&user_id)
179            .unwrap_or(&HashSet::new())
180        {
181            let worktree = &self.worktrees[worktree_id];
182
183            let mut guests = HashSet::new();
184            if let Ok(share) = worktree.share() {
185                for guest_connection_id in share.guest_connection_ids.keys() {
186                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
187                        guests.insert(user_id.to_proto());
188                    }
189                }
190            }
191
192            if let Ok(host_user_id) = self.user_id_for_connection(worktree.host_connection_id) {
193                collaborators
194                    .entry(host_user_id)
195                    .or_insert_with(|| proto::Collaborator {
196                        user_id: host_user_id.to_proto(),
197                        worktrees: Vec::new(),
198                    })
199                    .worktrees
200                    .push(proto::WorktreeMetadata {
201                        id: *worktree_id,
202                        root_name: worktree.root_name.clone(),
203                        is_shared: worktree.share.is_some(),
204                        guests: guests.into_iter().collect(),
205                    });
206            }
207        }
208
209        collaborators.into_values().collect()
210    }
211
212    pub fn add_worktree(&mut self, worktree: Worktree) -> u64 {
213        let worktree_id = self.next_worktree_id;
214        for collaborator_user_id in &worktree.collaborator_user_ids {
215            self.visible_worktrees_by_user_id
216                .entry(*collaborator_user_id)
217                .or_default()
218                .insert(worktree_id);
219        }
220        self.next_worktree_id += 1;
221        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
222            connection.worktrees.insert(worktree_id);
223        }
224        self.worktrees.insert(worktree_id, worktree);
225
226        #[cfg(test)]
227        self.check_invariants();
228
229        worktree_id
230    }
231
232    pub fn remove_worktree(
233        &mut self,
234        worktree_id: u64,
235        acting_connection_id: ConnectionId,
236    ) -> tide::Result<Worktree> {
237        let worktree = if let hash_map::Entry::Occupied(e) = self.worktrees.entry(worktree_id) {
238            if e.get().host_connection_id != acting_connection_id {
239                Err(anyhow!("not your worktree"))?;
240            }
241            e.remove()
242        } else {
243            return Err(anyhow!("no such worktree"))?;
244        };
245
246        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
247            connection.worktrees.remove(&worktree_id);
248        }
249
250        if let Some(share) = &worktree.share {
251            for connection_id in share.guest_connection_ids.keys() {
252                if let Some(connection) = self.connections.get_mut(connection_id) {
253                    connection.worktrees.remove(&worktree_id);
254                }
255            }
256        }
257
258        for collaborator_user_id in &worktree.collaborator_user_ids {
259            if let Some(visible_worktrees) = self
260                .visible_worktrees_by_user_id
261                .get_mut(&collaborator_user_id)
262            {
263                visible_worktrees.remove(&worktree_id);
264            }
265        }
266
267        #[cfg(test)]
268        self.check_invariants();
269
270        Ok(worktree)
271    }
272
273    pub fn share_worktree(
274        &mut self,
275        worktree_id: u64,
276        connection_id: ConnectionId,
277        entries: HashMap<u64, proto::Entry>,
278    ) -> Option<Vec<UserId>> {
279        if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
280            if worktree.host_connection_id == connection_id {
281                worktree.share = Some(WorktreeShare {
282                    guest_connection_ids: Default::default(),
283                    active_replica_ids: Default::default(),
284                    entries,
285                });
286                return Some(worktree.collaborator_user_ids.clone());
287            }
288        }
289        None
290    }
291
292    pub fn unshare_worktree(
293        &mut self,
294        worktree_id: u64,
295        acting_connection_id: ConnectionId,
296    ) -> tide::Result<UnsharedWorktree> {
297        let worktree = if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
298            worktree
299        } else {
300            return Err(anyhow!("no such worktree"))?;
301        };
302
303        if worktree.host_connection_id != acting_connection_id {
304            return Err(anyhow!("not your worktree"))?;
305        }
306
307        let connection_ids = worktree.connection_ids();
308        let collaborator_ids = worktree.collaborator_user_ids.clone();
309        if let Some(share) = worktree.share.take() {
310            for connection_id in share.guest_connection_ids.into_keys() {
311                if let Some(connection) = self.connections.get_mut(&connection_id) {
312                    connection.worktrees.remove(&worktree_id);
313                }
314            }
315
316            #[cfg(test)]
317            self.check_invariants();
318
319            Ok(UnsharedWorktree {
320                connection_ids,
321                collaborator_ids,
322            })
323        } else {
324            Err(anyhow!("worktree is not shared"))?
325        }
326    }
327
328    pub fn join_worktree(
329        &mut self,
330        connection_id: ConnectionId,
331        user_id: UserId,
332        worktree_id: u64,
333    ) -> tide::Result<JoinedWorktree> {
334        let connection = self
335            .connections
336            .get_mut(&connection_id)
337            .ok_or_else(|| anyhow!("no such connection"))?;
338        let worktree = self
339            .worktrees
340            .get_mut(&worktree_id)
341            .and_then(|worktree| {
342                if worktree.collaborator_user_ids.contains(&user_id) {
343                    Some(worktree)
344                } else {
345                    None
346                }
347            })
348            .ok_or_else(|| anyhow!("no such worktree"))?;
349
350        let share = worktree.share_mut()?;
351        connection.worktrees.insert(worktree_id);
352
353        let mut replica_id = 1;
354        while share.active_replica_ids.contains(&replica_id) {
355            replica_id += 1;
356        }
357        share.active_replica_ids.insert(replica_id);
358        share.guest_connection_ids.insert(connection_id, replica_id);
359
360        #[cfg(test)]
361        self.check_invariants();
362
363        Ok(JoinedWorktree {
364            replica_id,
365            worktree: &self.worktrees[&worktree_id],
366        })
367    }
368
369    pub fn leave_worktree(
370        &mut self,
371        connection_id: ConnectionId,
372        worktree_id: u64,
373    ) -> Option<LeftWorktree> {
374        let worktree = self.worktrees.get_mut(&worktree_id)?;
375        let share = worktree.share.as_mut()?;
376        let replica_id = share.guest_connection_ids.remove(&connection_id)?;
377        share.active_replica_ids.remove(&replica_id);
378
379        let connection = self.connections.get_mut(&connection_id)?;
380        connection.worktrees.remove(&worktree_id);
381
382        let connection_ids = worktree.connection_ids();
383        let collaborator_ids = worktree.collaborator_user_ids.clone();
384
385        #[cfg(test)]
386        self.check_invariants();
387
388        Some(LeftWorktree {
389            connection_ids,
390            collaborator_ids,
391        })
392    }
393
394    pub fn update_worktree(
395        &mut self,
396        connection_id: ConnectionId,
397        worktree_id: u64,
398        removed_entries: &[u64],
399        updated_entries: &[proto::Entry],
400    ) -> tide::Result<Vec<ConnectionId>> {
401        let worktree = self.write_worktree(worktree_id, connection_id)?;
402        let share = worktree.share_mut()?;
403        for entry_id in removed_entries {
404            share.entries.remove(&entry_id);
405        }
406        for entry in updated_entries {
407            share.entries.insert(entry.id, entry.clone());
408        }
409        Ok(worktree.connection_ids())
410    }
411
412    pub fn worktree_host_connection_id(
413        &self,
414        connection_id: ConnectionId,
415        worktree_id: u64,
416    ) -> tide::Result<ConnectionId> {
417        Ok(self
418            .read_worktree(worktree_id, connection_id)?
419            .host_connection_id)
420    }
421
422    pub fn worktree_guest_connection_ids(
423        &self,
424        connection_id: ConnectionId,
425        worktree_id: u64,
426    ) -> tide::Result<Vec<ConnectionId>> {
427        Ok(self
428            .read_worktree(worktree_id, connection_id)?
429            .share()?
430            .guest_connection_ids
431            .keys()
432            .copied()
433            .collect())
434    }
435
436    pub fn worktree_connection_ids(
437        &self,
438        connection_id: ConnectionId,
439        worktree_id: u64,
440    ) -> tide::Result<Vec<ConnectionId>> {
441        Ok(self
442            .read_worktree(worktree_id, connection_id)?
443            .connection_ids())
444    }
445
446    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
447        Some(self.channels.get(&channel_id)?.connection_ids())
448    }
449
450    fn read_worktree(
451        &self,
452        worktree_id: u64,
453        connection_id: ConnectionId,
454    ) -> tide::Result<&Worktree> {
455        let worktree = self
456            .worktrees
457            .get(&worktree_id)
458            .ok_or_else(|| anyhow!("worktree not found"))?;
459
460        if worktree.host_connection_id == connection_id
461            || worktree
462                .share()?
463                .guest_connection_ids
464                .contains_key(&connection_id)
465        {
466            Ok(worktree)
467        } else {
468            Err(anyhow!(
469                "{} is not a member of worktree {}",
470                connection_id,
471                worktree_id
472            ))?
473        }
474    }
475
476    fn write_worktree(
477        &mut self,
478        worktree_id: u64,
479        connection_id: ConnectionId,
480    ) -> tide::Result<&mut Worktree> {
481        let worktree = self
482            .worktrees
483            .get_mut(&worktree_id)
484            .ok_or_else(|| anyhow!("worktree not found"))?;
485
486        if worktree.host_connection_id == connection_id
487            || worktree.share.as_ref().map_or(false, |share| {
488                share.guest_connection_ids.contains_key(&connection_id)
489            })
490        {
491            Ok(worktree)
492        } else {
493            Err(anyhow!(
494                "{} is not a member of worktree {}",
495                connection_id,
496                worktree_id
497            ))?
498        }
499    }
500
501    #[cfg(test)]
502    fn check_invariants(&self) {
503        for (connection_id, connection) in &self.connections {
504            for worktree_id in &connection.worktrees {
505                let worktree = &self.worktrees.get(&worktree_id).unwrap();
506                if worktree.host_connection_id != *connection_id {
507                    assert!(worktree
508                        .share()
509                        .unwrap()
510                        .guest_connection_ids
511                        .contains_key(connection_id));
512                }
513            }
514            for channel_id in &connection.channels {
515                let channel = self.channels.get(channel_id).unwrap();
516                assert!(channel.connection_ids.contains(connection_id));
517            }
518            assert!(self
519                .connections_by_user_id
520                .get(&connection.user_id)
521                .unwrap()
522                .contains(connection_id));
523        }
524
525        for (user_id, connection_ids) in &self.connections_by_user_id {
526            for connection_id in connection_ids {
527                assert_eq!(
528                    self.connections.get(connection_id).unwrap().user_id,
529                    *user_id
530                );
531            }
532        }
533
534        for (worktree_id, worktree) in &self.worktrees {
535            let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
536            assert!(host_connection.worktrees.contains(worktree_id));
537
538            for collaborator_id in &worktree.collaborator_user_ids {
539                let visible_worktree_ids = self
540                    .visible_worktrees_by_user_id
541                    .get(collaborator_id)
542                    .unwrap();
543                assert!(visible_worktree_ids.contains(worktree_id));
544            }
545
546            if let Some(share) = &worktree.share {
547                for guest_connection_id in share.guest_connection_ids.keys() {
548                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
549                    assert!(guest_connection.worktrees.contains(worktree_id));
550                }
551                assert_eq!(
552                    share.active_replica_ids.len(),
553                    share.guest_connection_ids.len(),
554                );
555                assert_eq!(
556                    share.active_replica_ids,
557                    share
558                        .guest_connection_ids
559                        .values()
560                        .copied()
561                        .collect::<HashSet<_>>(),
562                );
563            }
564        }
565
566        for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
567            for worktree_id in visible_worktree_ids {
568                let worktree = self.worktrees.get(worktree_id).unwrap();
569                assert!(worktree.collaborator_user_ids.contains(user_id));
570            }
571        }
572
573        for (channel_id, channel) in &self.channels {
574            for connection_id in &channel.connection_ids {
575                let connection = self.connections.get(connection_id).unwrap();
576                assert!(connection.channels.contains(channel_id));
577            }
578        }
579    }
580}
581
582impl Worktree {
583    pub fn connection_ids(&self) -> Vec<ConnectionId> {
584        if let Some(share) = &self.share {
585            share
586                .guest_connection_ids
587                .keys()
588                .copied()
589                .chain(Some(self.host_connection_id))
590                .collect()
591        } else {
592            vec![self.host_connection_id]
593        }
594    }
595
596    pub fn share(&self) -> tide::Result<&WorktreeShare> {
597        Ok(self
598            .share
599            .as_ref()
600            .ok_or_else(|| anyhow!("worktree is not shared"))?)
601    }
602
603    fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
604        Ok(self
605            .share
606            .as_mut()
607            .ok_or_else(|| anyhow!("worktree is not shared"))?)
608    }
609}
610
611impl Channel {
612    fn connection_ids(&self) -> Vec<ConnectionId> {
613        self.connection_ids.iter().copied().collect()
614    }
615}