store.rs

  1use crate::db::{ChannelId, UserId};
  2use crate::errors::TideResultExt;
  3use anyhow::anyhow;
  4use std::collections::{hash_map, HashMap, HashSet};
  5use zrpc::{proto, ConnectionId};
  6
  7#[derive(Default)]
  8pub struct Store {
  9    connections: HashMap<ConnectionId, ConnectionState>,
 10    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 11    worktrees: HashMap<u64, Worktree>,
 12    visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
 13    channels: HashMap<ChannelId, Channel>,
 14    next_worktree_id: u64,
 15}
 16
 17struct ConnectionState {
 18    user_id: UserId,
 19    worktrees: HashSet<u64>,
 20    channels: HashSet<ChannelId>,
 21}
 22
 23pub struct Worktree {
 24    pub host_connection_id: ConnectionId,
 25    pub collaborator_user_ids: Vec<UserId>,
 26    pub root_name: String,
 27    pub share: Option<WorktreeShare>,
 28}
 29
 30pub struct WorktreeShare {
 31    pub guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
 32    pub active_replica_ids: HashSet<ReplicaId>,
 33    pub entries: HashMap<u64, proto::Entry>,
 34}
 35
 36#[derive(Default)]
 37pub struct Channel {
 38    pub connection_ids: HashSet<ConnectionId>,
 39}
 40
 41pub type ReplicaId = u16;
 42
 43#[derive(Default)]
 44pub struct RemovedConnectionState {
 45    pub hosted_worktrees: HashMap<u64, Worktree>,
 46    pub guest_worktree_ids: HashMap<u64, Vec<ConnectionId>>,
 47    pub collaborator_ids: HashSet<UserId>,
 48}
 49
 50pub struct JoinedWorktree<'a> {
 51    pub replica_id: ReplicaId,
 52    pub worktree: &'a Worktree,
 53}
 54
 55pub struct WorktreeMetadata {
 56    pub connection_ids: Vec<ConnectionId>,
 57    pub collaborator_ids: Vec<UserId>,
 58}
 59
 60impl Store {
 61    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 62        self.connections.insert(
 63            connection_id,
 64            ConnectionState {
 65                user_id,
 66                worktrees: Default::default(),
 67                channels: Default::default(),
 68            },
 69        );
 70        self.connections_by_user_id
 71            .entry(user_id)
 72            .or_default()
 73            .insert(connection_id);
 74    }
 75
 76    pub fn remove_connection(
 77        &mut self,
 78        connection_id: ConnectionId,
 79    ) -> tide::Result<RemovedConnectionState> {
 80        let connection = if let Some(connection) = self.connections.get(&connection_id) {
 81            connection
 82        } else {
 83            return Err(anyhow!("no such connection"))?;
 84        };
 85
 86        for channel_id in &connection.channels {
 87            if let Some(channel) = self.channels.get_mut(&channel_id) {
 88                channel.connection_ids.remove(&connection_id);
 89            }
 90        }
 91
 92        let user_connections = self
 93            .connections_by_user_id
 94            .get_mut(&connection.user_id)
 95            .unwrap();
 96        user_connections.remove(&connection_id);
 97        if user_connections.is_empty() {
 98            self.connections_by_user_id.remove(&connection.user_id);
 99        }
100
101        let mut result = RemovedConnectionState::default();
102        for worktree_id in connection.worktrees.clone() {
103            if let Ok(worktree) = self.remove_worktree(worktree_id, connection_id) {
104                result
105                    .collaborator_ids
106                    .extend(worktree.collaborator_user_ids.iter().copied());
107                result.hosted_worktrees.insert(worktree_id, worktree);
108            } else {
109                if let Some(worktree) = self.worktrees.get(&worktree_id) {
110                    result
111                        .guest_worktree_ids
112                        .insert(worktree_id, worktree.connection_ids());
113                    result
114                        .collaborator_ids
115                        .extend(worktree.collaborator_user_ids.iter().copied());
116                }
117            }
118        }
119
120        Ok(result)
121    }
122
123    #[cfg(test)]
124    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
125        self.channels.get(&id)
126    }
127
128    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
129        if let Some(connection) = self.connections.get_mut(&connection_id) {
130            connection.channels.insert(channel_id);
131            self.channels
132                .entry(channel_id)
133                .or_default()
134                .connection_ids
135                .insert(connection_id);
136        }
137    }
138
139    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
140        if let Some(connection) = self.connections.get_mut(&connection_id) {
141            connection.channels.remove(&channel_id);
142            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
143                entry.get_mut().connection_ids.remove(&connection_id);
144                if entry.get_mut().connection_ids.is_empty() {
145                    entry.remove();
146                }
147            }
148        }
149    }
150
151    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
152        Ok(self
153            .connections
154            .get(&connection_id)
155            .ok_or_else(|| anyhow!("unknown connection"))?
156            .user_id)
157    }
158
159    pub fn connection_ids_for_user<'a>(
160        &'a self,
161        user_id: UserId,
162    ) -> impl 'a + Iterator<Item = ConnectionId> {
163        self.connections_by_user_id
164            .get(&user_id)
165            .into_iter()
166            .flatten()
167            .copied()
168    }
169
170    pub fn collaborators_for_user(&self, user_id: UserId) -> Vec<proto::Collaborator> {
171        let mut collaborators = HashMap::new();
172        for worktree_id in self
173            .visible_worktrees_by_user_id
174            .get(&user_id)
175            .unwrap_or(&HashSet::new())
176        {
177            let worktree = &self.worktrees[worktree_id];
178
179            let mut guests = HashSet::new();
180            if let Ok(share) = worktree.share() {
181                for guest_connection_id in share.guest_connection_ids.keys() {
182                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
183                        guests.insert(user_id.to_proto());
184                    }
185                }
186            }
187
188            if let Ok(host_user_id) = self
189                .user_id_for_connection(worktree.host_connection_id)
190                .context("stale worktree host connection")
191            {
192                let host =
193                    collaborators
194                        .entry(host_user_id)
195                        .or_insert_with(|| proto::Collaborator {
196                            user_id: host_user_id.to_proto(),
197                            worktrees: Vec::new(),
198                        });
199                host.worktrees.push(proto::WorktreeMetadata {
200                    root_name: worktree.root_name.clone(),
201                    is_shared: worktree.share().is_ok(),
202                    participants: guests.into_iter().collect(),
203                });
204            }
205        }
206
207        collaborators.into_values().collect()
208    }
209
210    pub fn add_worktree(&mut self, worktree: Worktree) -> u64 {
211        let worktree_id = self.next_worktree_id;
212        for collaborator_user_id in &worktree.collaborator_user_ids {
213            self.visible_worktrees_by_user_id
214                .entry(*collaborator_user_id)
215                .or_default()
216                .insert(worktree_id);
217        }
218        self.next_worktree_id += 1;
219        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
220            connection.worktrees.insert(worktree_id);
221        }
222        self.worktrees.insert(worktree_id, worktree);
223
224        #[cfg(test)]
225        self.check_invariants();
226
227        worktree_id
228    }
229
230    pub fn remove_worktree(
231        &mut self,
232        worktree_id: u64,
233        acting_connection_id: ConnectionId,
234    ) -> tide::Result<Worktree> {
235        let worktree = if let hash_map::Entry::Occupied(e) = self.worktrees.entry(worktree_id) {
236            if e.get().host_connection_id != acting_connection_id {
237                Err(anyhow!("not your worktree"))?;
238            }
239            e.remove()
240        } else {
241            return Err(anyhow!("no such worktree"))?;
242        };
243
244        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
245            connection.worktrees.remove(&worktree_id);
246        }
247
248        if let Some(share) = &worktree.share {
249            for connection_id in share.guest_connection_ids.keys() {
250                if let Some(connection) = self.connections.get_mut(connection_id) {
251                    connection.worktrees.remove(&worktree_id);
252                }
253            }
254        }
255
256        for collaborator_user_id in &worktree.collaborator_user_ids {
257            if let Some(visible_worktrees) = self
258                .visible_worktrees_by_user_id
259                .get_mut(&collaborator_user_id)
260            {
261                visible_worktrees.remove(&worktree_id);
262            }
263        }
264
265        #[cfg(test)]
266        self.check_invariants();
267
268        Ok(worktree)
269    }
270
271    pub fn share_worktree(
272        &mut self,
273        worktree_id: u64,
274        connection_id: ConnectionId,
275        entries: HashMap<u64, proto::Entry>,
276    ) -> Option<Vec<UserId>> {
277        if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
278            if worktree.host_connection_id == connection_id {
279                worktree.share = Some(WorktreeShare {
280                    guest_connection_ids: Default::default(),
281                    active_replica_ids: Default::default(),
282                    entries,
283                });
284                return Some(worktree.collaborator_user_ids.clone());
285            }
286        }
287        None
288    }
289
290    pub fn unshare_worktree(
291        &mut self,
292        worktree_id: u64,
293        acting_connection_id: ConnectionId,
294    ) -> tide::Result<WorktreeMetadata> {
295        let worktree = if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
296            worktree
297        } else {
298            return Err(anyhow!("no such worktree"))?;
299        };
300
301        if worktree.host_connection_id != acting_connection_id {
302            return Err(anyhow!("not your worktree"))?;
303        }
304
305        let connection_ids = worktree.connection_ids();
306
307        if let Some(_) = worktree.share.take() {
308            for connection_id in &connection_ids {
309                if let Some(connection) = self.connections.get_mut(connection_id) {
310                    connection.worktrees.remove(&worktree_id);
311                }
312            }
313            Ok(WorktreeMetadata {
314                connection_ids,
315                collaborator_ids: worktree.collaborator_user_ids.clone(),
316            })
317        } else {
318            Err(anyhow!("worktree is not shared"))?
319        }
320    }
321
322    pub fn join_worktree(
323        &mut self,
324        connection_id: ConnectionId,
325        user_id: UserId,
326        worktree_id: u64,
327    ) -> tide::Result<JoinedWorktree> {
328        let connection = self
329            .connections
330            .get_mut(&connection_id)
331            .ok_or_else(|| anyhow!("no such connection"))?;
332        let worktree = self
333            .worktrees
334            .get_mut(&worktree_id)
335            .and_then(|worktree| {
336                if worktree.collaborator_user_ids.contains(&user_id) {
337                    Some(worktree)
338                } else {
339                    None
340                }
341            })
342            .ok_or_else(|| anyhow!("no such worktree"))?;
343
344        let share = worktree.share_mut()?;
345        connection.worktrees.insert(worktree_id);
346
347        let mut replica_id = 1;
348        while share.active_replica_ids.contains(&replica_id) {
349            replica_id += 1;
350        }
351        share.active_replica_ids.insert(replica_id);
352        share.guest_connection_ids.insert(connection_id, replica_id);
353        Ok(JoinedWorktree {
354            replica_id,
355            worktree,
356        })
357    }
358
359    pub fn leave_worktree(
360        &mut self,
361        connection_id: ConnectionId,
362        worktree_id: u64,
363    ) -> Option<WorktreeMetadata> {
364        let worktree = self.worktrees.get_mut(&worktree_id)?;
365        let share = worktree.share.as_mut()?;
366        let replica_id = share.guest_connection_ids.remove(&connection_id)?;
367        share.active_replica_ids.remove(&replica_id);
368        Some(WorktreeMetadata {
369            connection_ids: worktree.connection_ids(),
370            collaborator_ids: worktree.collaborator_user_ids.clone(),
371        })
372    }
373
374    pub fn update_worktree(
375        &mut self,
376        connection_id: ConnectionId,
377        worktree_id: u64,
378        removed_entries: &[u64],
379        updated_entries: &[proto::Entry],
380    ) -> tide::Result<Vec<ConnectionId>> {
381        let worktree = self.write_worktree(worktree_id, connection_id)?;
382        let share = worktree.share_mut()?;
383        for entry_id in removed_entries {
384            share.entries.remove(&entry_id);
385        }
386        for entry in updated_entries {
387            share.entries.insert(entry.id, entry.clone());
388        }
389        Ok(worktree.connection_ids())
390    }
391
392    pub fn worktree_host_connection_id(
393        &self,
394        connection_id: ConnectionId,
395        worktree_id: u64,
396    ) -> tide::Result<ConnectionId> {
397        Ok(self
398            .read_worktree(worktree_id, connection_id)?
399            .host_connection_id)
400    }
401
402    pub fn worktree_guest_connection_ids(
403        &self,
404        connection_id: ConnectionId,
405        worktree_id: u64,
406    ) -> tide::Result<Vec<ConnectionId>> {
407        Ok(self
408            .read_worktree(worktree_id, connection_id)?
409            .share()?
410            .guest_connection_ids
411            .keys()
412            .copied()
413            .collect())
414    }
415
416    pub fn worktree_connection_ids(
417        &self,
418        connection_id: ConnectionId,
419        worktree_id: u64,
420    ) -> tide::Result<Vec<ConnectionId>> {
421        Ok(self
422            .read_worktree(worktree_id, connection_id)?
423            .connection_ids())
424    }
425
426    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
427        Some(self.channels.get(&channel_id)?.connection_ids())
428    }
429
430    fn read_worktree(
431        &self,
432        worktree_id: u64,
433        connection_id: ConnectionId,
434    ) -> tide::Result<&Worktree> {
435        let worktree = self
436            .worktrees
437            .get(&worktree_id)
438            .ok_or_else(|| anyhow!("worktree not found"))?;
439
440        if worktree.host_connection_id == connection_id
441            || worktree
442                .share()?
443                .guest_connection_ids
444                .contains_key(&connection_id)
445        {
446            Ok(worktree)
447        } else {
448            Err(anyhow!(
449                "{} is not a member of worktree {}",
450                connection_id,
451                worktree_id
452            ))?
453        }
454    }
455
456    fn write_worktree(
457        &mut self,
458        worktree_id: u64,
459        connection_id: ConnectionId,
460    ) -> tide::Result<&mut Worktree> {
461        let worktree = self
462            .worktrees
463            .get_mut(&worktree_id)
464            .ok_or_else(|| anyhow!("worktree not found"))?;
465
466        if worktree.host_connection_id == connection_id
467            || worktree.share.as_ref().map_or(false, |share| {
468                share.guest_connection_ids.contains_key(&connection_id)
469            })
470        {
471            Ok(worktree)
472        } else {
473            Err(anyhow!(
474                "{} is not a member of worktree {}",
475                connection_id,
476                worktree_id
477            ))?
478        }
479    }
480
481    #[cfg(test)]
482    fn check_invariants(&self) {
483        for (connection_id, connection) in &self.connections {
484            for worktree_id in &connection.worktrees {
485                let worktree = &self.worktrees.get(&worktree_id).unwrap();
486                if worktree.host_connection_id != *connection_id {
487                    assert!(worktree
488                        .share()
489                        .unwrap()
490                        .guest_connection_ids
491                        .contains_key(connection_id));
492                }
493            }
494            for channel_id in &connection.channels {
495                let channel = self.channels.get(channel_id).unwrap();
496                assert!(channel.connection_ids.contains(connection_id));
497            }
498            assert!(self
499                .connections_by_user_id
500                .get(&connection.user_id)
501                .unwrap()
502                .contains(connection_id));
503        }
504
505        for (user_id, connection_ids) in &self.connections_by_user_id {
506            for connection_id in connection_ids {
507                assert_eq!(
508                    self.connections.get(connection_id).unwrap().user_id,
509                    *user_id
510                );
511            }
512        }
513
514        for (worktree_id, worktree) in &self.worktrees {
515            let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
516            assert!(host_connection.worktrees.contains(worktree_id));
517
518            for collaborator_id in &worktree.collaborator_user_ids {
519                let visible_worktree_ids = self
520                    .visible_worktrees_by_user_id
521                    .get(collaborator_id)
522                    .unwrap();
523                assert!(visible_worktree_ids.contains(worktree_id));
524            }
525
526            if let Some(share) = &worktree.share {
527                for guest_connection_id in share.guest_connection_ids.keys() {
528                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
529                    assert!(guest_connection.worktrees.contains(worktree_id));
530                }
531                assert_eq!(
532                    share.active_replica_ids.len(),
533                    share.guest_connection_ids.len(),
534                );
535                assert_eq!(
536                    share.active_replica_ids,
537                    share
538                        .guest_connection_ids
539                        .values()
540                        .copied()
541                        .collect::<HashSet<_>>(),
542                );
543            }
544        }
545
546        for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
547            for worktree_id in visible_worktree_ids {
548                let worktree = self.worktrees.get(worktree_id).unwrap();
549                assert!(worktree.collaborator_user_ids.contains(user_id));
550            }
551        }
552
553        for (channel_id, channel) in &self.channels {
554            for connection_id in &channel.connection_ids {
555                let connection = self.connections.get(connection_id).unwrap();
556                assert!(connection.channels.contains(channel_id));
557            }
558        }
559    }
560}
561
562impl Worktree {
563    pub fn connection_ids(&self) -> Vec<ConnectionId> {
564        if let Some(share) = &self.share {
565            share
566                .guest_connection_ids
567                .keys()
568                .copied()
569                .chain(Some(self.host_connection_id))
570                .collect()
571        } else {
572            vec![self.host_connection_id]
573        }
574    }
575
576    pub fn share(&self) -> tide::Result<&WorktreeShare> {
577        Ok(self
578            .share
579            .as_ref()
580            .ok_or_else(|| anyhow!("worktree is not shared"))?)
581    }
582
583    fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
584        Ok(self
585            .share
586            .as_mut()
587            .ok_or_else(|| anyhow!("worktree is not shared"))?)
588    }
589}
590
591impl Channel {
592    fn connection_ids(&self) -> Vec<ConnectionId> {
593        self.connection_ids.iter().copied().collect()
594    }
595}