store.rs

  1use crate::db::{ChannelId, UserId};
  2use crate::errors::TideResultExt;
  3use anyhow::anyhow;
  4use std::collections::{hash_map, HashMap, HashSet};
  5use zrpc::{proto, ConnectionId};
  6
  7#[derive(Default)]
  8pub struct Store {
  9    connections: HashMap<ConnectionId, ConnectionState>,
 10    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 11    worktrees: HashMap<u64, Worktree>,
 12    visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
 13    channels: HashMap<ChannelId, Channel>,
 14    next_worktree_id: u64,
 15}
 16
 17struct ConnectionState {
 18    user_id: UserId,
 19    worktrees: HashSet<u64>,
 20    channels: HashSet<ChannelId>,
 21}
 22
 23pub struct Worktree {
 24    pub host_connection_id: ConnectionId,
 25    pub collaborator_user_ids: Vec<UserId>,
 26    pub root_name: String,
 27    pub share: Option<WorktreeShare>,
 28}
 29
 30pub struct WorktreeShare {
 31    pub guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
 32    pub active_replica_ids: HashSet<ReplicaId>,
 33    pub entries: HashMap<u64, proto::Entry>,
 34}
 35
 36#[derive(Default)]
 37pub struct Channel {
 38    pub connection_ids: HashSet<ConnectionId>,
 39}
 40
 41pub type ReplicaId = u16;
 42
 43#[derive(Default)]
 44pub struct RemovedConnectionState {
 45    pub hosted_worktrees: HashMap<u64, Worktree>,
 46    pub guest_worktree_ids: HashMap<u64, Vec<ConnectionId>>,
 47    pub collaborator_ids: HashSet<UserId>,
 48}
 49
 50pub struct JoinedWorktree<'a> {
 51    pub replica_id: ReplicaId,
 52    pub worktree: &'a Worktree,
 53}
 54
 55pub struct UnsharedWorktree {
 56    pub connection_ids: Vec<ConnectionId>,
 57    pub collaborator_ids: Vec<UserId>,
 58}
 59
 60pub struct LeftWorktree {
 61    pub connection_ids: Vec<ConnectionId>,
 62    pub collaborator_ids: Vec<UserId>,
 63}
 64
 65impl Store {
 66    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 67        self.connections.insert(
 68            connection_id,
 69            ConnectionState {
 70                user_id,
 71                worktrees: Default::default(),
 72                channels: Default::default(),
 73            },
 74        );
 75        self.connections_by_user_id
 76            .entry(user_id)
 77            .or_default()
 78            .insert(connection_id);
 79    }
 80
 81    pub fn remove_connection(
 82        &mut self,
 83        connection_id: ConnectionId,
 84    ) -> tide::Result<RemovedConnectionState> {
 85        let connection = if let Some(connection) = self.connections.get(&connection_id) {
 86            connection
 87        } else {
 88            return Err(anyhow!("no such connection"))?;
 89        };
 90
 91        for channel_id in &connection.channels {
 92            if let Some(channel) = self.channels.get_mut(&channel_id) {
 93                channel.connection_ids.remove(&connection_id);
 94            }
 95        }
 96
 97        let user_connections = self
 98            .connections_by_user_id
 99            .get_mut(&connection.user_id)
100            .unwrap();
101        user_connections.remove(&connection_id);
102        if user_connections.is_empty() {
103            self.connections_by_user_id.remove(&connection.user_id);
104        }
105
106        let mut result = RemovedConnectionState::default();
107        for worktree_id in connection.worktrees.clone() {
108            if let Ok(worktree) = self.remove_worktree(worktree_id, connection_id) {
109                result
110                    .collaborator_ids
111                    .extend(worktree.collaborator_user_ids.iter().copied());
112                result.hosted_worktrees.insert(worktree_id, worktree);
113            } else {
114                if let Some(worktree) = self.worktrees.get(&worktree_id) {
115                    result
116                        .guest_worktree_ids
117                        .insert(worktree_id, worktree.connection_ids());
118                    result
119                        .collaborator_ids
120                        .extend(worktree.collaborator_user_ids.iter().copied());
121                }
122            }
123        }
124
125        Ok(result)
126    }
127
128    #[cfg(test)]
129    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
130        self.channels.get(&id)
131    }
132
133    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
134        if let Some(connection) = self.connections.get_mut(&connection_id) {
135            connection.channels.insert(channel_id);
136            self.channels
137                .entry(channel_id)
138                .or_default()
139                .connection_ids
140                .insert(connection_id);
141        }
142    }
143
144    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
145        if let Some(connection) = self.connections.get_mut(&connection_id) {
146            connection.channels.remove(&channel_id);
147            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
148                entry.get_mut().connection_ids.remove(&connection_id);
149                if entry.get_mut().connection_ids.is_empty() {
150                    entry.remove();
151                }
152            }
153        }
154    }
155
156    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
157        Ok(self
158            .connections
159            .get(&connection_id)
160            .ok_or_else(|| anyhow!("unknown connection"))?
161            .user_id)
162    }
163
164    pub fn connection_ids_for_user<'a>(
165        &'a self,
166        user_id: UserId,
167    ) -> impl 'a + Iterator<Item = ConnectionId> {
168        self.connections_by_user_id
169            .get(&user_id)
170            .into_iter()
171            .flatten()
172            .copied()
173    }
174
175    pub fn collaborators_for_user(&self, user_id: UserId) -> Vec<proto::Collaborator> {
176        let mut collaborators = HashMap::new();
177        for worktree_id in self
178            .visible_worktrees_by_user_id
179            .get(&user_id)
180            .unwrap_or(&HashSet::new())
181        {
182            let worktree = &self.worktrees[worktree_id];
183
184            let mut guests = HashSet::new();
185            if let Ok(share) = worktree.share() {
186                for guest_connection_id in share.guest_connection_ids.keys() {
187                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
188                        guests.insert(user_id.to_proto());
189                    }
190                }
191            }
192
193            if let Ok(host_user_id) = self
194                .user_id_for_connection(worktree.host_connection_id)
195                .context("stale worktree host connection")
196            {
197                let host =
198                    collaborators
199                        .entry(host_user_id)
200                        .or_insert_with(|| proto::Collaborator {
201                            user_id: host_user_id.to_proto(),
202                            worktrees: Vec::new(),
203                        });
204                host.worktrees.push(proto::WorktreeMetadata {
205                    root_name: worktree.root_name.clone(),
206                    is_shared: worktree.share().is_ok(),
207                    participants: guests.into_iter().collect(),
208                });
209            }
210        }
211
212        collaborators.into_values().collect()
213    }
214
215    pub fn add_worktree(&mut self, worktree: Worktree) -> u64 {
216        let worktree_id = self.next_worktree_id;
217        for collaborator_user_id in &worktree.collaborator_user_ids {
218            self.visible_worktrees_by_user_id
219                .entry(*collaborator_user_id)
220                .or_default()
221                .insert(worktree_id);
222        }
223        self.next_worktree_id += 1;
224        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
225            connection.worktrees.insert(worktree_id);
226        }
227        self.worktrees.insert(worktree_id, worktree);
228
229        #[cfg(test)]
230        self.check_invariants();
231
232        worktree_id
233    }
234
235    pub fn remove_worktree(
236        &mut self,
237        worktree_id: u64,
238        acting_connection_id: ConnectionId,
239    ) -> tide::Result<Worktree> {
240        let worktree = if let hash_map::Entry::Occupied(e) = self.worktrees.entry(worktree_id) {
241            if e.get().host_connection_id != acting_connection_id {
242                Err(anyhow!("not your worktree"))?;
243            }
244            e.remove()
245        } else {
246            return Err(anyhow!("no such worktree"))?;
247        };
248
249        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
250            connection.worktrees.remove(&worktree_id);
251        }
252
253        if let Some(share) = &worktree.share {
254            for connection_id in share.guest_connection_ids.keys() {
255                if let Some(connection) = self.connections.get_mut(connection_id) {
256                    connection.worktrees.remove(&worktree_id);
257                }
258            }
259        }
260
261        for collaborator_user_id in &worktree.collaborator_user_ids {
262            if let Some(visible_worktrees) = self
263                .visible_worktrees_by_user_id
264                .get_mut(&collaborator_user_id)
265            {
266                visible_worktrees.remove(&worktree_id);
267            }
268        }
269
270        #[cfg(test)]
271        self.check_invariants();
272
273        Ok(worktree)
274    }
275
276    pub fn share_worktree(
277        &mut self,
278        worktree_id: u64,
279        connection_id: ConnectionId,
280        entries: HashMap<u64, proto::Entry>,
281    ) -> Option<Vec<UserId>> {
282        if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
283            if worktree.host_connection_id == connection_id {
284                worktree.share = Some(WorktreeShare {
285                    guest_connection_ids: Default::default(),
286                    active_replica_ids: Default::default(),
287                    entries,
288                });
289                return Some(worktree.collaborator_user_ids.clone());
290            }
291        }
292        None
293    }
294
295    pub fn unshare_worktree(
296        &mut self,
297        worktree_id: u64,
298        acting_connection_id: ConnectionId,
299    ) -> tide::Result<UnsharedWorktree> {
300        let worktree = if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
301            worktree
302        } else {
303            return Err(anyhow!("no such worktree"))?;
304        };
305
306        if worktree.host_connection_id != acting_connection_id {
307            return Err(anyhow!("not your worktree"))?;
308        }
309
310        let connection_ids = worktree.connection_ids();
311
312        if let Some(_) = worktree.share.take() {
313            for connection_id in &connection_ids {
314                if let Some(connection) = self.connections.get_mut(connection_id) {
315                    connection.worktrees.remove(&worktree_id);
316                }
317            }
318            Ok(UnsharedWorktree {
319                connection_ids,
320                collaborator_ids: worktree.collaborator_user_ids.clone(),
321            })
322        } else {
323            Err(anyhow!("worktree is not shared"))?
324        }
325    }
326
327    pub fn join_worktree(
328        &mut self,
329        connection_id: ConnectionId,
330        user_id: UserId,
331        worktree_id: u64,
332    ) -> tide::Result<JoinedWorktree> {
333        let connection = self
334            .connections
335            .get_mut(&connection_id)
336            .ok_or_else(|| anyhow!("no such connection"))?;
337        let worktree = self
338            .worktrees
339            .get_mut(&worktree_id)
340            .and_then(|worktree| {
341                if worktree.collaborator_user_ids.contains(&user_id) {
342                    Some(worktree)
343                } else {
344                    None
345                }
346            })
347            .ok_or_else(|| anyhow!("no such worktree"))?;
348
349        let share = worktree.share_mut()?;
350        connection.worktrees.insert(worktree_id);
351
352        let mut replica_id = 1;
353        while share.active_replica_ids.contains(&replica_id) {
354            replica_id += 1;
355        }
356        share.active_replica_ids.insert(replica_id);
357        share.guest_connection_ids.insert(connection_id, replica_id);
358        Ok(JoinedWorktree {
359            replica_id,
360            worktree,
361        })
362    }
363
364    pub fn leave_worktree(
365        &mut self,
366        connection_id: ConnectionId,
367        worktree_id: u64,
368    ) -> Option<LeftWorktree> {
369        let worktree = self.worktrees.get_mut(&worktree_id)?;
370        let share = worktree.share.as_mut()?;
371        let replica_id = share.guest_connection_ids.remove(&connection_id)?;
372        share.active_replica_ids.remove(&replica_id);
373        Some(LeftWorktree {
374            connection_ids: worktree.connection_ids(),
375            collaborator_ids: worktree.collaborator_user_ids.clone(),
376        })
377    }
378
379    pub fn update_worktree(
380        &mut self,
381        connection_id: ConnectionId,
382        worktree_id: u64,
383        removed_entries: &[u64],
384        updated_entries: &[proto::Entry],
385    ) -> tide::Result<Vec<ConnectionId>> {
386        let worktree = self.write_worktree(worktree_id, connection_id)?;
387        let share = worktree.share_mut()?;
388        for entry_id in removed_entries {
389            share.entries.remove(&entry_id);
390        }
391        for entry in updated_entries {
392            share.entries.insert(entry.id, entry.clone());
393        }
394        Ok(worktree.connection_ids())
395    }
396
397    pub fn worktree_host_connection_id(
398        &self,
399        connection_id: ConnectionId,
400        worktree_id: u64,
401    ) -> tide::Result<ConnectionId> {
402        Ok(self
403            .read_worktree(worktree_id, connection_id)?
404            .host_connection_id)
405    }
406
407    pub fn worktree_guest_connection_ids(
408        &self,
409        connection_id: ConnectionId,
410        worktree_id: u64,
411    ) -> tide::Result<Vec<ConnectionId>> {
412        Ok(self
413            .read_worktree(worktree_id, connection_id)?
414            .share()?
415            .guest_connection_ids
416            .keys()
417            .copied()
418            .collect())
419    }
420
421    pub fn worktree_connection_ids(
422        &self,
423        connection_id: ConnectionId,
424        worktree_id: u64,
425    ) -> tide::Result<Vec<ConnectionId>> {
426        Ok(self
427            .read_worktree(worktree_id, connection_id)?
428            .connection_ids())
429    }
430
431    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
432        Some(self.channels.get(&channel_id)?.connection_ids())
433    }
434
435    fn read_worktree(
436        &self,
437        worktree_id: u64,
438        connection_id: ConnectionId,
439    ) -> tide::Result<&Worktree> {
440        let worktree = self
441            .worktrees
442            .get(&worktree_id)
443            .ok_or_else(|| anyhow!("worktree not found"))?;
444
445        if worktree.host_connection_id == connection_id
446            || worktree
447                .share()?
448                .guest_connection_ids
449                .contains_key(&connection_id)
450        {
451            Ok(worktree)
452        } else {
453            Err(anyhow!(
454                "{} is not a member of worktree {}",
455                connection_id,
456                worktree_id
457            ))?
458        }
459    }
460
461    fn write_worktree(
462        &mut self,
463        worktree_id: u64,
464        connection_id: ConnectionId,
465    ) -> tide::Result<&mut Worktree> {
466        let worktree = self
467            .worktrees
468            .get_mut(&worktree_id)
469            .ok_or_else(|| anyhow!("worktree not found"))?;
470
471        if worktree.host_connection_id == connection_id
472            || worktree.share.as_ref().map_or(false, |share| {
473                share.guest_connection_ids.contains_key(&connection_id)
474            })
475        {
476            Ok(worktree)
477        } else {
478            Err(anyhow!(
479                "{} is not a member of worktree {}",
480                connection_id,
481                worktree_id
482            ))?
483        }
484    }
485
486    #[cfg(test)]
487    fn check_invariants(&self) {
488        for (connection_id, connection) in &self.connections {
489            for worktree_id in &connection.worktrees {
490                let worktree = &self.worktrees.get(&worktree_id).unwrap();
491                if worktree.host_connection_id != *connection_id {
492                    assert!(worktree
493                        .share()
494                        .unwrap()
495                        .guest_connection_ids
496                        .contains_key(connection_id));
497                }
498            }
499            for channel_id in &connection.channels {
500                let channel = self.channels.get(channel_id).unwrap();
501                assert!(channel.connection_ids.contains(connection_id));
502            }
503            assert!(self
504                .connections_by_user_id
505                .get(&connection.user_id)
506                .unwrap()
507                .contains(connection_id));
508        }
509
510        for (user_id, connection_ids) in &self.connections_by_user_id {
511            for connection_id in connection_ids {
512                assert_eq!(
513                    self.connections.get(connection_id).unwrap().user_id,
514                    *user_id
515                );
516            }
517        }
518
519        for (worktree_id, worktree) in &self.worktrees {
520            let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
521            assert!(host_connection.worktrees.contains(worktree_id));
522
523            for collaborator_id in &worktree.collaborator_user_ids {
524                let visible_worktree_ids = self
525                    .visible_worktrees_by_user_id
526                    .get(collaborator_id)
527                    .unwrap();
528                assert!(visible_worktree_ids.contains(worktree_id));
529            }
530
531            if let Some(share) = &worktree.share {
532                for guest_connection_id in share.guest_connection_ids.keys() {
533                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
534                    assert!(guest_connection.worktrees.contains(worktree_id));
535                }
536                assert_eq!(
537                    share.active_replica_ids.len(),
538                    share.guest_connection_ids.len(),
539                );
540                assert_eq!(
541                    share.active_replica_ids,
542                    share
543                        .guest_connection_ids
544                        .values()
545                        .copied()
546                        .collect::<HashSet<_>>(),
547                );
548            }
549        }
550
551        for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
552            for worktree_id in visible_worktree_ids {
553                let worktree = self.worktrees.get(worktree_id).unwrap();
554                assert!(worktree.collaborator_user_ids.contains(user_id));
555            }
556        }
557
558        for (channel_id, channel) in &self.channels {
559            for connection_id in &channel.connection_ids {
560                let connection = self.connections.get(connection_id).unwrap();
561                assert!(connection.channels.contains(channel_id));
562            }
563        }
564    }
565}
566
567impl Worktree {
568    pub fn connection_ids(&self) -> Vec<ConnectionId> {
569        if let Some(share) = &self.share {
570            share
571                .guest_connection_ids
572                .keys()
573                .copied()
574                .chain(Some(self.host_connection_id))
575                .collect()
576        } else {
577            vec![self.host_connection_id]
578        }
579    }
580
581    pub fn share(&self) -> tide::Result<&WorktreeShare> {
582        Ok(self
583            .share
584            .as_ref()
585            .ok_or_else(|| anyhow!("worktree is not shared"))?)
586    }
587
588    fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
589        Ok(self
590            .share
591            .as_mut()
592            .ok_or_else(|| anyhow!("worktree is not shared"))?)
593    }
594}
595
596impl Channel {
597    fn connection_ids(&self) -> Vec<ConnectionId> {
598        self.connection_ids.iter().copied().collect()
599    }
600}