store.rs

  1use crate::db::{ChannelId, UserId};
  2use anyhow::anyhow;
  3use collections::{HashMap, HashSet};
  4use rpc::{proto, ConnectionId};
  5use std::collections::hash_map;
  6
  7#[derive(Default)]
  8pub struct Store {
  9    connections: HashMap<ConnectionId, ConnectionState>,
 10    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 11    worktrees: HashMap<u64, Worktree>,
 12    visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
 13    channels: HashMap<ChannelId, Channel>,
 14    next_worktree_id: u64,
 15}
 16
 17struct ConnectionState {
 18    user_id: UserId,
 19    worktrees: HashSet<u64>,
 20    channels: HashSet<ChannelId>,
 21}
 22
 23pub struct Worktree {
 24    pub host_connection_id: ConnectionId,
 25    pub host_user_id: UserId,
 26    pub authorized_user_ids: Vec<UserId>,
 27    pub root_name: String,
 28    pub share: Option<WorktreeShare>,
 29}
 30
 31pub struct WorktreeShare {
 32    pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
 33    pub active_replica_ids: HashSet<ReplicaId>,
 34    pub entries: HashMap<u64, proto::Entry>,
 35}
 36
 37#[derive(Default)]
 38pub struct Channel {
 39    pub connection_ids: HashSet<ConnectionId>,
 40}
 41
 42pub type ReplicaId = u16;
 43
 44#[derive(Default)]
 45pub struct RemovedConnectionState {
 46    pub hosted_worktrees: HashMap<u64, Worktree>,
 47    pub guest_worktree_ids: HashMap<u64, Vec<ConnectionId>>,
 48    pub contact_ids: HashSet<UserId>,
 49}
 50
 51pub struct JoinedWorktree<'a> {
 52    pub replica_id: ReplicaId,
 53    pub worktree: &'a Worktree,
 54}
 55
 56pub struct UnsharedWorktree {
 57    pub connection_ids: Vec<ConnectionId>,
 58    pub authorized_user_ids: Vec<UserId>,
 59}
 60
 61pub struct LeftWorktree {
 62    pub connection_ids: Vec<ConnectionId>,
 63    pub authorized_user_ids: Vec<UserId>,
 64}
 65
 66impl Store {
 67    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 68        self.connections.insert(
 69            connection_id,
 70            ConnectionState {
 71                user_id,
 72                worktrees: Default::default(),
 73                channels: Default::default(),
 74            },
 75        );
 76        self.connections_by_user_id
 77            .entry(user_id)
 78            .or_default()
 79            .insert(connection_id);
 80    }
 81
 82    pub fn remove_connection(
 83        &mut self,
 84        connection_id: ConnectionId,
 85    ) -> tide::Result<RemovedConnectionState> {
 86        let connection = if let Some(connection) = self.connections.remove(&connection_id) {
 87            connection
 88        } else {
 89            return Err(anyhow!("no such connection"))?;
 90        };
 91
 92        for channel_id in &connection.channels {
 93            if let Some(channel) = self.channels.get_mut(&channel_id) {
 94                channel.connection_ids.remove(&connection_id);
 95            }
 96        }
 97
 98        let user_connections = self
 99            .connections_by_user_id
100            .get_mut(&connection.user_id)
101            .unwrap();
102        user_connections.remove(&connection_id);
103        if user_connections.is_empty() {
104            self.connections_by_user_id.remove(&connection.user_id);
105        }
106
107        let mut result = RemovedConnectionState::default();
108        for worktree_id in connection.worktrees.clone() {
109            if let Ok(worktree) = self.remove_worktree(worktree_id, connection_id) {
110                result
111                    .contact_ids
112                    .extend(worktree.authorized_user_ids.iter().copied());
113                result.hosted_worktrees.insert(worktree_id, worktree);
114            } else if let Some(worktree) = self.leave_worktree(connection_id, worktree_id) {
115                result
116                    .guest_worktree_ids
117                    .insert(worktree_id, worktree.connection_ids);
118                result.contact_ids.extend(worktree.authorized_user_ids);
119            }
120        }
121
122        #[cfg(test)]
123        self.check_invariants();
124
125        Ok(result)
126    }
127
128    #[cfg(test)]
129    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
130        self.channels.get(&id)
131    }
132
133    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
134        if let Some(connection) = self.connections.get_mut(&connection_id) {
135            connection.channels.insert(channel_id);
136            self.channels
137                .entry(channel_id)
138                .or_default()
139                .connection_ids
140                .insert(connection_id);
141        }
142    }
143
144    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
145        if let Some(connection) = self.connections.get_mut(&connection_id) {
146            connection.channels.remove(&channel_id);
147            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
148                entry.get_mut().connection_ids.remove(&connection_id);
149                if entry.get_mut().connection_ids.is_empty() {
150                    entry.remove();
151                }
152            }
153        }
154    }
155
156    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
157        Ok(self
158            .connections
159            .get(&connection_id)
160            .ok_or_else(|| anyhow!("unknown connection"))?
161            .user_id)
162    }
163
164    pub fn connection_ids_for_user<'a>(
165        &'a self,
166        user_id: UserId,
167    ) -> impl 'a + Iterator<Item = ConnectionId> {
168        self.connections_by_user_id
169            .get(&user_id)
170            .into_iter()
171            .flatten()
172            .copied()
173    }
174
175    pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
176        let mut contacts = HashMap::default();
177        for worktree_id in self
178            .visible_worktrees_by_user_id
179            .get(&user_id)
180            .unwrap_or(&HashSet::default())
181        {
182            let worktree = &self.worktrees[worktree_id];
183
184            let mut guests = HashSet::default();
185            if let Ok(share) = worktree.share() {
186                for guest_connection_id in share.guests.keys() {
187                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
188                        guests.insert(user_id.to_proto());
189                    }
190                }
191            }
192
193            if let Ok(host_user_id) = self.user_id_for_connection(worktree.host_connection_id) {
194                contacts
195                    .entry(host_user_id)
196                    .or_insert_with(|| proto::Contact {
197                        user_id: host_user_id.to_proto(),
198                        worktrees: Vec::new(),
199                    })
200                    .worktrees
201                    .push(proto::WorktreeMetadata {
202                        id: *worktree_id,
203                        root_name: worktree.root_name.clone(),
204                        is_shared: worktree.share.is_some(),
205                        guests: guests.into_iter().collect(),
206                    });
207            }
208        }
209
210        contacts.into_values().collect()
211    }
212
213    pub fn add_worktree(&mut self, worktree: Worktree) -> u64 {
214        let worktree_id = self.next_worktree_id;
215        for authorized_user_id in &worktree.authorized_user_ids {
216            self.visible_worktrees_by_user_id
217                .entry(*authorized_user_id)
218                .or_default()
219                .insert(worktree_id);
220        }
221        self.next_worktree_id += 1;
222        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
223            connection.worktrees.insert(worktree_id);
224        }
225        self.worktrees.insert(worktree_id, worktree);
226
227        #[cfg(test)]
228        self.check_invariants();
229
230        worktree_id
231    }
232
233    pub fn remove_worktree(
234        &mut self,
235        worktree_id: u64,
236        acting_connection_id: ConnectionId,
237    ) -> tide::Result<Worktree> {
238        let worktree = if let hash_map::Entry::Occupied(e) = self.worktrees.entry(worktree_id) {
239            if e.get().host_connection_id != acting_connection_id {
240                Err(anyhow!("not your worktree"))?;
241            }
242            e.remove()
243        } else {
244            return Err(anyhow!("no such worktree"))?;
245        };
246
247        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
248            connection.worktrees.remove(&worktree_id);
249        }
250
251        if let Some(share) = &worktree.share {
252            for connection_id in share.guests.keys() {
253                if let Some(connection) = self.connections.get_mut(connection_id) {
254                    connection.worktrees.remove(&worktree_id);
255                }
256            }
257        }
258
259        for authorized_user_id in &worktree.authorized_user_ids {
260            if let Some(visible_worktrees) = self
261                .visible_worktrees_by_user_id
262                .get_mut(&authorized_user_id)
263            {
264                visible_worktrees.remove(&worktree_id);
265            }
266        }
267
268        #[cfg(test)]
269        self.check_invariants();
270
271        Ok(worktree)
272    }
273
274    pub fn share_worktree(
275        &mut self,
276        worktree_id: u64,
277        connection_id: ConnectionId,
278        entries: HashMap<u64, proto::Entry>,
279    ) -> Option<Vec<UserId>> {
280        if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
281            if worktree.host_connection_id == connection_id {
282                worktree.share = Some(WorktreeShare {
283                    guests: Default::default(),
284                    active_replica_ids: Default::default(),
285                    entries,
286                });
287                return Some(worktree.authorized_user_ids.clone());
288            }
289        }
290        None
291    }
292
293    pub fn unshare_worktree(
294        &mut self,
295        worktree_id: u64,
296        acting_connection_id: ConnectionId,
297    ) -> tide::Result<UnsharedWorktree> {
298        let worktree = if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
299            worktree
300        } else {
301            return Err(anyhow!("no such worktree"))?;
302        };
303
304        if worktree.host_connection_id != acting_connection_id {
305            return Err(anyhow!("not your worktree"))?;
306        }
307
308        let connection_ids = worktree.connection_ids();
309        let authorized_user_ids = worktree.authorized_user_ids.clone();
310        if let Some(share) = worktree.share.take() {
311            for connection_id in share.guests.into_keys() {
312                if let Some(connection) = self.connections.get_mut(&connection_id) {
313                    connection.worktrees.remove(&worktree_id);
314                }
315            }
316
317            #[cfg(test)]
318            self.check_invariants();
319
320            Ok(UnsharedWorktree {
321                connection_ids,
322                authorized_user_ids,
323            })
324        } else {
325            Err(anyhow!("worktree is not shared"))?
326        }
327    }
328
329    pub fn join_worktree(
330        &mut self,
331        connection_id: ConnectionId,
332        user_id: UserId,
333        worktree_id: u64,
334    ) -> tide::Result<JoinedWorktree> {
335        let connection = self
336            .connections
337            .get_mut(&connection_id)
338            .ok_or_else(|| anyhow!("no such connection"))?;
339        let worktree = self
340            .worktrees
341            .get_mut(&worktree_id)
342            .and_then(|worktree| {
343                if worktree.authorized_user_ids.contains(&user_id) {
344                    Some(worktree)
345                } else {
346                    None
347                }
348            })
349            .ok_or_else(|| anyhow!("no such worktree"))?;
350
351        let share = worktree.share_mut()?;
352        connection.worktrees.insert(worktree_id);
353
354        let mut replica_id = 1;
355        while share.active_replica_ids.contains(&replica_id) {
356            replica_id += 1;
357        }
358        share.active_replica_ids.insert(replica_id);
359        share.guests.insert(connection_id, (replica_id, user_id));
360
361        #[cfg(test)]
362        self.check_invariants();
363
364        Ok(JoinedWorktree {
365            replica_id,
366            worktree: &self.worktrees[&worktree_id],
367        })
368    }
369
370    pub fn leave_worktree(
371        &mut self,
372        connection_id: ConnectionId,
373        worktree_id: u64,
374    ) -> Option<LeftWorktree> {
375        let worktree = self.worktrees.get_mut(&worktree_id)?;
376        let share = worktree.share.as_mut()?;
377        let (replica_id, _) = share.guests.remove(&connection_id)?;
378        share.active_replica_ids.remove(&replica_id);
379
380        if let Some(connection) = self.connections.get_mut(&connection_id) {
381            connection.worktrees.remove(&worktree_id);
382        }
383
384        let connection_ids = worktree.connection_ids();
385        let authorized_user_ids = worktree.authorized_user_ids.clone();
386
387        #[cfg(test)]
388        self.check_invariants();
389
390        Some(LeftWorktree {
391            connection_ids,
392            authorized_user_ids,
393        })
394    }
395
396    pub fn update_worktree(
397        &mut self,
398        connection_id: ConnectionId,
399        worktree_id: u64,
400        removed_entries: &[u64],
401        updated_entries: &[proto::Entry],
402    ) -> tide::Result<Vec<ConnectionId>> {
403        let worktree = self.write_worktree(worktree_id, connection_id)?;
404        let share = worktree.share_mut()?;
405        for entry_id in removed_entries {
406            share.entries.remove(&entry_id);
407        }
408        for entry in updated_entries {
409            share.entries.insert(entry.id, entry.clone());
410        }
411        Ok(worktree.connection_ids())
412    }
413
414    pub fn worktree_host_connection_id(
415        &self,
416        connection_id: ConnectionId,
417        worktree_id: u64,
418    ) -> tide::Result<ConnectionId> {
419        Ok(self
420            .read_worktree(worktree_id, connection_id)?
421            .host_connection_id)
422    }
423
424    pub fn worktree_guest_connection_ids(
425        &self,
426        connection_id: ConnectionId,
427        worktree_id: u64,
428    ) -> tide::Result<Vec<ConnectionId>> {
429        Ok(self
430            .read_worktree(worktree_id, connection_id)?
431            .share()?
432            .guests
433            .keys()
434            .copied()
435            .collect())
436    }
437
438    pub fn worktree_connection_ids(
439        &self,
440        connection_id: ConnectionId,
441        worktree_id: u64,
442    ) -> tide::Result<Vec<ConnectionId>> {
443        Ok(self
444            .read_worktree(worktree_id, connection_id)?
445            .connection_ids())
446    }
447
448    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
449        Some(self.channels.get(&channel_id)?.connection_ids())
450    }
451
452    fn read_worktree(
453        &self,
454        worktree_id: u64,
455        connection_id: ConnectionId,
456    ) -> tide::Result<&Worktree> {
457        let worktree = self
458            .worktrees
459            .get(&worktree_id)
460            .ok_or_else(|| anyhow!("worktree not found"))?;
461
462        if worktree.host_connection_id == connection_id
463            || worktree.share()?.guests.contains_key(&connection_id)
464        {
465            Ok(worktree)
466        } else {
467            Err(anyhow!(
468                "{} is not a member of worktree {}",
469                connection_id,
470                worktree_id
471            ))?
472        }
473    }
474
475    fn write_worktree(
476        &mut self,
477        worktree_id: u64,
478        connection_id: ConnectionId,
479    ) -> tide::Result<&mut Worktree> {
480        let worktree = self
481            .worktrees
482            .get_mut(&worktree_id)
483            .ok_or_else(|| anyhow!("worktree not found"))?;
484
485        if worktree.host_connection_id == connection_id
486            || worktree
487                .share
488                .as_ref()
489                .map_or(false, |share| share.guests.contains_key(&connection_id))
490        {
491            Ok(worktree)
492        } else {
493            Err(anyhow!(
494                "{} is not a member of worktree {}",
495                connection_id,
496                worktree_id
497            ))?
498        }
499    }
500
501    #[cfg(test)]
502    fn check_invariants(&self) {
503        for (connection_id, connection) in &self.connections {
504            for worktree_id in &connection.worktrees {
505                let worktree = &self.worktrees.get(&worktree_id).unwrap();
506                if worktree.host_connection_id != *connection_id {
507                    assert!(worktree.share().unwrap().guests.contains_key(connection_id));
508                }
509            }
510            for channel_id in &connection.channels {
511                let channel = self.channels.get(channel_id).unwrap();
512                assert!(channel.connection_ids.contains(connection_id));
513            }
514            assert!(self
515                .connections_by_user_id
516                .get(&connection.user_id)
517                .unwrap()
518                .contains(connection_id));
519        }
520
521        for (user_id, connection_ids) in &self.connections_by_user_id {
522            for connection_id in connection_ids {
523                assert_eq!(
524                    self.connections.get(connection_id).unwrap().user_id,
525                    *user_id
526                );
527            }
528        }
529
530        for (worktree_id, worktree) in &self.worktrees {
531            let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
532            assert!(host_connection.worktrees.contains(worktree_id));
533
534            for authorized_user_ids in &worktree.authorized_user_ids {
535                let visible_worktree_ids = self
536                    .visible_worktrees_by_user_id
537                    .get(authorized_user_ids)
538                    .unwrap();
539                assert!(visible_worktree_ids.contains(worktree_id));
540            }
541
542            if let Some(share) = &worktree.share {
543                for guest_connection_id in share.guests.keys() {
544                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
545                    assert!(guest_connection.worktrees.contains(worktree_id));
546                }
547                assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
548                assert_eq!(
549                    share.active_replica_ids,
550                    share
551                        .guests
552                        .values()
553                        .map(|(replica_id, _)| *replica_id)
554                        .collect::<HashSet<_>>(),
555                );
556            }
557        }
558
559        for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
560            for worktree_id in visible_worktree_ids {
561                let worktree = self.worktrees.get(worktree_id).unwrap();
562                assert!(worktree.authorized_user_ids.contains(user_id));
563            }
564        }
565
566        for (channel_id, channel) in &self.channels {
567            for connection_id in &channel.connection_ids {
568                let connection = self.connections.get(connection_id).unwrap();
569                assert!(connection.channels.contains(channel_id));
570            }
571        }
572    }
573}
574
575impl Worktree {
576    pub fn connection_ids(&self) -> Vec<ConnectionId> {
577        if let Some(share) = &self.share {
578            share
579                .guests
580                .keys()
581                .copied()
582                .chain(Some(self.host_connection_id))
583                .collect()
584        } else {
585            vec![self.host_connection_id]
586        }
587    }
588
589    pub fn share(&self) -> tide::Result<&WorktreeShare> {
590        Ok(self
591            .share
592            .as_ref()
593            .ok_or_else(|| anyhow!("worktree is not shared"))?)
594    }
595
596    fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
597        Ok(self
598            .share
599            .as_mut()
600            .ok_or_else(|| anyhow!("worktree is not shared"))?)
601    }
602}
603
604impl Channel {
605    fn connection_ids(&self) -> Vec<ConnectionId> {
606        self.connection_ids.iter().copied().collect()
607    }
608}