store.rs

  1use crate::db::{ChannelId, UserId};
  2use anyhow::anyhow;
  3use std::collections::{hash_map, HashMap, HashSet};
  4use zrpc::{proto, ConnectionId};
  5
  6#[derive(Default)]
  7pub struct Store {
  8    connections: HashMap<ConnectionId, ConnectionState>,
  9    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 10    worktrees: HashMap<u64, Worktree>,
 11    visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
 12    channels: HashMap<ChannelId, Channel>,
 13    next_worktree_id: u64,
 14}
 15
 16struct ConnectionState {
 17    user_id: UserId,
 18    worktrees: HashSet<u64>,
 19    channels: HashSet<ChannelId>,
 20}
 21
 22pub struct Worktree {
 23    pub host_connection_id: ConnectionId,
 24    pub collaborator_user_ids: Vec<UserId>,
 25    pub root_name: String,
 26    pub share: Option<WorktreeShare>,
 27}
 28
 29pub struct WorktreeShare {
 30    pub guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
 31    pub active_replica_ids: HashSet<ReplicaId>,
 32    pub entries: HashMap<u64, proto::Entry>,
 33}
 34
 35#[derive(Default)]
 36pub struct Channel {
 37    pub connection_ids: HashSet<ConnectionId>,
 38}
 39
 40pub type ReplicaId = u16;
 41
 42#[derive(Default)]
 43pub struct RemovedConnectionState {
 44    pub hosted_worktrees: HashMap<u64, Worktree>,
 45    pub guest_worktree_ids: HashMap<u64, Vec<ConnectionId>>,
 46    pub collaborator_ids: HashSet<UserId>,
 47}
 48
 49pub struct JoinedWorktree<'a> {
 50    pub replica_id: ReplicaId,
 51    pub worktree: &'a Worktree,
 52}
 53
 54pub struct UnsharedWorktree {
 55    pub connection_ids: Vec<ConnectionId>,
 56    pub collaborator_ids: Vec<UserId>,
 57}
 58
 59pub struct LeftWorktree {
 60    pub connection_ids: Vec<ConnectionId>,
 61    pub collaborator_ids: Vec<UserId>,
 62}
 63
 64impl Store {
 65    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 66        self.connections.insert(
 67            connection_id,
 68            ConnectionState {
 69                user_id,
 70                worktrees: Default::default(),
 71                channels: Default::default(),
 72            },
 73        );
 74        self.connections_by_user_id
 75            .entry(user_id)
 76            .or_default()
 77            .insert(connection_id);
 78    }
 79
 80    pub fn remove_connection(
 81        &mut self,
 82        connection_id: ConnectionId,
 83    ) -> tide::Result<RemovedConnectionState> {
 84        let connection = if let Some(connection) = self.connections.get(&connection_id) {
 85            connection
 86        } else {
 87            return Err(anyhow!("no such connection"))?;
 88        };
 89
 90        for channel_id in &connection.channels {
 91            if let Some(channel) = self.channels.get_mut(&channel_id) {
 92                channel.connection_ids.remove(&connection_id);
 93            }
 94        }
 95
 96        let user_connections = self
 97            .connections_by_user_id
 98            .get_mut(&connection.user_id)
 99            .unwrap();
100        user_connections.remove(&connection_id);
101        if user_connections.is_empty() {
102            self.connections_by_user_id.remove(&connection.user_id);
103        }
104
105        let mut result = RemovedConnectionState::default();
106        for worktree_id in connection.worktrees.clone() {
107            if let Ok(worktree) = self.remove_worktree(worktree_id, connection_id) {
108                result
109                    .collaborator_ids
110                    .extend(worktree.collaborator_user_ids.iter().copied());
111                result.hosted_worktrees.insert(worktree_id, worktree);
112            } else {
113                if let Some(worktree) = self.worktrees.get(&worktree_id) {
114                    result
115                        .guest_worktree_ids
116                        .insert(worktree_id, worktree.connection_ids());
117                    result
118                        .collaborator_ids
119                        .extend(worktree.collaborator_user_ids.iter().copied());
120                }
121            }
122        }
123
124        Ok(result)
125    }
126
127    #[cfg(test)]
128    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
129        self.channels.get(&id)
130    }
131
132    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
133        if let Some(connection) = self.connections.get_mut(&connection_id) {
134            connection.channels.insert(channel_id);
135            self.channels
136                .entry(channel_id)
137                .or_default()
138                .connection_ids
139                .insert(connection_id);
140        }
141    }
142
143    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
144        if let Some(connection) = self.connections.get_mut(&connection_id) {
145            connection.channels.remove(&channel_id);
146            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
147                entry.get_mut().connection_ids.remove(&connection_id);
148                if entry.get_mut().connection_ids.is_empty() {
149                    entry.remove();
150                }
151            }
152        }
153    }
154
155    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
156        Ok(self
157            .connections
158            .get(&connection_id)
159            .ok_or_else(|| anyhow!("unknown connection"))?
160            .user_id)
161    }
162
163    pub fn connection_ids_for_user<'a>(
164        &'a self,
165        user_id: UserId,
166    ) -> impl 'a + Iterator<Item = ConnectionId> {
167        self.connections_by_user_id
168            .get(&user_id)
169            .into_iter()
170            .flatten()
171            .copied()
172    }
173
174    pub fn collaborators_for_user(&self, user_id: UserId) -> Vec<proto::Collaborator> {
175        let mut collaborators = HashMap::new();
176        for worktree_id in self
177            .visible_worktrees_by_user_id
178            .get(&user_id)
179            .unwrap_or(&HashSet::new())
180        {
181            let worktree = &self.worktrees[worktree_id];
182
183            let mut guests = HashSet::new();
184            if let Ok(share) = worktree.share() {
185                for guest_connection_id in share.guest_connection_ids.keys() {
186                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
187                        guests.insert(user_id.to_proto());
188                    }
189                }
190            }
191
192            if let Ok(host_user_id) = self.user_id_for_connection(worktree.host_connection_id) {
193                collaborators
194                    .entry(host_user_id)
195                    .or_insert_with(|| proto::Collaborator {
196                        user_id: host_user_id.to_proto(),
197                        worktrees: Vec::new(),
198                    })
199                    .worktrees
200                    .push(proto::WorktreeMetadata {
201                        id: *worktree_id,
202                        root_name: worktree.root_name.clone(),
203                        is_shared: worktree.share.is_some(),
204                        guests: guests.into_iter().collect(),
205                    });
206            }
207        }
208
209        collaborators.into_values().collect()
210    }
211
212    pub fn add_worktree(&mut self, worktree: Worktree) -> u64 {
213        let worktree_id = self.next_worktree_id;
214        for collaborator_user_id in &worktree.collaborator_user_ids {
215            self.visible_worktrees_by_user_id
216                .entry(*collaborator_user_id)
217                .or_default()
218                .insert(worktree_id);
219        }
220        self.next_worktree_id += 1;
221        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
222            connection.worktrees.insert(worktree_id);
223        }
224        self.worktrees.insert(worktree_id, worktree);
225
226        #[cfg(test)]
227        self.check_invariants();
228
229        worktree_id
230    }
231
232    pub fn remove_worktree(
233        &mut self,
234        worktree_id: u64,
235        acting_connection_id: ConnectionId,
236    ) -> tide::Result<Worktree> {
237        let worktree = if let hash_map::Entry::Occupied(e) = self.worktrees.entry(worktree_id) {
238            if e.get().host_connection_id != acting_connection_id {
239                Err(anyhow!("not your worktree"))?;
240            }
241            e.remove()
242        } else {
243            return Err(anyhow!("no such worktree"))?;
244        };
245
246        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
247            connection.worktrees.remove(&worktree_id);
248        }
249
250        if let Some(share) = &worktree.share {
251            for connection_id in share.guest_connection_ids.keys() {
252                if let Some(connection) = self.connections.get_mut(connection_id) {
253                    connection.worktrees.remove(&worktree_id);
254                }
255            }
256        }
257
258        for collaborator_user_id in &worktree.collaborator_user_ids {
259            if let Some(visible_worktrees) = self
260                .visible_worktrees_by_user_id
261                .get_mut(&collaborator_user_id)
262            {
263                visible_worktrees.remove(&worktree_id);
264            }
265        }
266
267        #[cfg(test)]
268        self.check_invariants();
269
270        Ok(worktree)
271    }
272
273    pub fn share_worktree(
274        &mut self,
275        worktree_id: u64,
276        connection_id: ConnectionId,
277        entries: HashMap<u64, proto::Entry>,
278    ) -> Option<Vec<UserId>> {
279        if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
280            if worktree.host_connection_id == connection_id {
281                worktree.share = Some(WorktreeShare {
282                    guest_connection_ids: Default::default(),
283                    active_replica_ids: Default::default(),
284                    entries,
285                });
286                return Some(worktree.collaborator_user_ids.clone());
287            }
288        }
289        None
290    }
291
292    pub fn unshare_worktree(
293        &mut self,
294        worktree_id: u64,
295        acting_connection_id: ConnectionId,
296    ) -> tide::Result<UnsharedWorktree> {
297        let worktree = if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
298            worktree
299        } else {
300            return Err(anyhow!("no such worktree"))?;
301        };
302
303        if worktree.host_connection_id != acting_connection_id {
304            return Err(anyhow!("not your worktree"))?;
305        }
306
307        let connection_ids = worktree.connection_ids();
308
309        if let Some(_) = worktree.share.take() {
310            for connection_id in &connection_ids {
311                if let Some(connection) = self.connections.get_mut(connection_id) {
312                    connection.worktrees.remove(&worktree_id);
313                }
314            }
315            Ok(UnsharedWorktree {
316                connection_ids,
317                collaborator_ids: worktree.collaborator_user_ids.clone(),
318            })
319        } else {
320            Err(anyhow!("worktree is not shared"))?
321        }
322    }
323
324    pub fn join_worktree(
325        &mut self,
326        connection_id: ConnectionId,
327        user_id: UserId,
328        worktree_id: u64,
329    ) -> tide::Result<JoinedWorktree> {
330        let connection = self
331            .connections
332            .get_mut(&connection_id)
333            .ok_or_else(|| anyhow!("no such connection"))?;
334        let worktree = self
335            .worktrees
336            .get_mut(&worktree_id)
337            .and_then(|worktree| {
338                if worktree.collaborator_user_ids.contains(&user_id) {
339                    Some(worktree)
340                } else {
341                    None
342                }
343            })
344            .ok_or_else(|| anyhow!("no such worktree"))?;
345
346        let share = worktree.share_mut()?;
347        connection.worktrees.insert(worktree_id);
348
349        let mut replica_id = 1;
350        while share.active_replica_ids.contains(&replica_id) {
351            replica_id += 1;
352        }
353        share.active_replica_ids.insert(replica_id);
354        share.guest_connection_ids.insert(connection_id, replica_id);
355        Ok(JoinedWorktree {
356            replica_id,
357            worktree,
358        })
359    }
360
361    pub fn leave_worktree(
362        &mut self,
363        connection_id: ConnectionId,
364        worktree_id: u64,
365    ) -> Option<LeftWorktree> {
366        let worktree = self.worktrees.get_mut(&worktree_id)?;
367        let share = worktree.share.as_mut()?;
368        let replica_id = share.guest_connection_ids.remove(&connection_id)?;
369        share.active_replica_ids.remove(&replica_id);
370        Some(LeftWorktree {
371            connection_ids: worktree.connection_ids(),
372            collaborator_ids: worktree.collaborator_user_ids.clone(),
373        })
374    }
375
376    pub fn update_worktree(
377        &mut self,
378        connection_id: ConnectionId,
379        worktree_id: u64,
380        removed_entries: &[u64],
381        updated_entries: &[proto::Entry],
382    ) -> tide::Result<Vec<ConnectionId>> {
383        let worktree = self.write_worktree(worktree_id, connection_id)?;
384        let share = worktree.share_mut()?;
385        for entry_id in removed_entries {
386            share.entries.remove(&entry_id);
387        }
388        for entry in updated_entries {
389            share.entries.insert(entry.id, entry.clone());
390        }
391        Ok(worktree.connection_ids())
392    }
393
394    pub fn worktree_host_connection_id(
395        &self,
396        connection_id: ConnectionId,
397        worktree_id: u64,
398    ) -> tide::Result<ConnectionId> {
399        Ok(self
400            .read_worktree(worktree_id, connection_id)?
401            .host_connection_id)
402    }
403
404    pub fn worktree_guest_connection_ids(
405        &self,
406        connection_id: ConnectionId,
407        worktree_id: u64,
408    ) -> tide::Result<Vec<ConnectionId>> {
409        Ok(self
410            .read_worktree(worktree_id, connection_id)?
411            .share()?
412            .guest_connection_ids
413            .keys()
414            .copied()
415            .collect())
416    }
417
418    pub fn worktree_connection_ids(
419        &self,
420        connection_id: ConnectionId,
421        worktree_id: u64,
422    ) -> tide::Result<Vec<ConnectionId>> {
423        Ok(self
424            .read_worktree(worktree_id, connection_id)?
425            .connection_ids())
426    }
427
428    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
429        Some(self.channels.get(&channel_id)?.connection_ids())
430    }
431
432    fn read_worktree(
433        &self,
434        worktree_id: u64,
435        connection_id: ConnectionId,
436    ) -> tide::Result<&Worktree> {
437        let worktree = self
438            .worktrees
439            .get(&worktree_id)
440            .ok_or_else(|| anyhow!("worktree not found"))?;
441
442        if worktree.host_connection_id == connection_id
443            || worktree
444                .share()?
445                .guest_connection_ids
446                .contains_key(&connection_id)
447        {
448            Ok(worktree)
449        } else {
450            Err(anyhow!(
451                "{} is not a member of worktree {}",
452                connection_id,
453                worktree_id
454            ))?
455        }
456    }
457
458    fn write_worktree(
459        &mut self,
460        worktree_id: u64,
461        connection_id: ConnectionId,
462    ) -> tide::Result<&mut Worktree> {
463        let worktree = self
464            .worktrees
465            .get_mut(&worktree_id)
466            .ok_or_else(|| anyhow!("worktree not found"))?;
467
468        if worktree.host_connection_id == connection_id
469            || worktree.share.as_ref().map_or(false, |share| {
470                share.guest_connection_ids.contains_key(&connection_id)
471            })
472        {
473            Ok(worktree)
474        } else {
475            Err(anyhow!(
476                "{} is not a member of worktree {}",
477                connection_id,
478                worktree_id
479            ))?
480        }
481    }
482
483    #[cfg(test)]
484    fn check_invariants(&self) {
485        for (connection_id, connection) in &self.connections {
486            for worktree_id in &connection.worktrees {
487                let worktree = &self.worktrees.get(&worktree_id).unwrap();
488                if worktree.host_connection_id != *connection_id {
489                    assert!(worktree
490                        .share()
491                        .unwrap()
492                        .guest_connection_ids
493                        .contains_key(connection_id));
494                }
495            }
496            for channel_id in &connection.channels {
497                let channel = self.channels.get(channel_id).unwrap();
498                assert!(channel.connection_ids.contains(connection_id));
499            }
500            assert!(self
501                .connections_by_user_id
502                .get(&connection.user_id)
503                .unwrap()
504                .contains(connection_id));
505        }
506
507        for (user_id, connection_ids) in &self.connections_by_user_id {
508            for connection_id in connection_ids {
509                assert_eq!(
510                    self.connections.get(connection_id).unwrap().user_id,
511                    *user_id
512                );
513            }
514        }
515
516        for (worktree_id, worktree) in &self.worktrees {
517            let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
518            assert!(host_connection.worktrees.contains(worktree_id));
519
520            for collaborator_id in &worktree.collaborator_user_ids {
521                let visible_worktree_ids = self
522                    .visible_worktrees_by_user_id
523                    .get(collaborator_id)
524                    .unwrap();
525                assert!(visible_worktree_ids.contains(worktree_id));
526            }
527
528            if let Some(share) = &worktree.share {
529                for guest_connection_id in share.guest_connection_ids.keys() {
530                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
531                    assert!(guest_connection.worktrees.contains(worktree_id));
532                }
533                assert_eq!(
534                    share.active_replica_ids.len(),
535                    share.guest_connection_ids.len(),
536                );
537                assert_eq!(
538                    share.active_replica_ids,
539                    share
540                        .guest_connection_ids
541                        .values()
542                        .copied()
543                        .collect::<HashSet<_>>(),
544                );
545            }
546        }
547
548        for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
549            for worktree_id in visible_worktree_ids {
550                let worktree = self.worktrees.get(worktree_id).unwrap();
551                assert!(worktree.collaborator_user_ids.contains(user_id));
552            }
553        }
554
555        for (channel_id, channel) in &self.channels {
556            for connection_id in &channel.connection_ids {
557                let connection = self.connections.get(connection_id).unwrap();
558                assert!(connection.channels.contains(channel_id));
559            }
560        }
561    }
562}
563
564impl Worktree {
565    pub fn connection_ids(&self) -> Vec<ConnectionId> {
566        if let Some(share) = &self.share {
567            share
568                .guest_connection_ids
569                .keys()
570                .copied()
571                .chain(Some(self.host_connection_id))
572                .collect()
573        } else {
574            vec![self.host_connection_id]
575        }
576    }
577
578    pub fn share(&self) -> tide::Result<&WorktreeShare> {
579        Ok(self
580            .share
581            .as_ref()
582            .ok_or_else(|| anyhow!("worktree is not shared"))?)
583    }
584
585    fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
586        Ok(self
587            .share
588            .as_mut()
589            .ok_or_else(|| anyhow!("worktree is not shared"))?)
590    }
591}
592
593impl Channel {
594    fn connection_ids(&self) -> Vec<ConnectionId> {
595        self.connection_ids.iter().copied().collect()
596    }
597}