store.rs

  1use crate::db::{self, ChannelId, UserId};
  2use anyhow::{anyhow, Result};
  3use collections::{hash_map::Entry, BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId, Receipt};
  5use serde::Serialize;
  6use std::{collections::hash_map, mem, path::PathBuf};
  7use tracing::instrument;
  8
  9#[derive(Default, Serialize)]
 10pub struct Store {
 11    connections: HashMap<ConnectionId, ConnectionState>,
 12    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 13    projects: HashMap<u64, Project>,
 14    #[serde(skip)]
 15    channels: HashMap<ChannelId, Channel>,
 16    next_project_id: u64,
 17}
 18
 19#[derive(Serialize)]
 20struct ConnectionState {
 21    user_id: UserId,
 22    projects: HashSet<u64>,
 23    requested_projects: HashSet<u64>,
 24    channels: HashSet<ChannelId>,
 25}
 26
 27#[derive(Serialize)]
 28pub struct Project {
 29    pub host_connection_id: ConnectionId,
 30    pub host_user_id: UserId,
 31    pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
 32    #[serde(skip)]
 33    pub join_requests: HashMap<UserId, Vec<Receipt<proto::JoinProject>>>,
 34    pub active_replica_ids: HashSet<ReplicaId>,
 35    pub worktrees: HashMap<u64, Worktree>,
 36    pub language_servers: Vec<proto::LanguageServer>,
 37}
 38
 39#[derive(Default, Serialize)]
 40pub struct Worktree {
 41    pub root_name: String,
 42    pub visible: bool,
 43    #[serde(skip)]
 44    pub entries: HashMap<u64, proto::Entry>,
 45    #[serde(skip)]
 46    pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
 47    pub scan_id: u64,
 48}
 49
 50#[derive(Default)]
 51pub struct Channel {
 52    pub connection_ids: HashSet<ConnectionId>,
 53}
 54
 55pub type ReplicaId = u16;
 56
 57#[derive(Default)]
 58pub struct RemovedConnectionState {
 59    pub user_id: UserId,
 60    pub hosted_projects: HashMap<u64, Project>,
 61    pub guest_project_ids: HashSet<u64>,
 62    pub contact_ids: HashSet<UserId>,
 63}
 64
 65pub struct LeftProject {
 66    pub host_user_id: UserId,
 67    pub host_connection_id: ConnectionId,
 68    pub connection_ids: Vec<ConnectionId>,
 69    pub remove_collaborator: bool,
 70    pub cancel_request: Option<UserId>,
 71    pub unshare: bool,
 72}
 73
 74#[derive(Copy, Clone)]
 75pub struct Metrics {
 76    pub connections: usize,
 77    pub registered_projects: usize,
 78    pub shared_projects: usize,
 79}
 80
 81impl Store {
 82    pub fn metrics(&self) -> Metrics {
 83        let connections = self.connections.len();
 84        let mut registered_projects = 0;
 85        let mut shared_projects = 0;
 86        for project in self.projects.values() {
 87            registered_projects += 1;
 88            if !project.guests.is_empty() {
 89                shared_projects += 1;
 90            }
 91        }
 92
 93        Metrics {
 94            connections,
 95            registered_projects,
 96            shared_projects,
 97        }
 98    }
 99
100    #[instrument(skip(self))]
101    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
102        self.connections.insert(
103            connection_id,
104            ConnectionState {
105                user_id,
106                projects: Default::default(),
107                requested_projects: Default::default(),
108                channels: Default::default(),
109            },
110        );
111        self.connections_by_user_id
112            .entry(user_id)
113            .or_default()
114            .insert(connection_id);
115    }
116
117    #[instrument(skip(self))]
118    pub fn remove_connection(
119        &mut self,
120        connection_id: ConnectionId,
121    ) -> Result<RemovedConnectionState> {
122        let connection = self
123            .connections
124            .get_mut(&connection_id)
125            .ok_or_else(|| anyhow!("no such connection"))?;
126
127        let user_id = connection.user_id;
128        let connection_projects = mem::take(&mut connection.projects);
129        let connection_channels = mem::take(&mut connection.channels);
130
131        let mut result = RemovedConnectionState::default();
132        result.user_id = user_id;
133
134        // Leave all channels.
135        for channel_id in connection_channels {
136            self.leave_channel(connection_id, channel_id);
137        }
138
139        // Unregister and leave all projects.
140        for project_id in connection_projects {
141            if let Ok(project) = self.unregister_project(project_id, connection_id) {
142                result.hosted_projects.insert(project_id, project);
143            } else if self.leave_project(connection_id, project_id).is_ok() {
144                result.guest_project_ids.insert(project_id);
145            }
146        }
147
148        let user_connections = self.connections_by_user_id.get_mut(&user_id).unwrap();
149        user_connections.remove(&connection_id);
150        if user_connections.is_empty() {
151            self.connections_by_user_id.remove(&user_id);
152        }
153
154        self.connections.remove(&connection_id).unwrap();
155
156        Ok(result)
157    }
158
159    #[cfg(test)]
160    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
161        self.channels.get(&id)
162    }
163
164    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
165        if let Some(connection) = self.connections.get_mut(&connection_id) {
166            connection.channels.insert(channel_id);
167            self.channels
168                .entry(channel_id)
169                .or_default()
170                .connection_ids
171                .insert(connection_id);
172        }
173    }
174
175    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
176        if let Some(connection) = self.connections.get_mut(&connection_id) {
177            connection.channels.remove(&channel_id);
178            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
179                entry.get_mut().connection_ids.remove(&connection_id);
180                if entry.get_mut().connection_ids.is_empty() {
181                    entry.remove();
182                }
183            }
184        }
185    }
186
187    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> Result<UserId> {
188        Ok(self
189            .connections
190            .get(&connection_id)
191            .ok_or_else(|| anyhow!("unknown connection"))?
192            .user_id)
193    }
194
195    pub fn connection_ids_for_user<'a>(
196        &'a self,
197        user_id: UserId,
198    ) -> impl 'a + Iterator<Item = ConnectionId> {
199        self.connections_by_user_id
200            .get(&user_id)
201            .into_iter()
202            .flatten()
203            .copied()
204    }
205
206    pub fn is_user_online(&self, user_id: UserId) -> bool {
207        !self
208            .connections_by_user_id
209            .get(&user_id)
210            .unwrap_or(&Default::default())
211            .is_empty()
212    }
213
214    pub fn build_initial_contacts_update(
215        &self,
216        contacts: Vec<db::Contact>,
217    ) -> proto::UpdateContacts {
218        let mut update = proto::UpdateContacts::default();
219
220        for contact in contacts {
221            match contact {
222                db::Contact::Accepted {
223                    user_id,
224                    should_notify,
225                } => {
226                    update
227                        .contacts
228                        .push(self.contact_for_user(user_id, should_notify));
229                }
230                db::Contact::Outgoing { user_id } => {
231                    update.outgoing_requests.push(user_id.to_proto())
232                }
233                db::Contact::Incoming {
234                    user_id,
235                    should_notify,
236                } => update
237                    .incoming_requests
238                    .push(proto::IncomingContactRequest {
239                        requester_id: user_id.to_proto(),
240                        should_notify,
241                    }),
242            }
243        }
244
245        update
246    }
247
248    pub fn contact_for_user(&self, user_id: UserId, should_notify: bool) -> proto::Contact {
249        proto::Contact {
250            user_id: user_id.to_proto(),
251            projects: self.project_metadata_for_user(user_id),
252            online: self.is_user_online(user_id),
253            should_notify,
254        }
255    }
256
257    pub fn project_metadata_for_user(&self, user_id: UserId) -> Vec<proto::ProjectMetadata> {
258        let connection_ids = self.connections_by_user_id.get(&user_id);
259        let project_ids = connection_ids.iter().flat_map(|connection_ids| {
260            connection_ids
261                .iter()
262                .filter_map(|connection_id| self.connections.get(connection_id))
263                .flat_map(|connection| connection.projects.iter().copied())
264        });
265
266        let mut metadata = Vec::new();
267        for project_id in project_ids {
268            if let Some(project) = self.projects.get(&project_id) {
269                if project.host_user_id == user_id {
270                    metadata.push(proto::ProjectMetadata {
271                        id: project_id,
272                        worktree_root_names: project
273                            .worktrees
274                            .values()
275                            .map(|worktree| worktree.root_name.clone())
276                            .collect(),
277                        guests: project
278                            .guests
279                            .values()
280                            .map(|(_, user_id)| user_id.to_proto())
281                            .collect(),
282                    });
283                }
284            }
285        }
286
287        metadata
288    }
289
290    pub fn register_project(
291        &mut self,
292        host_connection_id: ConnectionId,
293        host_user_id: UserId,
294    ) -> u64 {
295        let project_id = self.next_project_id;
296        self.projects.insert(
297            project_id,
298            Project {
299                host_connection_id,
300                host_user_id,
301                guests: Default::default(),
302                join_requests: Default::default(),
303                active_replica_ids: Default::default(),
304                worktrees: Default::default(),
305                language_servers: Default::default(),
306            },
307        );
308        if let Some(connection) = self.connections.get_mut(&host_connection_id) {
309            connection.projects.insert(project_id);
310        }
311        self.next_project_id += 1;
312        project_id
313    }
314
315    pub fn register_worktree(
316        &mut self,
317        project_id: u64,
318        worktree_id: u64,
319        connection_id: ConnectionId,
320        worktree: Worktree,
321    ) -> Result<()> {
322        let project = self
323            .projects
324            .get_mut(&project_id)
325            .ok_or_else(|| anyhow!("no such project"))?;
326        if project.host_connection_id == connection_id {
327            project.worktrees.insert(worktree_id, worktree);
328            Ok(())
329        } else {
330            Err(anyhow!("no such project"))?
331        }
332    }
333
334    pub fn unregister_project(
335        &mut self,
336        project_id: u64,
337        connection_id: ConnectionId,
338    ) -> Result<Project> {
339        match self.projects.entry(project_id) {
340            hash_map::Entry::Occupied(e) => {
341                if e.get().host_connection_id == connection_id {
342                    let project = e.remove();
343
344                    if let Some(host_connection) = self.connections.get_mut(&connection_id) {
345                        host_connection.projects.remove(&project_id);
346                    }
347
348                    for guest_connection in project.guests.keys() {
349                        if let Some(connection) = self.connections.get_mut(&guest_connection) {
350                            connection.projects.remove(&project_id);
351                        }
352                    }
353
354                    for requester_user_id in project.join_requests.keys() {
355                        if let Some(requester_connection_ids) =
356                            self.connections_by_user_id.get_mut(&requester_user_id)
357                        {
358                            for requester_connection_id in requester_connection_ids.iter() {
359                                if let Some(requester_connection) =
360                                    self.connections.get_mut(requester_connection_id)
361                                {
362                                    requester_connection.requested_projects.remove(&project_id);
363                                }
364                            }
365                        }
366                    }
367
368                    Ok(project)
369                } else {
370                    Err(anyhow!("no such project"))?
371                }
372            }
373            hash_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?,
374        }
375    }
376
377    pub fn unregister_worktree(
378        &mut self,
379        project_id: u64,
380        worktree_id: u64,
381        acting_connection_id: ConnectionId,
382    ) -> Result<(Worktree, Vec<ConnectionId>)> {
383        let project = self
384            .projects
385            .get_mut(&project_id)
386            .ok_or_else(|| anyhow!("no such project"))?;
387        if project.host_connection_id != acting_connection_id {
388            Err(anyhow!("not your worktree"))?;
389        }
390
391        let worktree = project
392            .worktrees
393            .remove(&worktree_id)
394            .ok_or_else(|| anyhow!("no such worktree"))?;
395        Ok((worktree, project.guest_connection_ids()))
396    }
397
398    pub fn update_diagnostic_summary(
399        &mut self,
400        project_id: u64,
401        worktree_id: u64,
402        connection_id: ConnectionId,
403        summary: proto::DiagnosticSummary,
404    ) -> Result<Vec<ConnectionId>> {
405        let project = self
406            .projects
407            .get_mut(&project_id)
408            .ok_or_else(|| anyhow!("no such project"))?;
409        if project.host_connection_id == connection_id {
410            let worktree = project
411                .worktrees
412                .get_mut(&worktree_id)
413                .ok_or_else(|| anyhow!("no such worktree"))?;
414            worktree
415                .diagnostic_summaries
416                .insert(summary.path.clone().into(), summary);
417            return Ok(project.connection_ids());
418        }
419
420        Err(anyhow!("no such worktree"))?
421    }
422
423    pub fn start_language_server(
424        &mut self,
425        project_id: u64,
426        connection_id: ConnectionId,
427        language_server: proto::LanguageServer,
428    ) -> Result<Vec<ConnectionId>> {
429        let project = self
430            .projects
431            .get_mut(&project_id)
432            .ok_or_else(|| anyhow!("no such project"))?;
433        if project.host_connection_id == connection_id {
434            project.language_servers.push(language_server);
435            return Ok(project.connection_ids());
436        }
437
438        Err(anyhow!("no such project"))?
439    }
440
441    pub fn request_join_project(
442        &mut self,
443        requester_id: UserId,
444        project_id: u64,
445        receipt: Receipt<proto::JoinProject>,
446    ) -> Result<()> {
447        let connection = self
448            .connections
449            .get_mut(&receipt.sender_id)
450            .ok_or_else(|| anyhow!("no such connection"))?;
451        let project = self
452            .projects
453            .get_mut(&project_id)
454            .ok_or_else(|| anyhow!("no such project"))?;
455        connection.requested_projects.insert(project_id);
456        project
457            .join_requests
458            .entry(requester_id)
459            .or_default()
460            .push(receipt);
461        Ok(())
462    }
463
464    pub fn deny_join_project_request(
465        &mut self,
466        responder_connection_id: ConnectionId,
467        requester_id: UserId,
468        project_id: u64,
469    ) -> Option<Vec<Receipt<proto::JoinProject>>> {
470        let project = self.projects.get_mut(&project_id)?;
471        if responder_connection_id != project.host_connection_id {
472            return None;
473        }
474
475        let receipts = project.join_requests.remove(&requester_id)?;
476        for receipt in &receipts {
477            let requester_connection = self.connections.get_mut(&receipt.sender_id)?;
478            requester_connection.requested_projects.remove(&project_id);
479        }
480        Some(receipts)
481    }
482
483    pub fn accept_join_project_request(
484        &mut self,
485        responder_connection_id: ConnectionId,
486        requester_id: UserId,
487        project_id: u64,
488    ) -> Option<(Vec<(Receipt<proto::JoinProject>, ReplicaId)>, &Project)> {
489        let project = self.projects.get_mut(&project_id)?;
490        if responder_connection_id != project.host_connection_id {
491            return None;
492        }
493
494        let receipts = project.join_requests.remove(&requester_id)?;
495        let mut receipts_with_replica_ids = Vec::new();
496        for receipt in receipts {
497            let requester_connection = self.connections.get_mut(&receipt.sender_id)?;
498            requester_connection.requested_projects.remove(&project_id);
499            requester_connection.projects.insert(project_id);
500            let mut replica_id = 1;
501            while project.active_replica_ids.contains(&replica_id) {
502                replica_id += 1;
503            }
504            project.active_replica_ids.insert(replica_id);
505            project
506                .guests
507                .insert(receipt.sender_id, (replica_id, requester_id));
508            receipts_with_replica_ids.push((receipt, replica_id));
509        }
510
511        Some((receipts_with_replica_ids, project))
512    }
513
514    pub fn leave_project(
515        &mut self,
516        connection_id: ConnectionId,
517        project_id: u64,
518    ) -> Result<LeftProject> {
519        let user_id = self.user_id_for_connection(connection_id)?;
520        let project = self
521            .projects
522            .get_mut(&project_id)
523            .ok_or_else(|| anyhow!("no such project"))?;
524
525        // If the connection leaving the project is a collaborator, remove it.
526        let remove_collaborator =
527            if let Some((replica_id, _)) = project.guests.remove(&connection_id) {
528                project.active_replica_ids.remove(&replica_id);
529                true
530            } else {
531                false
532            };
533
534        // If the connection leaving the project has a pending request, remove it.
535        // If that user has no other pending requests on other connections, indicate that the request should be cancelled.
536        let mut cancel_request = None;
537        if let Entry::Occupied(mut entry) = project.join_requests.entry(user_id) {
538            entry
539                .get_mut()
540                .retain(|receipt| receipt.sender_id != connection_id);
541            if entry.get().is_empty() {
542                entry.remove();
543                cancel_request = Some(user_id);
544            }
545        }
546
547        if let Some(connection) = self.connections.get_mut(&connection_id) {
548            connection.projects.remove(&project_id);
549        }
550
551        let connection_ids = project.connection_ids();
552        let unshare = connection_ids.len() <= 1 && project.join_requests.is_empty();
553        if unshare {
554            project.language_servers.clear();
555            for worktree in project.worktrees.values_mut() {
556                worktree.diagnostic_summaries.clear();
557                worktree.entries.clear();
558            }
559        }
560
561        Ok(LeftProject {
562            host_connection_id: project.host_connection_id,
563            host_user_id: project.host_user_id,
564            connection_ids,
565            cancel_request,
566            unshare,
567            remove_collaborator,
568        })
569    }
570
571    pub fn update_worktree(
572        &mut self,
573        connection_id: ConnectionId,
574        project_id: u64,
575        worktree_id: u64,
576        removed_entries: &[u64],
577        updated_entries: &[proto::Entry],
578        scan_id: u64,
579    ) -> Result<Vec<ConnectionId>> {
580        let project = self.write_project(project_id, connection_id)?;
581        let worktree = project
582            .worktrees
583            .get_mut(&worktree_id)
584            .ok_or_else(|| anyhow!("no such worktree"))?;
585        for entry_id in removed_entries {
586            worktree.entries.remove(&entry_id);
587        }
588        for entry in updated_entries {
589            worktree.entries.insert(entry.id, entry.clone());
590        }
591        worktree.scan_id = scan_id;
592        let connection_ids = project.connection_ids();
593        Ok(connection_ids)
594    }
595
596    pub fn project_connection_ids(
597        &self,
598        project_id: u64,
599        acting_connection_id: ConnectionId,
600    ) -> Result<Vec<ConnectionId>> {
601        Ok(self
602            .read_project(project_id, acting_connection_id)?
603            .connection_ids())
604    }
605
606    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Result<Vec<ConnectionId>> {
607        Ok(self
608            .channels
609            .get(&channel_id)
610            .ok_or_else(|| anyhow!("no such channel"))?
611            .connection_ids())
612    }
613
614    pub fn project(&self, project_id: u64) -> Result<&Project> {
615        self.projects
616            .get(&project_id)
617            .ok_or_else(|| anyhow!("no such project"))
618    }
619
620    pub fn read_project(&self, project_id: u64, connection_id: ConnectionId) -> Result<&Project> {
621        let project = self
622            .projects
623            .get(&project_id)
624            .ok_or_else(|| anyhow!("no such project"))?;
625        if project.host_connection_id == connection_id
626            || project.guests.contains_key(&connection_id)
627        {
628            Ok(project)
629        } else {
630            Err(anyhow!("no such project"))?
631        }
632    }
633
634    fn write_project(
635        &mut self,
636        project_id: u64,
637        connection_id: ConnectionId,
638    ) -> Result<&mut Project> {
639        let project = self
640            .projects
641            .get_mut(&project_id)
642            .ok_or_else(|| anyhow!("no such project"))?;
643        if project.host_connection_id == connection_id
644            || project.guests.contains_key(&connection_id)
645        {
646            Ok(project)
647        } else {
648            Err(anyhow!("no such project"))?
649        }
650    }
651
652    #[cfg(test)]
653    pub fn check_invariants(&self) {
654        for (connection_id, connection) in &self.connections {
655            for project_id in &connection.projects {
656                let project = &self.projects.get(&project_id).unwrap();
657                if project.host_connection_id != *connection_id {
658                    assert!(project.guests.contains_key(connection_id));
659                }
660
661                for (worktree_id, worktree) in project.worktrees.iter() {
662                    let mut paths = HashMap::default();
663                    for entry in worktree.entries.values() {
664                        let prev_entry = paths.insert(&entry.path, entry);
665                        assert_eq!(
666                            prev_entry,
667                            None,
668                            "worktree {:?}, duplicate path for entries {:?} and {:?}",
669                            worktree_id,
670                            prev_entry.unwrap(),
671                            entry
672                        );
673                    }
674                }
675            }
676            for channel_id in &connection.channels {
677                let channel = self.channels.get(channel_id).unwrap();
678                assert!(channel.connection_ids.contains(connection_id));
679            }
680            assert!(self
681                .connections_by_user_id
682                .get(&connection.user_id)
683                .unwrap()
684                .contains(connection_id));
685        }
686
687        for (user_id, connection_ids) in &self.connections_by_user_id {
688            for connection_id in connection_ids {
689                assert_eq!(
690                    self.connections.get(connection_id).unwrap().user_id,
691                    *user_id
692                );
693            }
694        }
695
696        for (project_id, project) in &self.projects {
697            let host_connection = self.connections.get(&project.host_connection_id).unwrap();
698            assert!(host_connection.projects.contains(project_id));
699
700            for guest_connection_id in project.guests.keys() {
701                let guest_connection = self.connections.get(guest_connection_id).unwrap();
702                assert!(guest_connection.projects.contains(project_id));
703            }
704            assert_eq!(project.active_replica_ids.len(), project.guests.len(),);
705            assert_eq!(
706                project.active_replica_ids,
707                project
708                    .guests
709                    .values()
710                    .map(|(replica_id, _)| *replica_id)
711                    .collect::<HashSet<_>>(),
712            );
713        }
714
715        for (channel_id, channel) in &self.channels {
716            for connection_id in &channel.connection_ids {
717                let connection = self.connections.get(connection_id).unwrap();
718                assert!(connection.channels.contains(channel_id));
719            }
720        }
721    }
722}
723
724impl Project {
725    pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
726        self.guests.keys().copied().collect()
727    }
728
729    pub fn connection_ids(&self) -> Vec<ConnectionId> {
730        self.guests
731            .keys()
732            .copied()
733            .chain(Some(self.host_connection_id))
734            .collect()
735    }
736}
737
738impl Channel {
739    fn connection_ids(&self) -> Vec<ConnectionId> {
740        self.connection_ids.iter().copied().collect()
741    }
742}