store.rs

  1use crate::db::{ChannelId, MessageId, UserId};
  2use crate::errors::TideResultExt;
  3use anyhow::anyhow;
  4use std::collections::{hash_map, HashMap, HashSet};
  5use zrpc::{proto, ConnectionId};
  6
  7#[derive(Default)]
  8pub struct Store {
  9    connections: HashMap<ConnectionId, ConnectionState>,
 10    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 11    worktrees: HashMap<u64, Worktree>,
 12    visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
 13    channels: HashMap<ChannelId, Channel>,
 14    next_worktree_id: u64,
 15}
 16
 17struct ConnectionState {
 18    user_id: UserId,
 19    worktrees: HashSet<u64>,
 20    channels: HashSet<ChannelId>,
 21}
 22
 23pub struct Worktree {
 24    pub host_connection_id: ConnectionId,
 25    pub collaborator_user_ids: Vec<UserId>,
 26    pub root_name: String,
 27    pub share: Option<WorktreeShare>,
 28}
 29
 30struct WorktreeShare {
 31    pub guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
 32    pub active_replica_ids: HashSet<ReplicaId>,
 33    pub entries: HashMap<u64, proto::Entry>,
 34}
 35
 36#[derive(Default)]
 37struct Channel {
 38    connection_ids: HashSet<ConnectionId>,
 39}
 40
 41pub type ReplicaId = u16;
 42
 43#[derive(Default)]
 44pub struct RemovedConnectionState {
 45    pub hosted_worktrees: HashMap<u64, Worktree>,
 46    pub guest_worktree_ids: HashMap<u64, Vec<ConnectionId>>,
 47    pub collaborator_ids: HashSet<UserId>,
 48}
 49
 50impl Store {
 51    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 52        self.connections.insert(
 53            connection_id,
 54            ConnectionState {
 55                user_id,
 56                worktrees: Default::default(),
 57                channels: Default::default(),
 58            },
 59        );
 60        self.connections_by_user_id
 61            .entry(user_id)
 62            .or_default()
 63            .insert(connection_id);
 64    }
 65
 66    pub fn remove_connection(
 67        &mut self,
 68        connection_id: ConnectionId,
 69    ) -> tide::Result<RemovedConnectionState> {
 70        let connection = if let Some(connection) = self.connections.get(&connection_id) {
 71            connection
 72        } else {
 73            return Err(anyhow!("no such connection"))?;
 74        };
 75
 76        for channel_id in connection.channels {
 77            if let Some(channel) = self.channels.get_mut(&channel_id) {
 78                channel.connection_ids.remove(&connection_id);
 79            }
 80        }
 81
 82        let user_connections = self
 83            .connections_by_user_id
 84            .get_mut(&connection.user_id)
 85            .unwrap();
 86        user_connections.remove(&connection_id);
 87        if user_connections.is_empty() {
 88            self.connections_by_user_id.remove(&connection.user_id);
 89        }
 90
 91        let mut result = RemovedConnectionState::default();
 92        for worktree_id in connection.worktrees {
 93            if let Ok(worktree) = self.remove_worktree(worktree_id, connection_id) {
 94                result.hosted_worktrees.insert(worktree_id, worktree);
 95                result
 96                    .collaborator_ids
 97                    .extend(worktree.collaborator_user_ids.iter().copied());
 98            } else {
 99                if let Some(worktree) = self.worktrees.get(&worktree_id) {
100                    result
101                        .guest_worktree_ids
102                        .insert(worktree_id, worktree.connection_ids());
103                    result
104                        .collaborator_ids
105                        .extend(worktree.collaborator_user_ids.iter().copied());
106                }
107            }
108        }
109
110        Ok(result)
111    }
112
113    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
114        if let Some(connection) = self.connections.get_mut(&connection_id) {
115            connection.channels.insert(channel_id);
116            self.channels
117                .entry(channel_id)
118                .or_default()
119                .connection_ids
120                .insert(connection_id);
121        }
122    }
123
124    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
125        if let Some(connection) = self.connections.get_mut(&connection_id) {
126            connection.channels.remove(&channel_id);
127            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
128                entry.get_mut().connection_ids.remove(&connection_id);
129                if entry.get_mut().connection_ids.is_empty() {
130                    entry.remove();
131                }
132            }
133        }
134    }
135
136    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
137        Ok(self
138            .connections
139            .get(&connection_id)
140            .ok_or_else(|| anyhow!("unknown connection"))?
141            .user_id)
142    }
143
144    pub fn connection_ids_for_user<'a>(
145        &'a self,
146        user_id: UserId,
147    ) -> impl 'a + Iterator<Item = ConnectionId> {
148        self.connections_by_user_id
149            .get(&user_id)
150            .into_iter()
151            .flatten()
152            .copied()
153    }
154
155    pub fn collaborators_for_user(&self, user_id: UserId) -> Vec<proto::Collaborator> {
156        let mut collaborators = HashMap::new();
157        for worktree_id in self
158            .visible_worktrees_by_user_id
159            .get(&user_id)
160            .unwrap_or(&HashSet::new())
161        {
162            let worktree = &self.worktrees[worktree_id];
163
164            let mut guests = HashSet::new();
165            if let Ok(share) = worktree.share() {
166                for guest_connection_id in share.guest_connection_ids.keys() {
167                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
168                        guests.insert(user_id.to_proto());
169                    }
170                }
171            }
172
173            if let Ok(host_user_id) = self
174                .user_id_for_connection(worktree.host_connection_id)
175                .context("stale worktree host connection")
176            {
177                let host =
178                    collaborators
179                        .entry(host_user_id)
180                        .or_insert_with(|| proto::Collaborator {
181                            user_id: host_user_id.to_proto(),
182                            worktrees: Vec::new(),
183                        });
184                host.worktrees.push(proto::WorktreeMetadata {
185                    root_name: worktree.root_name.clone(),
186                    is_shared: worktree.share().is_ok(),
187                    participants: guests.into_iter().collect(),
188                });
189            }
190        }
191
192        collaborators.into_values().collect()
193    }
194
195    pub fn add_worktree(&mut self, worktree: Worktree) -> u64 {
196        let worktree_id = self.next_worktree_id;
197        for collaborator_user_id in &worktree.collaborator_user_ids {
198            self.visible_worktrees_by_user_id
199                .entry(*collaborator_user_id)
200                .or_default()
201                .insert(worktree_id);
202        }
203        self.next_worktree_id += 1;
204        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
205            connection.worktrees.insert(worktree_id);
206        }
207        self.worktrees.insert(worktree_id, worktree);
208
209        #[cfg(test)]
210        self.check_invariants();
211
212        worktree_id
213    }
214
215    pub fn remove_worktree(
216        &mut self,
217        worktree_id: u64,
218        acting_connection_id: ConnectionId,
219    ) -> tide::Result<Worktree> {
220        let worktree = if let hash_map::Entry::Occupied(e) = self.worktrees.entry(worktree_id) {
221            if e.get().host_connection_id != acting_connection_id {
222                Err(anyhow!("not your worktree"))?;
223            }
224            e.remove()
225        } else {
226            return Err(anyhow!("no such worktree"))?;
227        };
228
229        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
230            connection.worktrees.remove(&worktree_id);
231        }
232
233        if let Some(share) = worktree.share {
234            for connection_id in share.guest_connection_ids.keys() {
235                if let Some(connection) = self.connections.get_mut(connection_id) {
236                    connection.worktrees.remove(&worktree_id);
237                }
238            }
239        }
240
241        for collaborator_user_id in worktree.collaborator_user_ids {
242            if let Some(visible_worktrees) = self
243                .visible_worktrees_by_user_id
244                .get_mut(&collaborator_user_id)
245            {
246                visible_worktrees.remove(&worktree_id);
247            }
248        }
249
250        #[cfg(test)]
251        self.check_invariants();
252
253        Ok(worktree)
254    }
255
256    pub fn share_worktree(
257        &mut self,
258        worktree_id: u64,
259        connection_id: ConnectionId,
260        entries: HashMap<u64, proto::Entry>,
261    ) -> Option<Vec<UserId>> {
262        if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
263            if worktree.host_connection_id == connection_id {
264                worktree.share = Some(WorktreeShare {
265                    guest_connection_ids: Default::default(),
266                    active_replica_ids: Default::default(),
267                    entries,
268                });
269                return Some(worktree.collaborator_user_ids.clone());
270            }
271        }
272        None
273    }
274
275    pub fn unshare_worktree(
276        &mut self,
277        worktree_id: u64,
278        acting_connection_id: ConnectionId,
279    ) -> tide::Result<(Vec<ConnectionId>, Vec<UserId>)> {
280        let worktree = if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
281            worktree
282        } else {
283            return Err(anyhow!("no such worktree"))?;
284        };
285
286        if worktree.host_connection_id != acting_connection_id {
287            return Err(anyhow!("not your worktree"))?;
288        }
289
290        let connection_ids = worktree.connection_ids();
291
292        if let Some(share) = worktree.share.take() {
293            for connection_id in &connection_ids {
294                if let Some(connection) = self.connections.get_mut(connection_id) {
295                    connection.worktrees.remove(&worktree_id);
296                }
297            }
298            Ok((connection_ids, worktree.collaborator_user_ids.clone()))
299        } else {
300            Err(anyhow!("worktree is not shared"))?
301        }
302    }
303
304    pub fn join_worktree(
305        &mut self,
306        connection_id: ConnectionId,
307        user_id: UserId,
308        worktree_id: u64,
309    ) -> tide::Result<(ReplicaId, &Worktree)> {
310        let connection = self
311            .connections
312            .get_mut(&connection_id)
313            .ok_or_else(|| anyhow!("no such connection"))?;
314        let worktree = self
315            .worktrees
316            .get_mut(&worktree_id)
317            .and_then(|worktree| {
318                if worktree.collaborator_user_ids.contains(&user_id) {
319                    Some(worktree)
320                } else {
321                    None
322                }
323            })
324            .ok_or_else(|| anyhow!("no such worktree"))?;
325
326        let share = worktree.share_mut()?;
327        connection.worktrees.insert(worktree_id);
328
329        let mut replica_id = 1;
330        while share.active_replica_ids.contains(&replica_id) {
331            replica_id += 1;
332        }
333        share.active_replica_ids.insert(replica_id);
334        share.guest_connection_ids.insert(connection_id, replica_id);
335        return Ok((replica_id, worktree));
336    }
337
338    pub fn leave_worktree(
339        &mut self,
340        connection_id: ConnectionId,
341        worktree_id: u64,
342    ) -> Option<(Vec<ConnectionId>, Vec<UserId>)> {
343        let worktree = self.worktrees.get_mut(&worktree_id)?;
344        let share = worktree.share.as_mut()?;
345        let replica_id = share.guest_connection_ids.remove(&connection_id)?;
346        share.active_replica_ids.remove(&replica_id);
347        Some((
348            worktree.connection_ids(),
349            worktree.collaborator_user_ids.clone(),
350        ))
351    }
352
353    pub fn update_worktree(
354        &mut self,
355        connection_id: ConnectionId,
356        worktree_id: u64,
357        removed_entries: &[u64],
358        updated_entries: &[proto::Entry],
359    ) -> tide::Result<Vec<ConnectionId>> {
360        let worktree = self.write_worktree(worktree_id, connection_id)?;
361        let share = worktree.share_mut()?;
362        for entry_id in removed_entries {
363            share.entries.remove(&entry_id);
364        }
365        for entry in updated_entries {
366            share.entries.insert(entry.id, entry.clone());
367        }
368        Ok(worktree.connection_ids())
369    }
370
371    pub fn worktree_host_connection_id(
372        &self,
373        connection_id: ConnectionId,
374        worktree_id: u64,
375    ) -> tide::Result<ConnectionId> {
376        Ok(self
377            .read_worktree(worktree_id, connection_id)?
378            .host_connection_id)
379    }
380
381    pub fn worktree_guest_connection_ids(
382        &self,
383        connection_id: ConnectionId,
384        worktree_id: u64,
385    ) -> tide::Result<Vec<ConnectionId>> {
386        Ok(self
387            .read_worktree(worktree_id, connection_id)?
388            .share()?
389            .guest_connection_ids
390            .keys()
391            .copied()
392            .collect())
393    }
394
395    pub fn worktree_connection_ids(
396        &self,
397        connection_id: ConnectionId,
398        worktree_id: u64,
399    ) -> tide::Result<Vec<ConnectionId>> {
400        Ok(self
401            .read_worktree(worktree_id, connection_id)?
402            .connection_ids())
403    }
404
405    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
406        Some(self.channels.get(&channel_id)?.connection_ids())
407    }
408
409    fn read_worktree(
410        &self,
411        worktree_id: u64,
412        connection_id: ConnectionId,
413    ) -> tide::Result<&Worktree> {
414        let worktree = self
415            .worktrees
416            .get(&worktree_id)
417            .ok_or_else(|| anyhow!("worktree not found"))?;
418
419        if worktree.host_connection_id == connection_id
420            || worktree
421                .share()?
422                .guest_connection_ids
423                .contains_key(&connection_id)
424        {
425            Ok(worktree)
426        } else {
427            Err(anyhow!(
428                "{} is not a member of worktree {}",
429                connection_id,
430                worktree_id
431            ))?
432        }
433    }
434
435    fn write_worktree(
436        &mut self,
437        worktree_id: u64,
438        connection_id: ConnectionId,
439    ) -> tide::Result<&mut Worktree> {
440        let worktree = self
441            .worktrees
442            .get_mut(&worktree_id)
443            .ok_or_else(|| anyhow!("worktree not found"))?;
444
445        if worktree.host_connection_id == connection_id
446            || worktree.share.as_ref().map_or(false, |share| {
447                share.guest_connection_ids.contains_key(&connection_id)
448            })
449        {
450            Ok(worktree)
451        } else {
452            Err(anyhow!(
453                "{} is not a member of worktree {}",
454                connection_id,
455                worktree_id
456            ))?
457        }
458    }
459
460    #[cfg(test)]
461    fn check_invariants(&self) {
462        for (connection_id, connection) in &self.connections {
463            for worktree_id in &connection.worktrees {
464                let worktree = &self.worktrees.get(&worktree_id).unwrap();
465                if worktree.host_connection_id != *connection_id {
466                    assert!(worktree
467                        .share()
468                        .unwrap()
469                        .guest_connection_ids
470                        .contains_key(connection_id));
471                }
472            }
473            for channel_id in &connection.channels {
474                let channel = self.channels.get(channel_id).unwrap();
475                assert!(channel.connection_ids.contains(connection_id));
476            }
477            assert!(self
478                .connections_by_user_id
479                .get(&connection.user_id)
480                .unwrap()
481                .contains(connection_id));
482        }
483
484        for (user_id, connection_ids) in &self.connections_by_user_id {
485            for connection_id in connection_ids {
486                assert_eq!(
487                    self.connections.get(connection_id).unwrap().user_id,
488                    *user_id
489                );
490            }
491        }
492
493        for (worktree_id, worktree) in &self.worktrees {
494            let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
495            assert!(host_connection.worktrees.contains(worktree_id));
496
497            for collaborator_id in &worktree.collaborator_user_ids {
498                let visible_worktree_ids = self
499                    .visible_worktrees_by_user_id
500                    .get(collaborator_id)
501                    .unwrap();
502                assert!(visible_worktree_ids.contains(worktree_id));
503            }
504
505            if let Some(share) = &worktree.share {
506                for guest_connection_id in share.guest_connection_ids.keys() {
507                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
508                    assert!(guest_connection.worktrees.contains(worktree_id));
509                }
510                assert_eq!(
511                    share.active_replica_ids.len(),
512                    share.guest_connection_ids.len(),
513                );
514                assert_eq!(
515                    share.active_replica_ids,
516                    share
517                        .guest_connection_ids
518                        .values()
519                        .copied()
520                        .collect::<HashSet<_>>(),
521                );
522            }
523        }
524
525        for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
526            for worktree_id in visible_worktree_ids {
527                let worktree = self.worktrees.get(worktree_id).unwrap();
528                assert!(worktree.collaborator_user_ids.contains(user_id));
529            }
530        }
531
532        for (channel_id, channel) in &self.channels {
533            for connection_id in &channel.connection_ids {
534                let connection = self.connections.get(connection_id).unwrap();
535                assert!(connection.channels.contains(channel_id));
536            }
537        }
538    }
539}
540
541impl Worktree {
542    pub fn connection_ids(&self) -> Vec<ConnectionId> {
543        if let Some(share) = &self.share {
544            share
545                .guest_connection_ids
546                .keys()
547                .copied()
548                .chain(Some(self.host_connection_id))
549                .collect()
550        } else {
551            vec![self.host_connection_id]
552        }
553    }
554
555    pub fn share(&self) -> tide::Result<&WorktreeShare> {
556        Ok(self
557            .share
558            .as_ref()
559            .ok_or_else(|| anyhow!("worktree is not shared"))?)
560    }
561
562    fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
563        Ok(self
564            .share
565            .as_mut()
566            .ok_or_else(|| anyhow!("worktree is not shared"))?)
567    }
568}
569
570impl Channel {
571    fn connection_ids(&self) -> Vec<ConnectionId> {
572        self.connection_ids.iter().copied().collect()
573    }
574}