store.rs

  1use crate::db::{ChannelId, UserId};
  2use anyhow::anyhow;
  3use rpc::{proto, ConnectionId};
  4use std::collections::{hash_map, HashMap, HashSet};
  5
  6#[derive(Default)]
  7pub struct Store {
  8    connections: HashMap<ConnectionId, ConnectionState>,
  9    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 10    worktrees: HashMap<u64, Worktree>,
 11    visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
 12    channels: HashMap<ChannelId, Channel>,
 13    next_worktree_id: u64,
 14}
 15
 16struct ConnectionState {
 17    user_id: UserId,
 18    worktrees: HashSet<u64>,
 19    channels: HashSet<ChannelId>,
 20}
 21
 22pub struct Worktree {
 23    pub host_connection_id: ConnectionId,
 24    pub host_user_id: UserId,
 25    pub authorized_user_ids: Vec<UserId>,
 26    pub root_name: String,
 27    pub share: Option<WorktreeShare>,
 28}
 29
 30pub struct WorktreeShare {
 31    pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
 32    pub active_replica_ids: HashSet<ReplicaId>,
 33    pub entries: HashMap<u64, proto::Entry>,
 34}
 35
 36#[derive(Default)]
 37pub struct Channel {
 38    pub connection_ids: HashSet<ConnectionId>,
 39}
 40
 41pub type ReplicaId = u16;
 42
 43#[derive(Default)]
 44pub struct RemovedConnectionState {
 45    pub hosted_worktrees: HashMap<u64, Worktree>,
 46    pub guest_worktree_ids: HashMap<u64, Vec<ConnectionId>>,
 47    pub contact_ids: HashSet<UserId>,
 48}
 49
 50pub struct JoinedWorktree<'a> {
 51    pub replica_id: ReplicaId,
 52    pub worktree: &'a Worktree,
 53}
 54
 55pub struct UnsharedWorktree {
 56    pub connection_ids: Vec<ConnectionId>,
 57    pub authorized_user_ids: Vec<UserId>,
 58}
 59
 60pub struct LeftWorktree {
 61    pub connection_ids: Vec<ConnectionId>,
 62    pub authorized_user_ids: Vec<UserId>,
 63}
 64
 65impl Store {
 66    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 67        self.connections.insert(
 68            connection_id,
 69            ConnectionState {
 70                user_id,
 71                worktrees: Default::default(),
 72                channels: Default::default(),
 73            },
 74        );
 75        self.connections_by_user_id
 76            .entry(user_id)
 77            .or_default()
 78            .insert(connection_id);
 79    }
 80
 81    pub fn remove_connection(
 82        &mut self,
 83        connection_id: ConnectionId,
 84    ) -> tide::Result<RemovedConnectionState> {
 85        let connection = if let Some(connection) = self.connections.remove(&connection_id) {
 86            connection
 87        } else {
 88            return Err(anyhow!("no such connection"))?;
 89        };
 90
 91        for channel_id in &connection.channels {
 92            if let Some(channel) = self.channels.get_mut(&channel_id) {
 93                channel.connection_ids.remove(&connection_id);
 94            }
 95        }
 96
 97        let user_connections = self
 98            .connections_by_user_id
 99            .get_mut(&connection.user_id)
100            .unwrap();
101        user_connections.remove(&connection_id);
102        if user_connections.is_empty() {
103            self.connections_by_user_id.remove(&connection.user_id);
104        }
105
106        let mut result = RemovedConnectionState::default();
107        for worktree_id in connection.worktrees.clone() {
108            if let Ok(worktree) = self.remove_worktree(worktree_id, connection_id) {
109                result
110                    .contact_ids
111                    .extend(worktree.authorized_user_ids.iter().copied());
112                result.hosted_worktrees.insert(worktree_id, worktree);
113            } else if let Some(worktree) = self.leave_worktree(connection_id, worktree_id) {
114                result
115                    .guest_worktree_ids
116                    .insert(worktree_id, worktree.connection_ids);
117                result.contact_ids.extend(worktree.authorized_user_ids);
118            }
119        }
120
121        #[cfg(test)]
122        self.check_invariants();
123
124        Ok(result)
125    }
126
127    #[cfg(test)]
128    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
129        self.channels.get(&id)
130    }
131
132    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
133        if let Some(connection) = self.connections.get_mut(&connection_id) {
134            connection.channels.insert(channel_id);
135            self.channels
136                .entry(channel_id)
137                .or_default()
138                .connection_ids
139                .insert(connection_id);
140        }
141    }
142
143    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
144        if let Some(connection) = self.connections.get_mut(&connection_id) {
145            connection.channels.remove(&channel_id);
146            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
147                entry.get_mut().connection_ids.remove(&connection_id);
148                if entry.get_mut().connection_ids.is_empty() {
149                    entry.remove();
150                }
151            }
152        }
153    }
154
155    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
156        Ok(self
157            .connections
158            .get(&connection_id)
159            .ok_or_else(|| anyhow!("unknown connection"))?
160            .user_id)
161    }
162
163    pub fn connection_ids_for_user<'a>(
164        &'a self,
165        user_id: UserId,
166    ) -> impl 'a + Iterator<Item = ConnectionId> {
167        self.connections_by_user_id
168            .get(&user_id)
169            .into_iter()
170            .flatten()
171            .copied()
172    }
173
174    pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
175        let mut contacts = HashMap::new();
176        for worktree_id in self
177            .visible_worktrees_by_user_id
178            .get(&user_id)
179            .unwrap_or(&HashSet::new())
180        {
181            let worktree = &self.worktrees[worktree_id];
182
183            let mut guests = HashSet::new();
184            if let Ok(share) = worktree.share() {
185                for guest_connection_id in share.guests.keys() {
186                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
187                        guests.insert(user_id.to_proto());
188                    }
189                }
190            }
191
192            if let Ok(host_user_id) = self.user_id_for_connection(worktree.host_connection_id) {
193                contacts
194                    .entry(host_user_id)
195                    .or_insert_with(|| proto::Contact {
196                        user_id: host_user_id.to_proto(),
197                        worktrees: Vec::new(),
198                    })
199                    .worktrees
200                    .push(proto::WorktreeMetadata {
201                        id: *worktree_id,
202                        root_name: worktree.root_name.clone(),
203                        is_shared: worktree.share.is_some(),
204                        guests: guests.into_iter().collect(),
205                    });
206            }
207        }
208
209        contacts.into_values().collect()
210    }
211
212    pub fn add_worktree(&mut self, worktree: Worktree) -> u64 {
213        let worktree_id = self.next_worktree_id;
214        for authorized_user_id in &worktree.authorized_user_ids {
215            self.visible_worktrees_by_user_id
216                .entry(*authorized_user_id)
217                .or_default()
218                .insert(worktree_id);
219        }
220        self.next_worktree_id += 1;
221        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
222            connection.worktrees.insert(worktree_id);
223        }
224        self.worktrees.insert(worktree_id, worktree);
225
226        #[cfg(test)]
227        self.check_invariants();
228
229        worktree_id
230    }
231
232    pub fn remove_worktree(
233        &mut self,
234        worktree_id: u64,
235        acting_connection_id: ConnectionId,
236    ) -> tide::Result<Worktree> {
237        let worktree = if let hash_map::Entry::Occupied(e) = self.worktrees.entry(worktree_id) {
238            if e.get().host_connection_id != acting_connection_id {
239                Err(anyhow!("not your worktree"))?;
240            }
241            e.remove()
242        } else {
243            return Err(anyhow!("no such worktree"))?;
244        };
245
246        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
247            connection.worktrees.remove(&worktree_id);
248        }
249
250        if let Some(share) = &worktree.share {
251            for connection_id in share.guests.keys() {
252                if let Some(connection) = self.connections.get_mut(connection_id) {
253                    connection.worktrees.remove(&worktree_id);
254                }
255            }
256        }
257
258        for authorized_user_id in &worktree.authorized_user_ids {
259            if let Some(visible_worktrees) = self
260                .visible_worktrees_by_user_id
261                .get_mut(&authorized_user_id)
262            {
263                visible_worktrees.remove(&worktree_id);
264            }
265        }
266
267        #[cfg(test)]
268        self.check_invariants();
269
270        Ok(worktree)
271    }
272
273    pub fn share_worktree(
274        &mut self,
275        worktree_id: u64,
276        connection_id: ConnectionId,
277        entries: HashMap<u64, proto::Entry>,
278    ) -> Option<Vec<UserId>> {
279        if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
280            if worktree.host_connection_id == connection_id {
281                worktree.share = Some(WorktreeShare {
282                    guests: Default::default(),
283                    active_replica_ids: Default::default(),
284                    entries,
285                });
286                return Some(worktree.authorized_user_ids.clone());
287            }
288        }
289        None
290    }
291
292    pub fn unshare_worktree(
293        &mut self,
294        worktree_id: u64,
295        acting_connection_id: ConnectionId,
296    ) -> tide::Result<UnsharedWorktree> {
297        let worktree = if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
298            worktree
299        } else {
300            return Err(anyhow!("no such worktree"))?;
301        };
302
303        if worktree.host_connection_id != acting_connection_id {
304            return Err(anyhow!("not your worktree"))?;
305        }
306
307        let connection_ids = worktree.connection_ids();
308        let authorized_user_ids = worktree.authorized_user_ids.clone();
309        if let Some(share) = worktree.share.take() {
310            for connection_id in share.guests.into_keys() {
311                if let Some(connection) = self.connections.get_mut(&connection_id) {
312                    connection.worktrees.remove(&worktree_id);
313                }
314            }
315
316            #[cfg(test)]
317            self.check_invariants();
318
319            Ok(UnsharedWorktree {
320                connection_ids,
321                authorized_user_ids,
322            })
323        } else {
324            Err(anyhow!("worktree is not shared"))?
325        }
326    }
327
328    pub fn join_worktree(
329        &mut self,
330        connection_id: ConnectionId,
331        user_id: UserId,
332        worktree_id: u64,
333    ) -> tide::Result<JoinedWorktree> {
334        let connection = self
335            .connections
336            .get_mut(&connection_id)
337            .ok_or_else(|| anyhow!("no such connection"))?;
338        let worktree = self
339            .worktrees
340            .get_mut(&worktree_id)
341            .and_then(|worktree| {
342                if worktree.authorized_user_ids.contains(&user_id) {
343                    Some(worktree)
344                } else {
345                    None
346                }
347            })
348            .ok_or_else(|| anyhow!("no such worktree"))?;
349
350        let share = worktree.share_mut()?;
351        connection.worktrees.insert(worktree_id);
352
353        let mut replica_id = 1;
354        while share.active_replica_ids.contains(&replica_id) {
355            replica_id += 1;
356        }
357        share.active_replica_ids.insert(replica_id);
358        share.guests.insert(connection_id, (replica_id, user_id));
359
360        #[cfg(test)]
361        self.check_invariants();
362
363        Ok(JoinedWorktree {
364            replica_id,
365            worktree: &self.worktrees[&worktree_id],
366        })
367    }
368
369    pub fn leave_worktree(
370        &mut self,
371        connection_id: ConnectionId,
372        worktree_id: u64,
373    ) -> Option<LeftWorktree> {
374        let worktree = self.worktrees.get_mut(&worktree_id)?;
375        let share = worktree.share.as_mut()?;
376        let (replica_id, _) = share.guests.remove(&connection_id)?;
377        share.active_replica_ids.remove(&replica_id);
378
379        if let Some(connection) = self.connections.get_mut(&connection_id) {
380            connection.worktrees.remove(&worktree_id);
381        }
382
383        let connection_ids = worktree.connection_ids();
384        let authorized_user_ids = worktree.authorized_user_ids.clone();
385
386        #[cfg(test)]
387        self.check_invariants();
388
389        Some(LeftWorktree {
390            connection_ids,
391            authorized_user_ids,
392        })
393    }
394
395    pub fn update_worktree(
396        &mut self,
397        connection_id: ConnectionId,
398        worktree_id: u64,
399        removed_entries: &[u64],
400        updated_entries: &[proto::Entry],
401    ) -> tide::Result<Vec<ConnectionId>> {
402        let worktree = self.write_worktree(worktree_id, connection_id)?;
403        let share = worktree.share_mut()?;
404        for entry_id in removed_entries {
405            share.entries.remove(&entry_id);
406        }
407        for entry in updated_entries {
408            share.entries.insert(entry.id, entry.clone());
409        }
410        Ok(worktree.connection_ids())
411    }
412
413    pub fn worktree_host_connection_id(
414        &self,
415        connection_id: ConnectionId,
416        worktree_id: u64,
417    ) -> tide::Result<ConnectionId> {
418        Ok(self
419            .read_worktree(worktree_id, connection_id)?
420            .host_connection_id)
421    }
422
423    pub fn worktree_guest_connection_ids(
424        &self,
425        connection_id: ConnectionId,
426        worktree_id: u64,
427    ) -> tide::Result<Vec<ConnectionId>> {
428        Ok(self
429            .read_worktree(worktree_id, connection_id)?
430            .share()?
431            .guests
432            .keys()
433            .copied()
434            .collect())
435    }
436
437    pub fn worktree_connection_ids(
438        &self,
439        connection_id: ConnectionId,
440        worktree_id: u64,
441    ) -> tide::Result<Vec<ConnectionId>> {
442        Ok(self
443            .read_worktree(worktree_id, connection_id)?
444            .connection_ids())
445    }
446
447    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
448        Some(self.channels.get(&channel_id)?.connection_ids())
449    }
450
451    fn read_worktree(
452        &self,
453        worktree_id: u64,
454        connection_id: ConnectionId,
455    ) -> tide::Result<&Worktree> {
456        let worktree = self
457            .worktrees
458            .get(&worktree_id)
459            .ok_or_else(|| anyhow!("worktree not found"))?;
460
461        if worktree.host_connection_id == connection_id
462            || worktree.share()?.guests.contains_key(&connection_id)
463        {
464            Ok(worktree)
465        } else {
466            Err(anyhow!(
467                "{} is not a member of worktree {}",
468                connection_id,
469                worktree_id
470            ))?
471        }
472    }
473
474    fn write_worktree(
475        &mut self,
476        worktree_id: u64,
477        connection_id: ConnectionId,
478    ) -> tide::Result<&mut Worktree> {
479        let worktree = self
480            .worktrees
481            .get_mut(&worktree_id)
482            .ok_or_else(|| anyhow!("worktree not found"))?;
483
484        if worktree.host_connection_id == connection_id
485            || worktree
486                .share
487                .as_ref()
488                .map_or(false, |share| share.guests.contains_key(&connection_id))
489        {
490            Ok(worktree)
491        } else {
492            Err(anyhow!(
493                "{} is not a member of worktree {}",
494                connection_id,
495                worktree_id
496            ))?
497        }
498    }
499
500    #[cfg(test)]
501    fn check_invariants(&self) {
502        for (connection_id, connection) in &self.connections {
503            for worktree_id in &connection.worktrees {
504                let worktree = &self.worktrees.get(&worktree_id).unwrap();
505                if worktree.host_connection_id != *connection_id {
506                    assert!(worktree.share().unwrap().guests.contains_key(connection_id));
507                }
508            }
509            for channel_id in &connection.channels {
510                let channel = self.channels.get(channel_id).unwrap();
511                assert!(channel.connection_ids.contains(connection_id));
512            }
513            assert!(self
514                .connections_by_user_id
515                .get(&connection.user_id)
516                .unwrap()
517                .contains(connection_id));
518        }
519
520        for (user_id, connection_ids) in &self.connections_by_user_id {
521            for connection_id in connection_ids {
522                assert_eq!(
523                    self.connections.get(connection_id).unwrap().user_id,
524                    *user_id
525                );
526            }
527        }
528
529        for (worktree_id, worktree) in &self.worktrees {
530            let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
531            assert!(host_connection.worktrees.contains(worktree_id));
532
533            for authorized_user_ids in &worktree.authorized_user_ids {
534                let visible_worktree_ids = self
535                    .visible_worktrees_by_user_id
536                    .get(authorized_user_ids)
537                    .unwrap();
538                assert!(visible_worktree_ids.contains(worktree_id));
539            }
540
541            if let Some(share) = &worktree.share {
542                for guest_connection_id in share.guests.keys() {
543                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
544                    assert!(guest_connection.worktrees.contains(worktree_id));
545                }
546                assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
547                assert_eq!(
548                    share.active_replica_ids,
549                    share
550                        .guests
551                        .values()
552                        .map(|(replica_id, _)| *replica_id)
553                        .collect::<HashSet<_>>(),
554                );
555            }
556        }
557
558        for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
559            for worktree_id in visible_worktree_ids {
560                let worktree = self.worktrees.get(worktree_id).unwrap();
561                assert!(worktree.authorized_user_ids.contains(user_id));
562            }
563        }
564
565        for (channel_id, channel) in &self.channels {
566            for connection_id in &channel.connection_ids {
567                let connection = self.connections.get(connection_id).unwrap();
568                assert!(connection.channels.contains(channel_id));
569            }
570        }
571    }
572}
573
574impl Worktree {
575    pub fn connection_ids(&self) -> Vec<ConnectionId> {
576        if let Some(share) = &self.share {
577            share
578                .guests
579                .keys()
580                .copied()
581                .chain(Some(self.host_connection_id))
582                .collect()
583        } else {
584            vec![self.host_connection_id]
585        }
586    }
587
588    pub fn share(&self) -> tide::Result<&WorktreeShare> {
589        Ok(self
590            .share
591            .as_ref()
592            .ok_or_else(|| anyhow!("worktree is not shared"))?)
593    }
594
595    fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
596        Ok(self
597            .share
598            .as_mut()
599            .ok_or_else(|| anyhow!("worktree is not shared"))?)
600    }
601}
602
603impl Channel {
604    fn connection_ids(&self) -> Vec<ConnectionId> {
605        self.connection_ids.iter().copied().collect()
606    }
607}