store.rs

  1use crate::db::{ChannelId, UserId};
  2use crate::errors::TideResultExt;
  3use anyhow::anyhow;
  4use std::collections::{hash_map, HashMap, HashSet};
  5use zrpc::{proto, ConnectionId};
  6
  7#[derive(Default)]
  8pub struct Store {
  9    connections: HashMap<ConnectionId, ConnectionState>,
 10    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 11    worktrees: HashMap<u64, Worktree>,
 12    visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
 13    channels: HashMap<ChannelId, Channel>,
 14    next_worktree_id: u64,
 15}
 16
 17struct ConnectionState {
 18    user_id: UserId,
 19    worktrees: HashSet<u64>,
 20    channels: HashSet<ChannelId>,
 21}
 22
 23pub struct Worktree {
 24    pub host_connection_id: ConnectionId,
 25    pub collaborator_user_ids: Vec<UserId>,
 26    pub root_name: String,
 27    pub share: Option<WorktreeShare>,
 28}
 29
 30pub struct WorktreeShare {
 31    pub guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
 32    pub active_replica_ids: HashSet<ReplicaId>,
 33    pub entries: HashMap<u64, proto::Entry>,
 34}
 35
 36#[derive(Default)]
 37pub struct Channel {
 38    pub connection_ids: HashSet<ConnectionId>,
 39}
 40
 41pub type ReplicaId = u16;
 42
 43#[derive(Default)]
 44pub struct RemovedConnectionState {
 45    pub hosted_worktrees: HashMap<u64, Worktree>,
 46    pub guest_worktree_ids: HashMap<u64, Vec<ConnectionId>>,
 47    pub collaborator_ids: HashSet<UserId>,
 48}
 49
 50pub struct JoinedWorktree<'a> {
 51    pub replica_id: ReplicaId,
 52    pub worktree: &'a Worktree,
 53}
 54
 55pub struct UnsharedWorktree {
 56    pub connection_ids: Vec<ConnectionId>,
 57    pub collaborator_ids: Vec<UserId>,
 58}
 59
 60pub struct LeftWorktree {
 61    pub connection_ids: Vec<ConnectionId>,
 62    pub collaborator_ids: Vec<UserId>,
 63}
 64
 65impl Store {
 66    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 67        self.connections.insert(
 68            connection_id,
 69            ConnectionState {
 70                user_id,
 71                worktrees: Default::default(),
 72                channels: Default::default(),
 73            },
 74        );
 75        self.connections_by_user_id
 76            .entry(user_id)
 77            .or_default()
 78            .insert(connection_id);
 79    }
 80
 81    pub fn remove_connection(
 82        &mut self,
 83        connection_id: ConnectionId,
 84    ) -> tide::Result<RemovedConnectionState> {
 85        let connection = if let Some(connection) = self.connections.get(&connection_id) {
 86            connection
 87        } else {
 88            return Err(anyhow!("no such connection"))?;
 89        };
 90
 91        for channel_id in &connection.channels {
 92            if let Some(channel) = self.channels.get_mut(&channel_id) {
 93                channel.connection_ids.remove(&connection_id);
 94            }
 95        }
 96
 97        let user_connections = self
 98            .connections_by_user_id
 99            .get_mut(&connection.user_id)
100            .unwrap();
101        user_connections.remove(&connection_id);
102        if user_connections.is_empty() {
103            self.connections_by_user_id.remove(&connection.user_id);
104        }
105
106        let mut result = RemovedConnectionState::default();
107        for worktree_id in connection.worktrees.clone() {
108            if let Ok(worktree) = self.remove_worktree(worktree_id, connection_id) {
109                result
110                    .collaborator_ids
111                    .extend(worktree.collaborator_user_ids.iter().copied());
112                result.hosted_worktrees.insert(worktree_id, worktree);
113            } else {
114                if let Some(worktree) = self.worktrees.get(&worktree_id) {
115                    result
116                        .guest_worktree_ids
117                        .insert(worktree_id, worktree.connection_ids());
118                    result
119                        .collaborator_ids
120                        .extend(worktree.collaborator_user_ids.iter().copied());
121                }
122            }
123        }
124
125        Ok(result)
126    }
127
128    #[cfg(test)]
129    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
130        self.channels.get(&id)
131    }
132
133    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
134        if let Some(connection) = self.connections.get_mut(&connection_id) {
135            connection.channels.insert(channel_id);
136            self.channels
137                .entry(channel_id)
138                .or_default()
139                .connection_ids
140                .insert(connection_id);
141        }
142    }
143
144    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
145        if let Some(connection) = self.connections.get_mut(&connection_id) {
146            connection.channels.remove(&channel_id);
147            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
148                entry.get_mut().connection_ids.remove(&connection_id);
149                if entry.get_mut().connection_ids.is_empty() {
150                    entry.remove();
151                }
152            }
153        }
154    }
155
156    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
157        Ok(self
158            .connections
159            .get(&connection_id)
160            .ok_or_else(|| anyhow!("unknown connection"))?
161            .user_id)
162    }
163
164    pub fn connection_ids_for_user<'a>(
165        &'a self,
166        user_id: UserId,
167    ) -> impl 'a + Iterator<Item = ConnectionId> {
168        self.connections_by_user_id
169            .get(&user_id)
170            .into_iter()
171            .flatten()
172            .copied()
173    }
174
175    pub fn collaborators_for_user(&self, user_id: UserId) -> Vec<proto::Collaborator> {
176        let mut collaborators = HashMap::new();
177        for worktree_id in self
178            .visible_worktrees_by_user_id
179            .get(&user_id)
180            .unwrap_or(&HashSet::new())
181        {
182            let worktree = &self.worktrees[worktree_id];
183
184            let mut guests = HashSet::new();
185            if let Ok(share) = worktree.share() {
186                for guest_connection_id in share.guest_connection_ids.keys() {
187                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
188                        guests.insert(user_id.to_proto());
189                    }
190                }
191            }
192
193            if let Ok(host_user_id) = self
194                .user_id_for_connection(worktree.host_connection_id)
195                .context("stale worktree host connection")
196            {
197                let host =
198                    collaborators
199                        .entry(host_user_id)
200                        .or_insert_with(|| proto::Collaborator {
201                            user_id: host_user_id.to_proto(),
202                            worktrees: Vec::new(),
203                        });
204                host.worktrees.push(proto::WorktreeMetadata {
205                    id: *worktree_id,
206                    root_name: worktree.root_name.clone(),
207                    is_shared: worktree.share().is_ok(),
208                    guests: guests.into_iter().collect(),
209                });
210            }
211        }
212
213        collaborators.into_values().collect()
214    }
215
216    pub fn add_worktree(&mut self, worktree: Worktree) -> u64 {
217        let worktree_id = self.next_worktree_id;
218        for collaborator_user_id in &worktree.collaborator_user_ids {
219            self.visible_worktrees_by_user_id
220                .entry(*collaborator_user_id)
221                .or_default()
222                .insert(worktree_id);
223        }
224        self.next_worktree_id += 1;
225        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
226            connection.worktrees.insert(worktree_id);
227        }
228        self.worktrees.insert(worktree_id, worktree);
229
230        #[cfg(test)]
231        self.check_invariants();
232
233        worktree_id
234    }
235
236    pub fn remove_worktree(
237        &mut self,
238        worktree_id: u64,
239        acting_connection_id: ConnectionId,
240    ) -> tide::Result<Worktree> {
241        let worktree = if let hash_map::Entry::Occupied(e) = self.worktrees.entry(worktree_id) {
242            if e.get().host_connection_id != acting_connection_id {
243                Err(anyhow!("not your worktree"))?;
244            }
245            e.remove()
246        } else {
247            return Err(anyhow!("no such worktree"))?;
248        };
249
250        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
251            connection.worktrees.remove(&worktree_id);
252        }
253
254        if let Some(share) = &worktree.share {
255            for connection_id in share.guest_connection_ids.keys() {
256                if let Some(connection) = self.connections.get_mut(connection_id) {
257                    connection.worktrees.remove(&worktree_id);
258                }
259            }
260        }
261
262        for collaborator_user_id in &worktree.collaborator_user_ids {
263            if let Some(visible_worktrees) = self
264                .visible_worktrees_by_user_id
265                .get_mut(&collaborator_user_id)
266            {
267                visible_worktrees.remove(&worktree_id);
268            }
269        }
270
271        #[cfg(test)]
272        self.check_invariants();
273
274        Ok(worktree)
275    }
276
277    pub fn share_worktree(
278        &mut self,
279        worktree_id: u64,
280        connection_id: ConnectionId,
281        entries: HashMap<u64, proto::Entry>,
282    ) -> Option<Vec<UserId>> {
283        if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
284            if worktree.host_connection_id == connection_id {
285                worktree.share = Some(WorktreeShare {
286                    guest_connection_ids: Default::default(),
287                    active_replica_ids: Default::default(),
288                    entries,
289                });
290                return Some(worktree.collaborator_user_ids.clone());
291            }
292        }
293        None
294    }
295
296    pub fn unshare_worktree(
297        &mut self,
298        worktree_id: u64,
299        acting_connection_id: ConnectionId,
300    ) -> tide::Result<UnsharedWorktree> {
301        let worktree = if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
302            worktree
303        } else {
304            return Err(anyhow!("no such worktree"))?;
305        };
306
307        if worktree.host_connection_id != acting_connection_id {
308            return Err(anyhow!("not your worktree"))?;
309        }
310
311        let connection_ids = worktree.connection_ids();
312
313        if let Some(_) = worktree.share.take() {
314            for connection_id in &connection_ids {
315                if let Some(connection) = self.connections.get_mut(connection_id) {
316                    connection.worktrees.remove(&worktree_id);
317                }
318            }
319            Ok(UnsharedWorktree {
320                connection_ids,
321                collaborator_ids: worktree.collaborator_user_ids.clone(),
322            })
323        } else {
324            Err(anyhow!("worktree is not shared"))?
325        }
326    }
327
328    pub fn join_worktree(
329        &mut self,
330        connection_id: ConnectionId,
331        user_id: UserId,
332        worktree_id: u64,
333    ) -> tide::Result<JoinedWorktree> {
334        let connection = self
335            .connections
336            .get_mut(&connection_id)
337            .ok_or_else(|| anyhow!("no such connection"))?;
338        let worktree = self
339            .worktrees
340            .get_mut(&worktree_id)
341            .and_then(|worktree| {
342                if worktree.collaborator_user_ids.contains(&user_id) {
343                    Some(worktree)
344                } else {
345                    None
346                }
347            })
348            .ok_or_else(|| anyhow!("no such worktree"))?;
349
350        let share = worktree.share_mut()?;
351        connection.worktrees.insert(worktree_id);
352
353        let mut replica_id = 1;
354        while share.active_replica_ids.contains(&replica_id) {
355            replica_id += 1;
356        }
357        share.active_replica_ids.insert(replica_id);
358        share.guest_connection_ids.insert(connection_id, replica_id);
359        Ok(JoinedWorktree {
360            replica_id,
361            worktree,
362        })
363    }
364
365    pub fn leave_worktree(
366        &mut self,
367        connection_id: ConnectionId,
368        worktree_id: u64,
369    ) -> Option<LeftWorktree> {
370        let worktree = self.worktrees.get_mut(&worktree_id)?;
371        let share = worktree.share.as_mut()?;
372        let replica_id = share.guest_connection_ids.remove(&connection_id)?;
373        share.active_replica_ids.remove(&replica_id);
374        Some(LeftWorktree {
375            connection_ids: worktree.connection_ids(),
376            collaborator_ids: worktree.collaborator_user_ids.clone(),
377        })
378    }
379
380    pub fn update_worktree(
381        &mut self,
382        connection_id: ConnectionId,
383        worktree_id: u64,
384        removed_entries: &[u64],
385        updated_entries: &[proto::Entry],
386    ) -> tide::Result<Vec<ConnectionId>> {
387        let worktree = self.write_worktree(worktree_id, connection_id)?;
388        let share = worktree.share_mut()?;
389        for entry_id in removed_entries {
390            share.entries.remove(&entry_id);
391        }
392        for entry in updated_entries {
393            share.entries.insert(entry.id, entry.clone());
394        }
395        Ok(worktree.connection_ids())
396    }
397
398    pub fn worktree_host_connection_id(
399        &self,
400        connection_id: ConnectionId,
401        worktree_id: u64,
402    ) -> tide::Result<ConnectionId> {
403        Ok(self
404            .read_worktree(worktree_id, connection_id)?
405            .host_connection_id)
406    }
407
408    pub fn worktree_guest_connection_ids(
409        &self,
410        connection_id: ConnectionId,
411        worktree_id: u64,
412    ) -> tide::Result<Vec<ConnectionId>> {
413        Ok(self
414            .read_worktree(worktree_id, connection_id)?
415            .share()?
416            .guest_connection_ids
417            .keys()
418            .copied()
419            .collect())
420    }
421
422    pub fn worktree_connection_ids(
423        &self,
424        connection_id: ConnectionId,
425        worktree_id: u64,
426    ) -> tide::Result<Vec<ConnectionId>> {
427        Ok(self
428            .read_worktree(worktree_id, connection_id)?
429            .connection_ids())
430    }
431
432    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
433        Some(self.channels.get(&channel_id)?.connection_ids())
434    }
435
436    fn read_worktree(
437        &self,
438        worktree_id: u64,
439        connection_id: ConnectionId,
440    ) -> tide::Result<&Worktree> {
441        let worktree = self
442            .worktrees
443            .get(&worktree_id)
444            .ok_or_else(|| anyhow!("worktree not found"))?;
445
446        if worktree.host_connection_id == connection_id
447            || worktree
448                .share()?
449                .guest_connection_ids
450                .contains_key(&connection_id)
451        {
452            Ok(worktree)
453        } else {
454            Err(anyhow!(
455                "{} is not a member of worktree {}",
456                connection_id,
457                worktree_id
458            ))?
459        }
460    }
461
462    fn write_worktree(
463        &mut self,
464        worktree_id: u64,
465        connection_id: ConnectionId,
466    ) -> tide::Result<&mut Worktree> {
467        let worktree = self
468            .worktrees
469            .get_mut(&worktree_id)
470            .ok_or_else(|| anyhow!("worktree not found"))?;
471
472        if worktree.host_connection_id == connection_id
473            || worktree.share.as_ref().map_or(false, |share| {
474                share.guest_connection_ids.contains_key(&connection_id)
475            })
476        {
477            Ok(worktree)
478        } else {
479            Err(anyhow!(
480                "{} is not a member of worktree {}",
481                connection_id,
482                worktree_id
483            ))?
484        }
485    }
486
487    #[cfg(test)]
488    fn check_invariants(&self) {
489        for (connection_id, connection) in &self.connections {
490            for worktree_id in &connection.worktrees {
491                let worktree = &self.worktrees.get(&worktree_id).unwrap();
492                if worktree.host_connection_id != *connection_id {
493                    assert!(worktree
494                        .share()
495                        .unwrap()
496                        .guest_connection_ids
497                        .contains_key(connection_id));
498                }
499            }
500            for channel_id in &connection.channels {
501                let channel = self.channels.get(channel_id).unwrap();
502                assert!(channel.connection_ids.contains(connection_id));
503            }
504            assert!(self
505                .connections_by_user_id
506                .get(&connection.user_id)
507                .unwrap()
508                .contains(connection_id));
509        }
510
511        for (user_id, connection_ids) in &self.connections_by_user_id {
512            for connection_id in connection_ids {
513                assert_eq!(
514                    self.connections.get(connection_id).unwrap().user_id,
515                    *user_id
516                );
517            }
518        }
519
520        for (worktree_id, worktree) in &self.worktrees {
521            let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
522            assert!(host_connection.worktrees.contains(worktree_id));
523
524            for collaborator_id in &worktree.collaborator_user_ids {
525                let visible_worktree_ids = self
526                    .visible_worktrees_by_user_id
527                    .get(collaborator_id)
528                    .unwrap();
529                assert!(visible_worktree_ids.contains(worktree_id));
530            }
531
532            if let Some(share) = &worktree.share {
533                for guest_connection_id in share.guest_connection_ids.keys() {
534                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
535                    assert!(guest_connection.worktrees.contains(worktree_id));
536                }
537                assert_eq!(
538                    share.active_replica_ids.len(),
539                    share.guest_connection_ids.len(),
540                );
541                assert_eq!(
542                    share.active_replica_ids,
543                    share
544                        .guest_connection_ids
545                        .values()
546                        .copied()
547                        .collect::<HashSet<_>>(),
548                );
549            }
550        }
551
552        for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
553            for worktree_id in visible_worktree_ids {
554                let worktree = self.worktrees.get(worktree_id).unwrap();
555                assert!(worktree.collaborator_user_ids.contains(user_id));
556            }
557        }
558
559        for (channel_id, channel) in &self.channels {
560            for connection_id in &channel.connection_ids {
561                let connection = self.connections.get(connection_id).unwrap();
562                assert!(connection.channels.contains(channel_id));
563            }
564        }
565    }
566}
567
568impl Worktree {
569    pub fn connection_ids(&self) -> Vec<ConnectionId> {
570        if let Some(share) = &self.share {
571            share
572                .guest_connection_ids
573                .keys()
574                .copied()
575                .chain(Some(self.host_connection_id))
576                .collect()
577        } else {
578            vec![self.host_connection_id]
579        }
580    }
581
582    pub fn share(&self) -> tide::Result<&WorktreeShare> {
583        Ok(self
584            .share
585            .as_ref()
586            .ok_or_else(|| anyhow!("worktree is not shared"))?)
587    }
588
589    fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
590        Ok(self
591            .share
592            .as_mut()
593            .ok_or_else(|| anyhow!("worktree is not shared"))?)
594    }
595}
596
597impl Channel {
598    fn connection_ids(&self) -> Vec<ConnectionId> {
599        self.connection_ids.iter().copied().collect()
600    }
601}