store.rs

  1use crate::db::{self, ChannelId, UserId};
  2use anyhow::{anyhow, Result};
  3use collections::{BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId, Receipt};
  5use std::{collections::hash_map, path::PathBuf};
  6use tracing::instrument;
  7
  8#[derive(Default)]
  9pub struct Store {
 10    connections: HashMap<ConnectionId, ConnectionState>,
 11    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 12    projects: HashMap<u64, Project>,
 13    channels: HashMap<ChannelId, Channel>,
 14    next_project_id: u64,
 15}
 16
 17struct ConnectionState {
 18    user_id: UserId,
 19    projects: HashSet<u64>,
 20    requested_projects: HashSet<u64>,
 21    channels: HashSet<ChannelId>,
 22}
 23
 24pub struct Project {
 25    pub host_connection_id: ConnectionId,
 26    pub host_user_id: UserId,
 27    pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
 28    pub join_requests: HashMap<UserId, Vec<Receipt<proto::JoinProject>>>,
 29    pub active_replica_ids: HashSet<ReplicaId>,
 30    pub worktrees: HashMap<u64, Worktree>,
 31    pub language_servers: Vec<proto::LanguageServer>,
 32}
 33
 34#[derive(Default)]
 35pub struct Worktree {
 36    pub root_name: String,
 37    pub visible: bool,
 38    pub entries: HashMap<u64, proto::Entry>,
 39    pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
 40    pub scan_id: u64,
 41}
 42
 43#[derive(Default)]
 44pub struct Channel {
 45    pub connection_ids: HashSet<ConnectionId>,
 46}
 47
 48pub type ReplicaId = u16;
 49
 50#[derive(Default)]
 51pub struct RemovedConnectionState {
 52    pub user_id: UserId,
 53    pub hosted_projects: HashMap<u64, Project>,
 54    pub guest_project_ids: HashSet<u64>,
 55    pub contact_ids: HashSet<UserId>,
 56}
 57
 58pub struct LeftProject {
 59    pub connection_ids: Vec<ConnectionId>,
 60    pub host_user_id: UserId,
 61    pub host_connection_id: ConnectionId,
 62}
 63
 64#[derive(Copy, Clone)]
 65pub struct Metrics {
 66    pub connections: usize,
 67    pub registered_projects: usize,
 68    pub shared_projects: usize,
 69}
 70
 71impl Store {
 72    pub fn metrics(&self) -> Metrics {
 73        let connections = self.connections.len();
 74        let mut registered_projects = 0;
 75        let mut shared_projects = 0;
 76        for project in self.projects.values() {
 77            registered_projects += 1;
 78            if !project.guests.is_empty() {
 79                shared_projects += 1;
 80            }
 81        }
 82
 83        Metrics {
 84            connections,
 85            registered_projects,
 86            shared_projects,
 87        }
 88    }
 89
 90    #[instrument(skip(self))]
 91    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 92        self.connections.insert(
 93            connection_id,
 94            ConnectionState {
 95                user_id,
 96                projects: Default::default(),
 97                requested_projects: Default::default(),
 98                channels: Default::default(),
 99            },
100        );
101        self.connections_by_user_id
102            .entry(user_id)
103            .or_default()
104            .insert(connection_id);
105    }
106
107    #[instrument(skip(self))]
108    pub fn remove_connection(
109        &mut self,
110        connection_id: ConnectionId,
111    ) -> Result<RemovedConnectionState> {
112        let connection = if let Some(connection) = self.connections.remove(&connection_id) {
113            connection
114        } else {
115            return Err(anyhow!("no such connection"))?;
116        };
117
118        for channel_id in &connection.channels {
119            if let Some(channel) = self.channels.get_mut(&channel_id) {
120                channel.connection_ids.remove(&connection_id);
121            }
122        }
123
124        let user_connections = self
125            .connections_by_user_id
126            .get_mut(&connection.user_id)
127            .unwrap();
128        user_connections.remove(&connection_id);
129        if user_connections.is_empty() {
130            self.connections_by_user_id.remove(&connection.user_id);
131        }
132
133        let mut result = RemovedConnectionState::default();
134        result.user_id = connection.user_id;
135        for project_id in connection.projects.clone() {
136            if let Ok(project) = self.unregister_project(project_id, connection_id) {
137                result.hosted_projects.insert(project_id, project);
138            } else if self.leave_project(connection_id, project_id).is_ok() {
139                result.guest_project_ids.insert(project_id);
140            }
141        }
142
143        Ok(result)
144    }
145
146    #[cfg(test)]
147    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
148        self.channels.get(&id)
149    }
150
151    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
152        if let Some(connection) = self.connections.get_mut(&connection_id) {
153            connection.channels.insert(channel_id);
154            self.channels
155                .entry(channel_id)
156                .or_default()
157                .connection_ids
158                .insert(connection_id);
159        }
160    }
161
162    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
163        if let Some(connection) = self.connections.get_mut(&connection_id) {
164            connection.channels.remove(&channel_id);
165            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
166                entry.get_mut().connection_ids.remove(&connection_id);
167                if entry.get_mut().connection_ids.is_empty() {
168                    entry.remove();
169                }
170            }
171        }
172    }
173
174    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> Result<UserId> {
175        Ok(self
176            .connections
177            .get(&connection_id)
178            .ok_or_else(|| anyhow!("unknown connection"))?
179            .user_id)
180    }
181
182    pub fn connection_ids_for_user<'a>(
183        &'a self,
184        user_id: UserId,
185    ) -> impl 'a + Iterator<Item = ConnectionId> {
186        self.connections_by_user_id
187            .get(&user_id)
188            .into_iter()
189            .flatten()
190            .copied()
191    }
192
193    pub fn is_user_online(&self, user_id: UserId) -> bool {
194        !self
195            .connections_by_user_id
196            .get(&user_id)
197            .unwrap_or(&Default::default())
198            .is_empty()
199    }
200
201    pub fn build_initial_contacts_update(
202        &self,
203        contacts: Vec<db::Contact>,
204    ) -> proto::UpdateContacts {
205        let mut update = proto::UpdateContacts::default();
206
207        for contact in contacts {
208            match contact {
209                db::Contact::Accepted {
210                    user_id,
211                    should_notify,
212                } => {
213                    update
214                        .contacts
215                        .push(self.contact_for_user(user_id, should_notify));
216                }
217                db::Contact::Outgoing { user_id } => {
218                    update.outgoing_requests.push(user_id.to_proto())
219                }
220                db::Contact::Incoming {
221                    user_id,
222                    should_notify,
223                } => update
224                    .incoming_requests
225                    .push(proto::IncomingContactRequest {
226                        requester_id: user_id.to_proto(),
227                        should_notify,
228                    }),
229            }
230        }
231
232        update
233    }
234
235    pub fn contact_for_user(&self, user_id: UserId, should_notify: bool) -> proto::Contact {
236        proto::Contact {
237            user_id: user_id.to_proto(),
238            projects: self.project_metadata_for_user(user_id),
239            online: self.is_user_online(user_id),
240            should_notify,
241        }
242    }
243
244    pub fn project_metadata_for_user(&self, user_id: UserId) -> Vec<proto::ProjectMetadata> {
245        let connection_ids = self.connections_by_user_id.get(&user_id);
246        let project_ids = connection_ids.iter().flat_map(|connection_ids| {
247            connection_ids
248                .iter()
249                .filter_map(|connection_id| self.connections.get(connection_id))
250                .flat_map(|connection| connection.projects.iter().copied())
251        });
252
253        let mut metadata = Vec::new();
254        for project_id in project_ids {
255            if let Some(project) = self.projects.get(&project_id) {
256                if project.host_user_id == user_id {
257                    metadata.push(proto::ProjectMetadata {
258                        id: project_id,
259                        worktree_root_names: project
260                            .worktrees
261                            .values()
262                            .map(|worktree| worktree.root_name.clone())
263                            .collect(),
264                        guests: project
265                            .guests
266                            .values()
267                            .map(|(_, user_id)| user_id.to_proto())
268                            .collect(),
269                    });
270                }
271            }
272        }
273
274        metadata
275    }
276
277    pub fn register_project(
278        &mut self,
279        host_connection_id: ConnectionId,
280        host_user_id: UserId,
281    ) -> u64 {
282        let project_id = self.next_project_id;
283        self.projects.insert(
284            project_id,
285            Project {
286                host_connection_id,
287                host_user_id,
288                guests: Default::default(),
289                join_requests: Default::default(),
290                active_replica_ids: Default::default(),
291                worktrees: Default::default(),
292                language_servers: Default::default(),
293            },
294        );
295        if let Some(connection) = self.connections.get_mut(&host_connection_id) {
296            connection.projects.insert(project_id);
297        }
298        self.next_project_id += 1;
299        project_id
300    }
301
302    pub fn register_worktree(
303        &mut self,
304        project_id: u64,
305        worktree_id: u64,
306        connection_id: ConnectionId,
307        worktree: Worktree,
308    ) -> Result<()> {
309        let project = self
310            .projects
311            .get_mut(&project_id)
312            .ok_or_else(|| anyhow!("no such project"))?;
313        if project.host_connection_id == connection_id {
314            project.worktrees.insert(worktree_id, worktree);
315            Ok(())
316        } else {
317            Err(anyhow!("no such project"))?
318        }
319    }
320
321    pub fn unregister_project(
322        &mut self,
323        project_id: u64,
324        connection_id: ConnectionId,
325    ) -> Result<Project> {
326        match self.projects.entry(project_id) {
327            hash_map::Entry::Occupied(e) => {
328                if e.get().host_connection_id == connection_id {
329                    let project = e.remove();
330
331                    if let Some(host_connection) = self.connections.get_mut(&connection_id) {
332                        host_connection.projects.remove(&project_id);
333                    }
334
335                    for guest_connection in project.guests.keys() {
336                        if let Some(connection) = self.connections.get_mut(&guest_connection) {
337                            connection.projects.remove(&project_id);
338                        }
339                    }
340
341                    for requester_user_id in project.join_requests.keys() {
342                        if let Some(requester_connection_ids) =
343                            self.connections_by_user_id.get_mut(&requester_user_id)
344                        {
345                            for requester_connection_id in requester_connection_ids.iter() {
346                                if let Some(requester_connection) =
347                                    self.connections.get_mut(requester_connection_id)
348                                {
349                                    requester_connection.requested_projects.remove(&project_id);
350                                }
351                            }
352                        }
353                    }
354
355                    Ok(project)
356                } else {
357                    Err(anyhow!("no such project"))?
358                }
359            }
360            hash_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?,
361        }
362    }
363
364    pub fn unregister_worktree(
365        &mut self,
366        project_id: u64,
367        worktree_id: u64,
368        acting_connection_id: ConnectionId,
369    ) -> Result<(Worktree, Vec<ConnectionId>)> {
370        let project = self
371            .projects
372            .get_mut(&project_id)
373            .ok_or_else(|| anyhow!("no such project"))?;
374        if project.host_connection_id != acting_connection_id {
375            Err(anyhow!("not your worktree"))?;
376        }
377
378        let worktree = project
379            .worktrees
380            .remove(&worktree_id)
381            .ok_or_else(|| anyhow!("no such worktree"))?;
382        Ok((worktree, project.guest_connection_ids()))
383    }
384
385    pub fn update_diagnostic_summary(
386        &mut self,
387        project_id: u64,
388        worktree_id: u64,
389        connection_id: ConnectionId,
390        summary: proto::DiagnosticSummary,
391    ) -> Result<Vec<ConnectionId>> {
392        let project = self
393            .projects
394            .get_mut(&project_id)
395            .ok_or_else(|| anyhow!("no such project"))?;
396        if project.host_connection_id == connection_id {
397            let worktree = project
398                .worktrees
399                .get_mut(&worktree_id)
400                .ok_or_else(|| anyhow!("no such worktree"))?;
401            worktree
402                .diagnostic_summaries
403                .insert(summary.path.clone().into(), summary);
404            return Ok(project.connection_ids());
405        }
406
407        Err(anyhow!("no such worktree"))?
408    }
409
410    pub fn start_language_server(
411        &mut self,
412        project_id: u64,
413        connection_id: ConnectionId,
414        language_server: proto::LanguageServer,
415    ) -> Result<Vec<ConnectionId>> {
416        let project = self
417            .projects
418            .get_mut(&project_id)
419            .ok_or_else(|| anyhow!("no such project"))?;
420        if project.host_connection_id == connection_id {
421            project.language_servers.push(language_server);
422            return Ok(project.connection_ids());
423        }
424
425        Err(anyhow!("no such project"))?
426    }
427
428    pub fn request_join_project(
429        &mut self,
430        requester_id: UserId,
431        project_id: u64,
432        receipt: Receipt<proto::JoinProject>,
433    ) -> Result<()> {
434        let connection = self
435            .connections
436            .get_mut(&receipt.sender_id)
437            .ok_or_else(|| anyhow!("no such connection"))?;
438        let project = self
439            .projects
440            .get_mut(&project_id)
441            .ok_or_else(|| anyhow!("no such project"))?;
442        connection.requested_projects.insert(project_id);
443        project
444            .join_requests
445            .entry(requester_id)
446            .or_default()
447            .push(receipt);
448        Ok(())
449    }
450
451    pub fn deny_join_project_request(
452        &mut self,
453        responder_connection_id: ConnectionId,
454        requester_id: UserId,
455        project_id: u64,
456    ) -> Option<Vec<Receipt<proto::JoinProject>>> {
457        let project = self.projects.get_mut(&project_id)?;
458        if responder_connection_id != project.host_connection_id {
459            return None;
460        }
461
462        let receipts = project.join_requests.remove(&requester_id)?;
463        for receipt in &receipts {
464            let requester_connection = self.connections.get_mut(&receipt.sender_id)?;
465            requester_connection.requested_projects.remove(&project_id);
466        }
467        Some(receipts)
468    }
469
470    pub fn accept_join_project_request(
471        &mut self,
472        responder_connection_id: ConnectionId,
473        requester_id: UserId,
474        project_id: u64,
475    ) -> Option<(Vec<(Receipt<proto::JoinProject>, ReplicaId)>, &Project)> {
476        let project = self.projects.get_mut(&project_id)?;
477        if responder_connection_id != project.host_connection_id {
478            return None;
479        }
480
481        let receipts = project.join_requests.remove(&requester_id)?;
482        let mut receipts_with_replica_ids = Vec::new();
483        for receipt in receipts {
484            let requester_connection = self.connections.get_mut(&receipt.sender_id)?;
485            requester_connection.requested_projects.remove(&project_id);
486            requester_connection.projects.insert(project_id);
487            let mut replica_id = 1;
488            while project.active_replica_ids.contains(&replica_id) {
489                replica_id += 1;
490            }
491            project.active_replica_ids.insert(replica_id);
492            project
493                .guests
494                .insert(receipt.sender_id, (replica_id, requester_id));
495            receipts_with_replica_ids.push((receipt, replica_id));
496        }
497
498        Some((receipts_with_replica_ids, project))
499    }
500
501    pub fn leave_project(
502        &mut self,
503        connection_id: ConnectionId,
504        project_id: u64,
505    ) -> Result<LeftProject> {
506        let project = self
507            .projects
508            .get_mut(&project_id)
509            .ok_or_else(|| anyhow!("no such project"))?;
510        let (replica_id, _) = project
511            .guests
512            .remove(&connection_id)
513            .ok_or_else(|| anyhow!("cannot leave a project before joining it"))?;
514        project.active_replica_ids.remove(&replica_id);
515
516        if let Some(connection) = self.connections.get_mut(&connection_id) {
517            connection.projects.remove(&project_id);
518        }
519
520        Ok(LeftProject {
521            connection_ids: project.connection_ids(),
522            host_connection_id: project.host_connection_id,
523            host_user_id: project.host_user_id,
524        })
525    }
526
527    pub fn update_worktree(
528        &mut self,
529        connection_id: ConnectionId,
530        project_id: u64,
531        worktree_id: u64,
532        removed_entries: &[u64],
533        updated_entries: &[proto::Entry],
534        scan_id: u64,
535    ) -> Result<Vec<ConnectionId>> {
536        let project = self.write_project(project_id, connection_id)?;
537        let worktree = project
538            .worktrees
539            .get_mut(&worktree_id)
540            .ok_or_else(|| anyhow!("no such worktree"))?;
541        for entry_id in removed_entries {
542            worktree.entries.remove(&entry_id);
543        }
544        for entry in updated_entries {
545            worktree.entries.insert(entry.id, entry.clone());
546        }
547        worktree.scan_id = scan_id;
548        let connection_ids = project.connection_ids();
549        Ok(connection_ids)
550    }
551
552    pub fn project_connection_ids(
553        &self,
554        project_id: u64,
555        acting_connection_id: ConnectionId,
556    ) -> Result<Vec<ConnectionId>> {
557        Ok(self
558            .read_project(project_id, acting_connection_id)?
559            .connection_ids())
560    }
561
562    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Result<Vec<ConnectionId>> {
563        Ok(self
564            .channels
565            .get(&channel_id)
566            .ok_or_else(|| anyhow!("no such channel"))?
567            .connection_ids())
568    }
569
570    pub fn project(&self, project_id: u64) -> Result<&Project> {
571        self.projects
572            .get(&project_id)
573            .ok_or_else(|| anyhow!("no such project"))
574    }
575
576    pub fn read_project(&self, project_id: u64, connection_id: ConnectionId) -> Result<&Project> {
577        let project = self
578            .projects
579            .get(&project_id)
580            .ok_or_else(|| anyhow!("no such project"))?;
581        if project.host_connection_id == connection_id
582            || project.guests.contains_key(&connection_id)
583        {
584            Ok(project)
585        } else {
586            Err(anyhow!("no such project"))?
587        }
588    }
589
590    fn write_project(
591        &mut self,
592        project_id: u64,
593        connection_id: ConnectionId,
594    ) -> Result<&mut Project> {
595        let project = self
596            .projects
597            .get_mut(&project_id)
598            .ok_or_else(|| anyhow!("no such project"))?;
599        if project.host_connection_id == connection_id
600            || project.guests.contains_key(&connection_id)
601        {
602            Ok(project)
603        } else {
604            Err(anyhow!("no such project"))?
605        }
606    }
607
608    #[cfg(test)]
609    pub fn check_invariants(&self) {
610        for (connection_id, connection) in &self.connections {
611            for project_id in &connection.projects {
612                let project = &self.projects.get(&project_id).unwrap();
613                if project.host_connection_id != *connection_id {
614                    assert!(project.guests.contains_key(connection_id));
615                }
616
617                for (worktree_id, worktree) in project.worktrees.iter() {
618                    let mut paths = HashMap::default();
619                    for entry in worktree.entries.values() {
620                        let prev_entry = paths.insert(&entry.path, entry);
621                        assert_eq!(
622                            prev_entry,
623                            None,
624                            "worktree {:?}, duplicate path for entries {:?} and {:?}",
625                            worktree_id,
626                            prev_entry.unwrap(),
627                            entry
628                        );
629                    }
630                }
631            }
632            for channel_id in &connection.channels {
633                let channel = self.channels.get(channel_id).unwrap();
634                assert!(channel.connection_ids.contains(connection_id));
635            }
636            assert!(self
637                .connections_by_user_id
638                .get(&connection.user_id)
639                .unwrap()
640                .contains(connection_id));
641        }
642
643        for (user_id, connection_ids) in &self.connections_by_user_id {
644            for connection_id in connection_ids {
645                assert_eq!(
646                    self.connections.get(connection_id).unwrap().user_id,
647                    *user_id
648                );
649            }
650        }
651
652        for (project_id, project) in &self.projects {
653            let host_connection = self.connections.get(&project.host_connection_id).unwrap();
654            assert!(host_connection.projects.contains(project_id));
655
656            for guest_connection_id in project.guests.keys() {
657                let guest_connection = self.connections.get(guest_connection_id).unwrap();
658                assert!(guest_connection.projects.contains(project_id));
659            }
660            assert_eq!(project.active_replica_ids.len(), project.guests.len(),);
661            assert_eq!(
662                project.active_replica_ids,
663                project
664                    .guests
665                    .values()
666                    .map(|(replica_id, _)| *replica_id)
667                    .collect::<HashSet<_>>(),
668            );
669        }
670
671        for (channel_id, channel) in &self.channels {
672            for connection_id in &channel.connection_ids {
673                let connection = self.connections.get(connection_id).unwrap();
674                assert!(connection.channels.contains(channel_id));
675            }
676        }
677    }
678}
679
680impl Project {
681    pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
682        self.guests.keys().copied().collect()
683    }
684
685    pub fn connection_ids(&self) -> Vec<ConnectionId> {
686        self.guests
687            .keys()
688            .copied()
689            .chain(Some(self.host_connection_id))
690            .collect()
691    }
692}
693
694impl Channel {
695    fn connection_ids(&self) -> Vec<ConnectionId> {
696        self.connection_ids.iter().copied().collect()
697    }
698}