store.rs

  1use crate::db::{ChannelId, UserId};
  2use anyhow::anyhow;
  3use std::collections::{hash_map, HashMap, HashSet};
  4use zrpc::{proto, ConnectionId};
  5
  6#[derive(Default)]
  7pub struct Store {
  8    connections: HashMap<ConnectionId, ConnectionState>,
  9    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 10    worktrees: HashMap<u64, Worktree>,
 11    visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
 12    channels: HashMap<ChannelId, Channel>,
 13    next_worktree_id: u64,
 14}
 15
 16struct ConnectionState {
 17    user_id: UserId,
 18    worktrees: HashSet<u64>,
 19    channels: HashSet<ChannelId>,
 20}
 21
 22pub struct Worktree {
 23    pub host_connection_id: ConnectionId,
 24    pub collaborator_user_ids: Vec<UserId>,
 25    pub root_name: String,
 26    pub share: Option<WorktreeShare>,
 27}
 28
 29pub struct WorktreeShare {
 30    pub guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
 31    pub active_replica_ids: HashSet<ReplicaId>,
 32    pub entries: HashMap<u64, proto::Entry>,
 33}
 34
 35#[derive(Default)]
 36pub struct Channel {
 37    pub connection_ids: HashSet<ConnectionId>,
 38}
 39
 40pub type ReplicaId = u16;
 41
 42#[derive(Default)]
 43pub struct RemovedConnectionState {
 44    pub hosted_worktrees: HashMap<u64, Worktree>,
 45    pub guest_worktree_ids: HashMap<u64, Vec<ConnectionId>>,
 46    pub collaborator_ids: HashSet<UserId>,
 47}
 48
 49pub struct JoinedWorktree<'a> {
 50    pub replica_id: ReplicaId,
 51    pub worktree: &'a Worktree,
 52}
 53
 54pub struct UnsharedWorktree {
 55    pub connection_ids: Vec<ConnectionId>,
 56    pub collaborator_ids: Vec<UserId>,
 57}
 58
 59pub struct LeftWorktree {
 60    pub connection_ids: Vec<ConnectionId>,
 61    pub collaborator_ids: Vec<UserId>,
 62}
 63
 64impl Store {
 65    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 66        self.connections.insert(
 67            connection_id,
 68            ConnectionState {
 69                user_id,
 70                worktrees: Default::default(),
 71                channels: Default::default(),
 72            },
 73        );
 74        self.connections_by_user_id
 75            .entry(user_id)
 76            .or_default()
 77            .insert(connection_id);
 78    }
 79
 80    pub fn remove_connection(
 81        &mut self,
 82        connection_id: ConnectionId,
 83    ) -> tide::Result<RemovedConnectionState> {
 84        let connection = if let Some(connection) = self.connections.remove(&connection_id) {
 85            connection
 86        } else {
 87            return Err(anyhow!("no such connection"))?;
 88        };
 89
 90        for channel_id in &connection.channels {
 91            if let Some(channel) = self.channels.get_mut(&channel_id) {
 92                channel.connection_ids.remove(&connection_id);
 93            }
 94        }
 95
 96        let user_connections = self
 97            .connections_by_user_id
 98            .get_mut(&connection.user_id)
 99            .unwrap();
100        user_connections.remove(&connection_id);
101        if user_connections.is_empty() {
102            self.connections_by_user_id.remove(&connection.user_id);
103        }
104
105        let mut result = RemovedConnectionState::default();
106        for worktree_id in connection.worktrees.clone() {
107            if let Ok(worktree) = self.remove_worktree(worktree_id, connection_id) {
108                result
109                    .collaborator_ids
110                    .extend(worktree.collaborator_user_ids.iter().copied());
111                result.hosted_worktrees.insert(worktree_id, worktree);
112            } else if let Some(worktree) = self.leave_worktree(connection_id, worktree_id) {
113                result
114                    .guest_worktree_ids
115                    .insert(worktree_id, worktree.connection_ids);
116                result.collaborator_ids.extend(worktree.collaborator_ids);
117            }
118        }
119
120        #[cfg(test)]
121        self.check_invariants();
122
123        Ok(result)
124    }
125
126    #[cfg(test)]
127    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
128        self.channels.get(&id)
129    }
130
131    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
132        if let Some(connection) = self.connections.get_mut(&connection_id) {
133            connection.channels.insert(channel_id);
134            self.channels
135                .entry(channel_id)
136                .or_default()
137                .connection_ids
138                .insert(connection_id);
139        }
140    }
141
142    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
143        if let Some(connection) = self.connections.get_mut(&connection_id) {
144            connection.channels.remove(&channel_id);
145            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
146                entry.get_mut().connection_ids.remove(&connection_id);
147                if entry.get_mut().connection_ids.is_empty() {
148                    entry.remove();
149                }
150            }
151        }
152    }
153
154    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
155        Ok(self
156            .connections
157            .get(&connection_id)
158            .ok_or_else(|| anyhow!("unknown connection"))?
159            .user_id)
160    }
161
162    pub fn connection_ids_for_user<'a>(
163        &'a self,
164        user_id: UserId,
165    ) -> impl 'a + Iterator<Item = ConnectionId> {
166        self.connections_by_user_id
167            .get(&user_id)
168            .into_iter()
169            .flatten()
170            .copied()
171    }
172
173    pub fn collaborators_for_user(&self, user_id: UserId) -> Vec<proto::Collaborator> {
174        let mut collaborators = HashMap::new();
175        for worktree_id in self
176            .visible_worktrees_by_user_id
177            .get(&user_id)
178            .unwrap_or(&HashSet::new())
179        {
180            let worktree = &self.worktrees[worktree_id];
181
182            let mut guests = HashSet::new();
183            if let Ok(share) = worktree.share() {
184                for guest_connection_id in share.guest_connection_ids.keys() {
185                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
186                        guests.insert(user_id.to_proto());
187                    }
188                }
189            }
190
191            if let Ok(host_user_id) = self.user_id_for_connection(worktree.host_connection_id) {
192                collaborators
193                    .entry(host_user_id)
194                    .or_insert_with(|| proto::Collaborator {
195                        user_id: host_user_id.to_proto(),
196                        worktrees: Vec::new(),
197                    })
198                    .worktrees
199                    .push(proto::WorktreeMetadata {
200                        id: *worktree_id,
201                        root_name: worktree.root_name.clone(),
202                        is_shared: worktree.share.is_some(),
203                        guests: guests.into_iter().collect(),
204                    });
205            }
206        }
207
208        collaborators.into_values().collect()
209    }
210
211    pub fn add_worktree(&mut self, worktree: Worktree) -> u64 {
212        let worktree_id = self.next_worktree_id;
213        for collaborator_user_id in &worktree.collaborator_user_ids {
214            self.visible_worktrees_by_user_id
215                .entry(*collaborator_user_id)
216                .or_default()
217                .insert(worktree_id);
218        }
219        self.next_worktree_id += 1;
220        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
221            connection.worktrees.insert(worktree_id);
222        }
223        self.worktrees.insert(worktree_id, worktree);
224
225        #[cfg(test)]
226        self.check_invariants();
227
228        worktree_id
229    }
230
231    pub fn remove_worktree(
232        &mut self,
233        worktree_id: u64,
234        acting_connection_id: ConnectionId,
235    ) -> tide::Result<Worktree> {
236        let worktree = if let hash_map::Entry::Occupied(e) = self.worktrees.entry(worktree_id) {
237            if e.get().host_connection_id != acting_connection_id {
238                Err(anyhow!("not your worktree"))?;
239            }
240            e.remove()
241        } else {
242            return Err(anyhow!("no such worktree"))?;
243        };
244
245        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
246            connection.worktrees.remove(&worktree_id);
247        }
248
249        if let Some(share) = &worktree.share {
250            for connection_id in share.guest_connection_ids.keys() {
251                if let Some(connection) = self.connections.get_mut(connection_id) {
252                    connection.worktrees.remove(&worktree_id);
253                }
254            }
255        }
256
257        for collaborator_user_id in &worktree.collaborator_user_ids {
258            if let Some(visible_worktrees) = self
259                .visible_worktrees_by_user_id
260                .get_mut(&collaborator_user_id)
261            {
262                visible_worktrees.remove(&worktree_id);
263            }
264        }
265
266        #[cfg(test)]
267        self.check_invariants();
268
269        Ok(worktree)
270    }
271
272    pub fn share_worktree(
273        &mut self,
274        worktree_id: u64,
275        connection_id: ConnectionId,
276        entries: HashMap<u64, proto::Entry>,
277    ) -> Option<Vec<UserId>> {
278        if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
279            if worktree.host_connection_id == connection_id {
280                worktree.share = Some(WorktreeShare {
281                    guest_connection_ids: Default::default(),
282                    active_replica_ids: Default::default(),
283                    entries,
284                });
285                return Some(worktree.collaborator_user_ids.clone());
286            }
287        }
288        None
289    }
290
291    pub fn unshare_worktree(
292        &mut self,
293        worktree_id: u64,
294        acting_connection_id: ConnectionId,
295    ) -> tide::Result<UnsharedWorktree> {
296        let worktree = if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
297            worktree
298        } else {
299            return Err(anyhow!("no such worktree"))?;
300        };
301
302        if worktree.host_connection_id != acting_connection_id {
303            return Err(anyhow!("not your worktree"))?;
304        }
305
306        let connection_ids = worktree.connection_ids();
307        let collaborator_ids = worktree.collaborator_user_ids.clone();
308        if let Some(share) = worktree.share.take() {
309            for connection_id in share.guest_connection_ids.into_keys() {
310                if let Some(connection) = self.connections.get_mut(&connection_id) {
311                    connection.worktrees.remove(&worktree_id);
312                }
313            }
314
315            #[cfg(test)]
316            self.check_invariants();
317
318            Ok(UnsharedWorktree {
319                connection_ids,
320                collaborator_ids,
321            })
322        } else {
323            Err(anyhow!("worktree is not shared"))?
324        }
325    }
326
327    pub fn join_worktree(
328        &mut self,
329        connection_id: ConnectionId,
330        user_id: UserId,
331        worktree_id: u64,
332    ) -> tide::Result<JoinedWorktree> {
333        let connection = self
334            .connections
335            .get_mut(&connection_id)
336            .ok_or_else(|| anyhow!("no such connection"))?;
337        let worktree = self
338            .worktrees
339            .get_mut(&worktree_id)
340            .and_then(|worktree| {
341                if worktree.collaborator_user_ids.contains(&user_id) {
342                    Some(worktree)
343                } else {
344                    None
345                }
346            })
347            .ok_or_else(|| anyhow!("no such worktree"))?;
348
349        let share = worktree.share_mut()?;
350        connection.worktrees.insert(worktree_id);
351
352        let mut replica_id = 1;
353        while share.active_replica_ids.contains(&replica_id) {
354            replica_id += 1;
355        }
356        share.active_replica_ids.insert(replica_id);
357        share.guest_connection_ids.insert(connection_id, replica_id);
358
359        #[cfg(test)]
360        self.check_invariants();
361
362        Ok(JoinedWorktree {
363            replica_id,
364            worktree: &self.worktrees[&worktree_id],
365        })
366    }
367
368    pub fn leave_worktree(
369        &mut self,
370        connection_id: ConnectionId,
371        worktree_id: u64,
372    ) -> Option<LeftWorktree> {
373        let worktree = self.worktrees.get_mut(&worktree_id)?;
374        let share = worktree.share.as_mut()?;
375        let replica_id = share.guest_connection_ids.remove(&connection_id)?;
376        share.active_replica_ids.remove(&replica_id);
377
378        if let Some(connection) = self.connections.get_mut(&connection_id) {
379            connection.worktrees.remove(&worktree_id);
380        }
381
382        let connection_ids = worktree.connection_ids();
383        let collaborator_ids = worktree.collaborator_user_ids.clone();
384
385        #[cfg(test)]
386        self.check_invariants();
387
388        Some(LeftWorktree {
389            connection_ids,
390            collaborator_ids,
391        })
392    }
393
394    pub fn update_worktree(
395        &mut self,
396        connection_id: ConnectionId,
397        worktree_id: u64,
398        removed_entries: &[u64],
399        updated_entries: &[proto::Entry],
400    ) -> tide::Result<Vec<ConnectionId>> {
401        let worktree = self.write_worktree(worktree_id, connection_id)?;
402        let share = worktree.share_mut()?;
403        for entry_id in removed_entries {
404            share.entries.remove(&entry_id);
405        }
406        for entry in updated_entries {
407            share.entries.insert(entry.id, entry.clone());
408        }
409        Ok(worktree.connection_ids())
410    }
411
412    pub fn worktree_host_connection_id(
413        &self,
414        connection_id: ConnectionId,
415        worktree_id: u64,
416    ) -> tide::Result<ConnectionId> {
417        Ok(self
418            .read_worktree(worktree_id, connection_id)?
419            .host_connection_id)
420    }
421
422    pub fn worktree_guest_connection_ids(
423        &self,
424        connection_id: ConnectionId,
425        worktree_id: u64,
426    ) -> tide::Result<Vec<ConnectionId>> {
427        Ok(self
428            .read_worktree(worktree_id, connection_id)?
429            .share()?
430            .guest_connection_ids
431            .keys()
432            .copied()
433            .collect())
434    }
435
436    pub fn worktree_connection_ids(
437        &self,
438        connection_id: ConnectionId,
439        worktree_id: u64,
440    ) -> tide::Result<Vec<ConnectionId>> {
441        Ok(self
442            .read_worktree(worktree_id, connection_id)?
443            .connection_ids())
444    }
445
446    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
447        Some(self.channels.get(&channel_id)?.connection_ids())
448    }
449
450    fn read_worktree(
451        &self,
452        worktree_id: u64,
453        connection_id: ConnectionId,
454    ) -> tide::Result<&Worktree> {
455        let worktree = self
456            .worktrees
457            .get(&worktree_id)
458            .ok_or_else(|| anyhow!("worktree not found"))?;
459
460        if worktree.host_connection_id == connection_id
461            || worktree
462                .share()?
463                .guest_connection_ids
464                .contains_key(&connection_id)
465        {
466            Ok(worktree)
467        } else {
468            Err(anyhow!(
469                "{} is not a member of worktree {}",
470                connection_id,
471                worktree_id
472            ))?
473        }
474    }
475
476    fn write_worktree(
477        &mut self,
478        worktree_id: u64,
479        connection_id: ConnectionId,
480    ) -> tide::Result<&mut Worktree> {
481        let worktree = self
482            .worktrees
483            .get_mut(&worktree_id)
484            .ok_or_else(|| anyhow!("worktree not found"))?;
485
486        if worktree.host_connection_id == connection_id
487            || worktree.share.as_ref().map_or(false, |share| {
488                share.guest_connection_ids.contains_key(&connection_id)
489            })
490        {
491            Ok(worktree)
492        } else {
493            Err(anyhow!(
494                "{} is not a member of worktree {}",
495                connection_id,
496                worktree_id
497            ))?
498        }
499    }
500
501    #[cfg(test)]
502    fn check_invariants(&self) {
503        for (connection_id, connection) in &self.connections {
504            for worktree_id in &connection.worktrees {
505                let worktree = &self.worktrees.get(&worktree_id).unwrap();
506                if worktree.host_connection_id != *connection_id {
507                    assert!(worktree
508                        .share()
509                        .unwrap()
510                        .guest_connection_ids
511                        .contains_key(connection_id));
512                }
513            }
514            for channel_id in &connection.channels {
515                let channel = self.channels.get(channel_id).unwrap();
516                assert!(channel.connection_ids.contains(connection_id));
517            }
518            assert!(self
519                .connections_by_user_id
520                .get(&connection.user_id)
521                .unwrap()
522                .contains(connection_id));
523        }
524
525        for (user_id, connection_ids) in &self.connections_by_user_id {
526            for connection_id in connection_ids {
527                assert_eq!(
528                    self.connections.get(connection_id).unwrap().user_id,
529                    *user_id
530                );
531            }
532        }
533
534        for (worktree_id, worktree) in &self.worktrees {
535            let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
536            assert!(host_connection.worktrees.contains(worktree_id));
537
538            for collaborator_id in &worktree.collaborator_user_ids {
539                let visible_worktree_ids = self
540                    .visible_worktrees_by_user_id
541                    .get(collaborator_id)
542                    .unwrap();
543                assert!(visible_worktree_ids.contains(worktree_id));
544            }
545
546            if let Some(share) = &worktree.share {
547                for guest_connection_id in share.guest_connection_ids.keys() {
548                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
549                    assert!(guest_connection.worktrees.contains(worktree_id));
550                }
551                assert_eq!(
552                    share.active_replica_ids.len(),
553                    share.guest_connection_ids.len(),
554                );
555                assert_eq!(
556                    share.active_replica_ids,
557                    share
558                        .guest_connection_ids
559                        .values()
560                        .copied()
561                        .collect::<HashSet<_>>(),
562                );
563            }
564        }
565
566        for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
567            for worktree_id in visible_worktree_ids {
568                let worktree = self.worktrees.get(worktree_id).unwrap();
569                assert!(worktree.collaborator_user_ids.contains(user_id));
570            }
571        }
572
573        for (channel_id, channel) in &self.channels {
574            for connection_id in &channel.connection_ids {
575                let connection = self.connections.get(connection_id).unwrap();
576                assert!(connection.channels.contains(channel_id));
577            }
578        }
579    }
580}
581
582impl Worktree {
583    pub fn connection_ids(&self) -> Vec<ConnectionId> {
584        if let Some(share) = &self.share {
585            share
586                .guest_connection_ids
587                .keys()
588                .copied()
589                .chain(Some(self.host_connection_id))
590                .collect()
591        } else {
592            vec![self.host_connection_id]
593        }
594    }
595
596    pub fn share(&self) -> tide::Result<&WorktreeShare> {
597        Ok(self
598            .share
599            .as_ref()
600            .ok_or_else(|| anyhow!("worktree is not shared"))?)
601    }
602
603    fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
604        Ok(self
605            .share
606            .as_mut()
607            .ok_or_else(|| anyhow!("worktree is not shared"))?)
608    }
609}
610
611impl Channel {
612    fn connection_ids(&self) -> Vec<ConnectionId> {
613        self.connection_ids.iter().copied().collect()
614    }
615}