store.rs

  1use crate::db::{ChannelId, UserId};
  2use anyhow::anyhow;
  3use rpc::{proto, ConnectionId};
  4use std::collections::{hash_map, HashMap, HashSet};
  5
  6#[derive(Default)]
  7pub struct Store {
  8    connections: HashMap<ConnectionId, ConnectionState>,
  9    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 10    worktrees: HashMap<u64, Worktree>,
 11    visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
 12    channels: HashMap<ChannelId, Channel>,
 13    next_worktree_id: u64,
 14}
 15
 16struct ConnectionState {
 17    user_id: UserId,
 18    worktrees: HashSet<u64>,
 19    channels: HashSet<ChannelId>,
 20}
 21
 22pub struct Worktree {
 23    pub host_connection_id: ConnectionId,
 24    pub host_user_id: UserId,
 25    pub contact_user_ids: Vec<UserId>,
 26    pub root_name: String,
 27    pub share: Option<WorktreeShare>,
 28}
 29
 30pub struct WorktreeShare {
 31    pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
 32    pub active_replica_ids: HashSet<ReplicaId>,
 33    pub entries: HashMap<u64, proto::Entry>,
 34}
 35
 36#[derive(Default)]
 37pub struct Channel {
 38    pub connection_ids: HashSet<ConnectionId>,
 39}
 40
 41pub type ReplicaId = u16;
 42
 43#[derive(Default)]
 44pub struct RemovedConnectionState {
 45    pub hosted_worktrees: HashMap<u64, Worktree>,
 46    pub guest_worktree_ids: HashMap<u64, Vec<ConnectionId>>,
 47    pub contact_ids: HashSet<UserId>,
 48}
 49
 50pub struct JoinedWorktree<'a> {
 51    pub replica_id: ReplicaId,
 52    pub worktree: &'a Worktree,
 53}
 54
 55pub struct UnsharedWorktree {
 56    pub connection_ids: Vec<ConnectionId>,
 57    pub contact_ids: Vec<UserId>,
 58}
 59
 60pub struct LeftWorktree {
 61    pub connection_ids: Vec<ConnectionId>,
 62    pub contact_ids: Vec<UserId>,
 63}
 64
 65impl Store {
 66    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 67        self.connections.insert(
 68            connection_id,
 69            ConnectionState {
 70                user_id,
 71                worktrees: Default::default(),
 72                channels: Default::default(),
 73            },
 74        );
 75        self.connections_by_user_id
 76            .entry(user_id)
 77            .or_default()
 78            .insert(connection_id);
 79    }
 80
 81    pub fn remove_connection(
 82        &mut self,
 83        connection_id: ConnectionId,
 84    ) -> tide::Result<RemovedConnectionState> {
 85        let connection = if let Some(connection) = self.connections.remove(&connection_id) {
 86            connection
 87        } else {
 88            return Err(anyhow!("no such connection"))?;
 89        };
 90
 91        for channel_id in &connection.channels {
 92            if let Some(channel) = self.channels.get_mut(&channel_id) {
 93                channel.connection_ids.remove(&connection_id);
 94            }
 95        }
 96
 97        let user_connections = self
 98            .connections_by_user_id
 99            .get_mut(&connection.user_id)
100            .unwrap();
101        user_connections.remove(&connection_id);
102        if user_connections.is_empty() {
103            self.connections_by_user_id.remove(&connection.user_id);
104        }
105
106        let mut result = RemovedConnectionState::default();
107        for worktree_id in connection.worktrees.clone() {
108            if let Ok(worktree) = self.remove_worktree(worktree_id, connection_id) {
109                result
110                    .contact_ids
111                    .extend(worktree.contact_user_ids.iter().copied());
112                result.hosted_worktrees.insert(worktree_id, worktree);
113            } else if let Some(worktree) = self.leave_worktree(connection_id, worktree_id) {
114                result
115                    .guest_worktree_ids
116                    .insert(worktree_id, worktree.connection_ids);
117                result.contact_ids.extend(worktree.contact_ids);
118            }
119        }
120
121        #[cfg(test)]
122        self.check_invariants();
123
124        Ok(result)
125    }
126
127    #[cfg(test)]
128    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
129        self.channels.get(&id)
130    }
131
132    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
133        if let Some(connection) = self.connections.get_mut(&connection_id) {
134            connection.channels.insert(channel_id);
135            self.channels
136                .entry(channel_id)
137                .or_default()
138                .connection_ids
139                .insert(connection_id);
140        }
141    }
142
143    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
144        if let Some(connection) = self.connections.get_mut(&connection_id) {
145            connection.channels.remove(&channel_id);
146            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
147                entry.get_mut().connection_ids.remove(&connection_id);
148                if entry.get_mut().connection_ids.is_empty() {
149                    entry.remove();
150                }
151            }
152        }
153    }
154
155    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
156        Ok(self
157            .connections
158            .get(&connection_id)
159            .ok_or_else(|| anyhow!("unknown connection"))?
160            .user_id)
161    }
162
163    pub fn connection_ids_for_user<'a>(
164        &'a self,
165        user_id: UserId,
166    ) -> impl 'a + Iterator<Item = ConnectionId> {
167        self.connections_by_user_id
168            .get(&user_id)
169            .into_iter()
170            .flatten()
171            .copied()
172    }
173
174    pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
175        let mut contacts = HashMap::new();
176        for worktree_id in self
177            .visible_worktrees_by_user_id
178            .get(&user_id)
179            .unwrap_or(&HashSet::new())
180        {
181            let worktree = &self.worktrees[worktree_id];
182
183            let mut guests = HashSet::new();
184            if let Ok(share) = worktree.share() {
185                for guest_connection_id in share.guests.keys() {
186                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
187                        guests.insert(user_id.to_proto());
188                    }
189                }
190            }
191
192            if let Ok(host_user_id) = self.user_id_for_connection(worktree.host_connection_id) {
193                contacts
194                    .entry(host_user_id)
195                    .or_insert_with(|| proto::Contact {
196                        user_id: host_user_id.to_proto(),
197                        worktrees: Vec::new(),
198                    })
199                    .worktrees
200                    .push(proto::WorktreeMetadata {
201                        id: *worktree_id,
202                        root_name: worktree.root_name.clone(),
203                        is_shared: worktree.share.is_some(),
204                        guests: guests.into_iter().collect(),
205                    });
206            }
207        }
208
209        contacts.into_values().collect()
210    }
211
212    pub fn add_worktree(&mut self, worktree: Worktree) -> u64 {
213        let worktree_id = self.next_worktree_id;
214        for contact_user_id in &worktree.contact_user_ids {
215            self.visible_worktrees_by_user_id
216                .entry(*contact_user_id)
217                .or_default()
218                .insert(worktree_id);
219        }
220        self.next_worktree_id += 1;
221        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
222            connection.worktrees.insert(worktree_id);
223        }
224        self.worktrees.insert(worktree_id, worktree);
225
226        #[cfg(test)]
227        self.check_invariants();
228
229        worktree_id
230    }
231
232    pub fn remove_worktree(
233        &mut self,
234        worktree_id: u64,
235        acting_connection_id: ConnectionId,
236    ) -> tide::Result<Worktree> {
237        let worktree = if let hash_map::Entry::Occupied(e) = self.worktrees.entry(worktree_id) {
238            if e.get().host_connection_id != acting_connection_id {
239                Err(anyhow!("not your worktree"))?;
240            }
241            e.remove()
242        } else {
243            return Err(anyhow!("no such worktree"))?;
244        };
245
246        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
247            connection.worktrees.remove(&worktree_id);
248        }
249
250        if let Some(share) = &worktree.share {
251            for connection_id in share.guests.keys() {
252                if let Some(connection) = self.connections.get_mut(connection_id) {
253                    connection.worktrees.remove(&worktree_id);
254                }
255            }
256        }
257
258        for contact_user_id in &worktree.contact_user_ids {
259            if let Some(visible_worktrees) =
260                self.visible_worktrees_by_user_id.get_mut(&contact_user_id)
261            {
262                visible_worktrees.remove(&worktree_id);
263            }
264        }
265
266        #[cfg(test)]
267        self.check_invariants();
268
269        Ok(worktree)
270    }
271
272    pub fn share_worktree(
273        &mut self,
274        worktree_id: u64,
275        connection_id: ConnectionId,
276        entries: HashMap<u64, proto::Entry>,
277    ) -> Option<Vec<UserId>> {
278        if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
279            if worktree.host_connection_id == connection_id {
280                worktree.share = Some(WorktreeShare {
281                    guests: Default::default(),
282                    active_replica_ids: Default::default(),
283                    entries,
284                });
285                return Some(worktree.contact_user_ids.clone());
286            }
287        }
288        None
289    }
290
291    pub fn unshare_worktree(
292        &mut self,
293        worktree_id: u64,
294        acting_connection_id: ConnectionId,
295    ) -> tide::Result<UnsharedWorktree> {
296        let worktree = if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
297            worktree
298        } else {
299            return Err(anyhow!("no such worktree"))?;
300        };
301
302        if worktree.host_connection_id != acting_connection_id {
303            return Err(anyhow!("not your worktree"))?;
304        }
305
306        let connection_ids = worktree.connection_ids();
307        let contact_ids = worktree.contact_user_ids.clone();
308        if let Some(share) = worktree.share.take() {
309            for connection_id in share.guests.into_keys() {
310                if let Some(connection) = self.connections.get_mut(&connection_id) {
311                    connection.worktrees.remove(&worktree_id);
312                }
313            }
314
315            #[cfg(test)]
316            self.check_invariants();
317
318            Ok(UnsharedWorktree {
319                connection_ids,
320                contact_ids,
321            })
322        } else {
323            Err(anyhow!("worktree is not shared"))?
324        }
325    }
326
327    pub fn join_worktree(
328        &mut self,
329        connection_id: ConnectionId,
330        user_id: UserId,
331        worktree_id: u64,
332    ) -> tide::Result<JoinedWorktree> {
333        let connection = self
334            .connections
335            .get_mut(&connection_id)
336            .ok_or_else(|| anyhow!("no such connection"))?;
337        let worktree = self
338            .worktrees
339            .get_mut(&worktree_id)
340            .and_then(|worktree| {
341                if worktree.contact_user_ids.contains(&user_id) {
342                    Some(worktree)
343                } else {
344                    None
345                }
346            })
347            .ok_or_else(|| anyhow!("no such worktree"))?;
348
349        let share = worktree.share_mut()?;
350        connection.worktrees.insert(worktree_id);
351
352        let mut replica_id = 1;
353        while share.active_replica_ids.contains(&replica_id) {
354            replica_id += 1;
355        }
356        share.active_replica_ids.insert(replica_id);
357        share.guests.insert(connection_id, (replica_id, user_id));
358
359        #[cfg(test)]
360        self.check_invariants();
361
362        Ok(JoinedWorktree {
363            replica_id,
364            worktree: &self.worktrees[&worktree_id],
365        })
366    }
367
368    pub fn leave_worktree(
369        &mut self,
370        connection_id: ConnectionId,
371        worktree_id: u64,
372    ) -> Option<LeftWorktree> {
373        let worktree = self.worktrees.get_mut(&worktree_id)?;
374        let share = worktree.share.as_mut()?;
375        let (replica_id, _) = share.guests.remove(&connection_id)?;
376        share.active_replica_ids.remove(&replica_id);
377
378        if let Some(connection) = self.connections.get_mut(&connection_id) {
379            connection.worktrees.remove(&worktree_id);
380        }
381
382        let connection_ids = worktree.connection_ids();
383        let contact_ids = worktree.contact_user_ids.clone();
384
385        #[cfg(test)]
386        self.check_invariants();
387
388        Some(LeftWorktree {
389            connection_ids,
390            contact_ids,
391        })
392    }
393
394    pub fn update_worktree(
395        &mut self,
396        connection_id: ConnectionId,
397        worktree_id: u64,
398        removed_entries: &[u64],
399        updated_entries: &[proto::Entry],
400    ) -> tide::Result<Vec<ConnectionId>> {
401        let worktree = self.write_worktree(worktree_id, connection_id)?;
402        let share = worktree.share_mut()?;
403        for entry_id in removed_entries {
404            share.entries.remove(&entry_id);
405        }
406        for entry in updated_entries {
407            share.entries.insert(entry.id, entry.clone());
408        }
409        Ok(worktree.connection_ids())
410    }
411
412    pub fn worktree_host_connection_id(
413        &self,
414        connection_id: ConnectionId,
415        worktree_id: u64,
416    ) -> tide::Result<ConnectionId> {
417        Ok(self
418            .read_worktree(worktree_id, connection_id)?
419            .host_connection_id)
420    }
421
422    pub fn worktree_guest_connection_ids(
423        &self,
424        connection_id: ConnectionId,
425        worktree_id: u64,
426    ) -> tide::Result<Vec<ConnectionId>> {
427        Ok(self
428            .read_worktree(worktree_id, connection_id)?
429            .share()?
430            .guests
431            .keys()
432            .copied()
433            .collect())
434    }
435
436    pub fn worktree_connection_ids(
437        &self,
438        connection_id: ConnectionId,
439        worktree_id: u64,
440    ) -> tide::Result<Vec<ConnectionId>> {
441        Ok(self
442            .read_worktree(worktree_id, connection_id)?
443            .connection_ids())
444    }
445
446    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
447        Some(self.channels.get(&channel_id)?.connection_ids())
448    }
449
450    fn read_worktree(
451        &self,
452        worktree_id: u64,
453        connection_id: ConnectionId,
454    ) -> tide::Result<&Worktree> {
455        let worktree = self
456            .worktrees
457            .get(&worktree_id)
458            .ok_or_else(|| anyhow!("worktree not found"))?;
459
460        if worktree.host_connection_id == connection_id
461            || worktree.share()?.guests.contains_key(&connection_id)
462        {
463            Ok(worktree)
464        } else {
465            Err(anyhow!(
466                "{} is not a member of worktree {}",
467                connection_id,
468                worktree_id
469            ))?
470        }
471    }
472
473    fn write_worktree(
474        &mut self,
475        worktree_id: u64,
476        connection_id: ConnectionId,
477    ) -> tide::Result<&mut Worktree> {
478        let worktree = self
479            .worktrees
480            .get_mut(&worktree_id)
481            .ok_or_else(|| anyhow!("worktree not found"))?;
482
483        if worktree.host_connection_id == connection_id
484            || worktree
485                .share
486                .as_ref()
487                .map_or(false, |share| share.guests.contains_key(&connection_id))
488        {
489            Ok(worktree)
490        } else {
491            Err(anyhow!(
492                "{} is not a member of worktree {}",
493                connection_id,
494                worktree_id
495            ))?
496        }
497    }
498
499    #[cfg(test)]
500    fn check_invariants(&self) {
501        for (connection_id, connection) in &self.connections {
502            for worktree_id in &connection.worktrees {
503                let worktree = &self.worktrees.get(&worktree_id).unwrap();
504                if worktree.host_connection_id != *connection_id {
505                    assert!(worktree.share().unwrap().guests.contains_key(connection_id));
506                }
507            }
508            for channel_id in &connection.channels {
509                let channel = self.channels.get(channel_id).unwrap();
510                assert!(channel.connection_ids.contains(connection_id));
511            }
512            assert!(self
513                .connections_by_user_id
514                .get(&connection.user_id)
515                .unwrap()
516                .contains(connection_id));
517        }
518
519        for (user_id, connection_ids) in &self.connections_by_user_id {
520            for connection_id in connection_ids {
521                assert_eq!(
522                    self.connections.get(connection_id).unwrap().user_id,
523                    *user_id
524                );
525            }
526        }
527
528        for (worktree_id, worktree) in &self.worktrees {
529            let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
530            assert!(host_connection.worktrees.contains(worktree_id));
531
532            for contact_id in &worktree.contact_user_ids {
533                let visible_worktree_ids =
534                    self.visible_worktrees_by_user_id.get(contact_id).unwrap();
535                assert!(visible_worktree_ids.contains(worktree_id));
536            }
537
538            if let Some(share) = &worktree.share {
539                for guest_connection_id in share.guests.keys() {
540                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
541                    assert!(guest_connection.worktrees.contains(worktree_id));
542                }
543                assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
544                assert_eq!(
545                    share.active_replica_ids,
546                    share
547                        .guests
548                        .values()
549                        .map(|(replica_id, _)| *replica_id)
550                        .collect::<HashSet<_>>(),
551                );
552            }
553        }
554
555        for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
556            for worktree_id in visible_worktree_ids {
557                let worktree = self.worktrees.get(worktree_id).unwrap();
558                assert!(worktree.contact_user_ids.contains(user_id));
559            }
560        }
561
562        for (channel_id, channel) in &self.channels {
563            for connection_id in &channel.connection_ids {
564                let connection = self.connections.get(connection_id).unwrap();
565                assert!(connection.channels.contains(channel_id));
566            }
567        }
568    }
569}
570
571impl Worktree {
572    pub fn connection_ids(&self) -> Vec<ConnectionId> {
573        if let Some(share) = &self.share {
574            share
575                .guests
576                .keys()
577                .copied()
578                .chain(Some(self.host_connection_id))
579                .collect()
580        } else {
581            vec![self.host_connection_id]
582        }
583    }
584
585    pub fn share(&self) -> tide::Result<&WorktreeShare> {
586        Ok(self
587            .share
588            .as_ref()
589            .ok_or_else(|| anyhow!("worktree is not shared"))?)
590    }
591
592    fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
593        Ok(self
594            .share
595            .as_mut()
596            .ok_or_else(|| anyhow!("worktree is not shared"))?)
597    }
598}
599
600impl Channel {
601    fn connection_ids(&self) -> Vec<ConnectionId> {
602        self.connection_ids.iter().copied().collect()
603    }
604}