store.rs

  1use crate::db::{self, ChannelId, UserId};
  2use anyhow::{anyhow, Result};
  3use collections::{hash_map::Entry, BTreeMap, HashMap, HashSet};
  4use rpc::{proto, ConnectionId, Receipt};
  5use std::{collections::hash_map, mem, path::PathBuf};
  6use tracing::instrument;
  7
  8#[derive(Default)]
  9pub struct Store {
 10    connections: HashMap<ConnectionId, ConnectionState>,
 11    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
 12    projects: HashMap<u64, Project>,
 13    channels: HashMap<ChannelId, Channel>,
 14    next_project_id: u64,
 15}
 16
 17struct ConnectionState {
 18    user_id: UserId,
 19    projects: HashSet<u64>,
 20    requested_projects: HashSet<u64>,
 21    channels: HashSet<ChannelId>,
 22}
 23
 24pub struct Project {
 25    pub host_connection_id: ConnectionId,
 26    pub host_user_id: UserId,
 27    pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
 28    pub join_requests: HashMap<UserId, Vec<Receipt<proto::JoinProject>>>,
 29    pub active_replica_ids: HashSet<ReplicaId>,
 30    pub worktrees: HashMap<u64, Worktree>,
 31    pub language_servers: Vec<proto::LanguageServer>,
 32}
 33
 34#[derive(Default)]
 35pub struct Worktree {
 36    pub root_name: String,
 37    pub visible: bool,
 38    pub entries: HashMap<u64, proto::Entry>,
 39    pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
 40    pub scan_id: u64,
 41}
 42
 43#[derive(Default)]
 44pub struct Channel {
 45    pub connection_ids: HashSet<ConnectionId>,
 46}
 47
 48pub type ReplicaId = u16;
 49
 50#[derive(Default)]
 51pub struct RemovedConnectionState {
 52    pub user_id: UserId,
 53    pub hosted_projects: HashMap<u64, Project>,
 54    pub guest_project_ids: HashSet<u64>,
 55    pub contact_ids: HashSet<UserId>,
 56}
 57
 58pub struct LeftProject {
 59    pub host_user_id: UserId,
 60    pub host_connection_id: ConnectionId,
 61    pub connection_ids: Vec<ConnectionId>,
 62    pub remove_collaborator: bool,
 63    pub cancel_request: Option<UserId>,
 64    pub unshare: bool,
 65}
 66
 67#[derive(Copy, Clone)]
 68pub struct Metrics {
 69    pub connections: usize,
 70    pub registered_projects: usize,
 71    pub shared_projects: usize,
 72}
 73
 74impl Store {
 75    pub fn metrics(&self) -> Metrics {
 76        let connections = self.connections.len();
 77        let mut registered_projects = 0;
 78        let mut shared_projects = 0;
 79        for project in self.projects.values() {
 80            registered_projects += 1;
 81            if !project.guests.is_empty() {
 82                shared_projects += 1;
 83            }
 84        }
 85
 86        Metrics {
 87            connections,
 88            registered_projects,
 89            shared_projects,
 90        }
 91    }
 92
 93    #[instrument(skip(self))]
 94    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
 95        self.connections.insert(
 96            connection_id,
 97            ConnectionState {
 98                user_id,
 99                projects: Default::default(),
100                requested_projects: Default::default(),
101                channels: Default::default(),
102            },
103        );
104        self.connections_by_user_id
105            .entry(user_id)
106            .or_default()
107            .insert(connection_id);
108    }
109
110    #[instrument(skip(self))]
111    pub fn remove_connection(
112        &mut self,
113        connection_id: ConnectionId,
114    ) -> Result<RemovedConnectionState> {
115        let connection = self
116            .connections
117            .get_mut(&connection_id)
118            .ok_or_else(|| anyhow!("no such connection"))?;
119
120        let user_id = connection.user_id;
121        let connection_projects = mem::take(&mut connection.projects);
122        let connection_channels = mem::take(&mut connection.channels);
123
124        let mut result = RemovedConnectionState::default();
125        result.user_id = user_id;
126
127        // Leave all channels.
128        for channel_id in connection_channels {
129            self.leave_channel(connection_id, channel_id);
130        }
131
132        // Unregister and leave all projects.
133        for project_id in connection_projects {
134            if let Ok(project) = self.unregister_project(project_id, connection_id) {
135                result.hosted_projects.insert(project_id, project);
136            } else if self.leave_project(connection_id, project_id).is_ok() {
137                result.guest_project_ids.insert(project_id);
138            }
139        }
140
141        let user_connections = self.connections_by_user_id.get_mut(&user_id).unwrap();
142        user_connections.remove(&connection_id);
143        if user_connections.is_empty() {
144            self.connections_by_user_id.remove(&user_id);
145        }
146
147        self.connections.remove(&connection_id).unwrap();
148
149        Ok(result)
150    }
151
152    #[cfg(test)]
153    pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
154        self.channels.get(&id)
155    }
156
157    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
158        if let Some(connection) = self.connections.get_mut(&connection_id) {
159            connection.channels.insert(channel_id);
160            self.channels
161                .entry(channel_id)
162                .or_default()
163                .connection_ids
164                .insert(connection_id);
165        }
166    }
167
168    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
169        if let Some(connection) = self.connections.get_mut(&connection_id) {
170            connection.channels.remove(&channel_id);
171            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
172                entry.get_mut().connection_ids.remove(&connection_id);
173                if entry.get_mut().connection_ids.is_empty() {
174                    entry.remove();
175                }
176            }
177        }
178    }
179
180    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> Result<UserId> {
181        Ok(self
182            .connections
183            .get(&connection_id)
184            .ok_or_else(|| anyhow!("unknown connection"))?
185            .user_id)
186    }
187
188    pub fn connection_ids_for_user<'a>(
189        &'a self,
190        user_id: UserId,
191    ) -> impl 'a + Iterator<Item = ConnectionId> {
192        self.connections_by_user_id
193            .get(&user_id)
194            .into_iter()
195            .flatten()
196            .copied()
197    }
198
199    pub fn is_user_online(&self, user_id: UserId) -> bool {
200        !self
201            .connections_by_user_id
202            .get(&user_id)
203            .unwrap_or(&Default::default())
204            .is_empty()
205    }
206
207    pub fn build_initial_contacts_update(
208        &self,
209        contacts: Vec<db::Contact>,
210    ) -> proto::UpdateContacts {
211        let mut update = proto::UpdateContacts::default();
212
213        for contact in contacts {
214            match contact {
215                db::Contact::Accepted {
216                    user_id,
217                    should_notify,
218                } => {
219                    update
220                        .contacts
221                        .push(self.contact_for_user(user_id, should_notify));
222                }
223                db::Contact::Outgoing { user_id } => {
224                    update.outgoing_requests.push(user_id.to_proto())
225                }
226                db::Contact::Incoming {
227                    user_id,
228                    should_notify,
229                } => update
230                    .incoming_requests
231                    .push(proto::IncomingContactRequest {
232                        requester_id: user_id.to_proto(),
233                        should_notify,
234                    }),
235            }
236        }
237
238        update
239    }
240
241    pub fn contact_for_user(&self, user_id: UserId, should_notify: bool) -> proto::Contact {
242        proto::Contact {
243            user_id: user_id.to_proto(),
244            projects: self.project_metadata_for_user(user_id),
245            online: self.is_user_online(user_id),
246            should_notify,
247        }
248    }
249
250    pub fn project_metadata_for_user(&self, user_id: UserId) -> Vec<proto::ProjectMetadata> {
251        let connection_ids = self.connections_by_user_id.get(&user_id);
252        let project_ids = connection_ids.iter().flat_map(|connection_ids| {
253            connection_ids
254                .iter()
255                .filter_map(|connection_id| self.connections.get(connection_id))
256                .flat_map(|connection| connection.projects.iter().copied())
257        });
258
259        let mut metadata = Vec::new();
260        for project_id in project_ids {
261            if let Some(project) = self.projects.get(&project_id) {
262                if project.host_user_id == user_id {
263                    metadata.push(proto::ProjectMetadata {
264                        id: project_id,
265                        worktree_root_names: project
266                            .worktrees
267                            .values()
268                            .map(|worktree| worktree.root_name.clone())
269                            .collect(),
270                        guests: project
271                            .guests
272                            .values()
273                            .map(|(_, user_id)| user_id.to_proto())
274                            .collect(),
275                    });
276                }
277            }
278        }
279
280        metadata
281    }
282
283    pub fn register_project(
284        &mut self,
285        host_connection_id: ConnectionId,
286        host_user_id: UserId,
287    ) -> u64 {
288        let project_id = self.next_project_id;
289        self.projects.insert(
290            project_id,
291            Project {
292                host_connection_id,
293                host_user_id,
294                guests: Default::default(),
295                join_requests: Default::default(),
296                active_replica_ids: Default::default(),
297                worktrees: Default::default(),
298                language_servers: Default::default(),
299            },
300        );
301        if let Some(connection) = self.connections.get_mut(&host_connection_id) {
302            connection.projects.insert(project_id);
303        }
304        self.next_project_id += 1;
305        project_id
306    }
307
308    pub fn register_worktree(
309        &mut self,
310        project_id: u64,
311        worktree_id: u64,
312        connection_id: ConnectionId,
313        worktree: Worktree,
314    ) -> Result<()> {
315        let project = self
316            .projects
317            .get_mut(&project_id)
318            .ok_or_else(|| anyhow!("no such project"))?;
319        if project.host_connection_id == connection_id {
320            project.worktrees.insert(worktree_id, worktree);
321            Ok(())
322        } else {
323            Err(anyhow!("no such project"))?
324        }
325    }
326
327    pub fn unregister_project(
328        &mut self,
329        project_id: u64,
330        connection_id: ConnectionId,
331    ) -> Result<Project> {
332        match self.projects.entry(project_id) {
333            hash_map::Entry::Occupied(e) => {
334                if e.get().host_connection_id == connection_id {
335                    let project = e.remove();
336
337                    if let Some(host_connection) = self.connections.get_mut(&connection_id) {
338                        host_connection.projects.remove(&project_id);
339                    }
340
341                    for guest_connection in project.guests.keys() {
342                        if let Some(connection) = self.connections.get_mut(&guest_connection) {
343                            connection.projects.remove(&project_id);
344                        }
345                    }
346
347                    for requester_user_id in project.join_requests.keys() {
348                        if let Some(requester_connection_ids) =
349                            self.connections_by_user_id.get_mut(&requester_user_id)
350                        {
351                            for requester_connection_id in requester_connection_ids.iter() {
352                                if let Some(requester_connection) =
353                                    self.connections.get_mut(requester_connection_id)
354                                {
355                                    requester_connection.requested_projects.remove(&project_id);
356                                }
357                            }
358                        }
359                    }
360
361                    Ok(project)
362                } else {
363                    Err(anyhow!("no such project"))?
364                }
365            }
366            hash_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?,
367        }
368    }
369
370    pub fn unregister_worktree(
371        &mut self,
372        project_id: u64,
373        worktree_id: u64,
374        acting_connection_id: ConnectionId,
375    ) -> Result<(Worktree, Vec<ConnectionId>)> {
376        let project = self
377            .projects
378            .get_mut(&project_id)
379            .ok_or_else(|| anyhow!("no such project"))?;
380        if project.host_connection_id != acting_connection_id {
381            Err(anyhow!("not your worktree"))?;
382        }
383
384        let worktree = project
385            .worktrees
386            .remove(&worktree_id)
387            .ok_or_else(|| anyhow!("no such worktree"))?;
388        Ok((worktree, project.guest_connection_ids()))
389    }
390
391    pub fn update_diagnostic_summary(
392        &mut self,
393        project_id: u64,
394        worktree_id: u64,
395        connection_id: ConnectionId,
396        summary: proto::DiagnosticSummary,
397    ) -> Result<Vec<ConnectionId>> {
398        let project = self
399            .projects
400            .get_mut(&project_id)
401            .ok_or_else(|| anyhow!("no such project"))?;
402        if project.host_connection_id == connection_id {
403            let worktree = project
404                .worktrees
405                .get_mut(&worktree_id)
406                .ok_or_else(|| anyhow!("no such worktree"))?;
407            worktree
408                .diagnostic_summaries
409                .insert(summary.path.clone().into(), summary);
410            return Ok(project.connection_ids());
411        }
412
413        Err(anyhow!("no such worktree"))?
414    }
415
416    pub fn start_language_server(
417        &mut self,
418        project_id: u64,
419        connection_id: ConnectionId,
420        language_server: proto::LanguageServer,
421    ) -> Result<Vec<ConnectionId>> {
422        let project = self
423            .projects
424            .get_mut(&project_id)
425            .ok_or_else(|| anyhow!("no such project"))?;
426        if project.host_connection_id == connection_id {
427            project.language_servers.push(language_server);
428            return Ok(project.connection_ids());
429        }
430
431        Err(anyhow!("no such project"))?
432    }
433
434    pub fn request_join_project(
435        &mut self,
436        requester_id: UserId,
437        project_id: u64,
438        receipt: Receipt<proto::JoinProject>,
439    ) -> Result<()> {
440        let connection = self
441            .connections
442            .get_mut(&receipt.sender_id)
443            .ok_or_else(|| anyhow!("no such connection"))?;
444        let project = self
445            .projects
446            .get_mut(&project_id)
447            .ok_or_else(|| anyhow!("no such project"))?;
448        connection.requested_projects.insert(project_id);
449        project
450            .join_requests
451            .entry(requester_id)
452            .or_default()
453            .push(receipt);
454        Ok(())
455    }
456
457    pub fn deny_join_project_request(
458        &mut self,
459        responder_connection_id: ConnectionId,
460        requester_id: UserId,
461        project_id: u64,
462    ) -> Option<Vec<Receipt<proto::JoinProject>>> {
463        let project = self.projects.get_mut(&project_id)?;
464        if responder_connection_id != project.host_connection_id {
465            return None;
466        }
467
468        let receipts = project.join_requests.remove(&requester_id)?;
469        for receipt in &receipts {
470            let requester_connection = self.connections.get_mut(&receipt.sender_id)?;
471            requester_connection.requested_projects.remove(&project_id);
472        }
473        Some(receipts)
474    }
475
476    pub fn accept_join_project_request(
477        &mut self,
478        responder_connection_id: ConnectionId,
479        requester_id: UserId,
480        project_id: u64,
481    ) -> Option<(Vec<(Receipt<proto::JoinProject>, ReplicaId)>, &Project)> {
482        let project = self.projects.get_mut(&project_id)?;
483        if responder_connection_id != project.host_connection_id {
484            return None;
485        }
486
487        let receipts = project.join_requests.remove(&requester_id)?;
488        let mut receipts_with_replica_ids = Vec::new();
489        for receipt in receipts {
490            let requester_connection = self.connections.get_mut(&receipt.sender_id)?;
491            requester_connection.requested_projects.remove(&project_id);
492            requester_connection.projects.insert(project_id);
493            let mut replica_id = 1;
494            while project.active_replica_ids.contains(&replica_id) {
495                replica_id += 1;
496            }
497            project.active_replica_ids.insert(replica_id);
498            project
499                .guests
500                .insert(receipt.sender_id, (replica_id, requester_id));
501            receipts_with_replica_ids.push((receipt, replica_id));
502        }
503
504        Some((receipts_with_replica_ids, project))
505    }
506
507    pub fn leave_project(
508        &mut self,
509        connection_id: ConnectionId,
510        project_id: u64,
511    ) -> Result<LeftProject> {
512        let user_id = self.user_id_for_connection(connection_id)?;
513        let project = self
514            .projects
515            .get_mut(&project_id)
516            .ok_or_else(|| anyhow!("no such project"))?;
517
518        // If the connection leaving the project is a collaborator, remove it.
519        let remove_collaborator =
520            if let Some((replica_id, _)) = project.guests.remove(&connection_id) {
521                project.active_replica_ids.remove(&replica_id);
522                true
523            } else {
524                false
525            };
526
527        // If the connection leaving the project has a pending request, remove it.
528        // If that user has no other pending requests on other connections, indicate that the request should be cancelled.
529        let mut cancel_request = None;
530        if let Entry::Occupied(mut entry) = project.join_requests.entry(user_id) {
531            entry
532                .get_mut()
533                .retain(|receipt| receipt.sender_id != connection_id);
534            if entry.get().is_empty() {
535                entry.remove();
536                cancel_request = Some(user_id);
537            }
538        }
539
540        if let Some(connection) = self.connections.get_mut(&connection_id) {
541            connection.projects.remove(&project_id);
542        }
543
544        let connection_ids = project.connection_ids();
545        let unshare = connection_ids.len() <= 1 && project.join_requests.is_empty();
546
547        Ok(LeftProject {
548            host_connection_id: project.host_connection_id,
549            host_user_id: project.host_user_id,
550            connection_ids,
551            cancel_request,
552            unshare,
553            remove_collaborator,
554        })
555    }
556
557    pub fn update_worktree(
558        &mut self,
559        connection_id: ConnectionId,
560        project_id: u64,
561        worktree_id: u64,
562        removed_entries: &[u64],
563        updated_entries: &[proto::Entry],
564        scan_id: u64,
565    ) -> Result<Vec<ConnectionId>> {
566        let project = self.write_project(project_id, connection_id)?;
567        let worktree = project
568            .worktrees
569            .get_mut(&worktree_id)
570            .ok_or_else(|| anyhow!("no such worktree"))?;
571        for entry_id in removed_entries {
572            worktree.entries.remove(&entry_id);
573        }
574        for entry in updated_entries {
575            worktree.entries.insert(entry.id, entry.clone());
576        }
577        worktree.scan_id = scan_id;
578        let connection_ids = project.connection_ids();
579        Ok(connection_ids)
580    }
581
582    pub fn project_connection_ids(
583        &self,
584        project_id: u64,
585        acting_connection_id: ConnectionId,
586    ) -> Result<Vec<ConnectionId>> {
587        Ok(self
588            .read_project(project_id, acting_connection_id)?
589            .connection_ids())
590    }
591
592    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Result<Vec<ConnectionId>> {
593        Ok(self
594            .channels
595            .get(&channel_id)
596            .ok_or_else(|| anyhow!("no such channel"))?
597            .connection_ids())
598    }
599
600    pub fn project(&self, project_id: u64) -> Result<&Project> {
601        self.projects
602            .get(&project_id)
603            .ok_or_else(|| anyhow!("no such project"))
604    }
605
606    pub fn read_project(&self, project_id: u64, connection_id: ConnectionId) -> Result<&Project> {
607        let project = self
608            .projects
609            .get(&project_id)
610            .ok_or_else(|| anyhow!("no such project"))?;
611        if project.host_connection_id == connection_id
612            || project.guests.contains_key(&connection_id)
613        {
614            Ok(project)
615        } else {
616            Err(anyhow!("no such project"))?
617        }
618    }
619
620    fn write_project(
621        &mut self,
622        project_id: u64,
623        connection_id: ConnectionId,
624    ) -> Result<&mut Project> {
625        let project = self
626            .projects
627            .get_mut(&project_id)
628            .ok_or_else(|| anyhow!("no such project"))?;
629        if project.host_connection_id == connection_id
630            || project.guests.contains_key(&connection_id)
631        {
632            Ok(project)
633        } else {
634            Err(anyhow!("no such project"))?
635        }
636    }
637
638    #[cfg(test)]
639    pub fn check_invariants(&self) {
640        for (connection_id, connection) in &self.connections {
641            for project_id in &connection.projects {
642                let project = &self.projects.get(&project_id).unwrap();
643                if project.host_connection_id != *connection_id {
644                    assert!(project.guests.contains_key(connection_id));
645                }
646
647                for (worktree_id, worktree) in project.worktrees.iter() {
648                    let mut paths = HashMap::default();
649                    for entry in worktree.entries.values() {
650                        let prev_entry = paths.insert(&entry.path, entry);
651                        assert_eq!(
652                            prev_entry,
653                            None,
654                            "worktree {:?}, duplicate path for entries {:?} and {:?}",
655                            worktree_id,
656                            prev_entry.unwrap(),
657                            entry
658                        );
659                    }
660                }
661            }
662            for channel_id in &connection.channels {
663                let channel = self.channels.get(channel_id).unwrap();
664                assert!(channel.connection_ids.contains(connection_id));
665            }
666            assert!(self
667                .connections_by_user_id
668                .get(&connection.user_id)
669                .unwrap()
670                .contains(connection_id));
671        }
672
673        for (user_id, connection_ids) in &self.connections_by_user_id {
674            for connection_id in connection_ids {
675                assert_eq!(
676                    self.connections.get(connection_id).unwrap().user_id,
677                    *user_id
678                );
679            }
680        }
681
682        for (project_id, project) in &self.projects {
683            let host_connection = self.connections.get(&project.host_connection_id).unwrap();
684            assert!(host_connection.projects.contains(project_id));
685
686            for guest_connection_id in project.guests.keys() {
687                let guest_connection = self.connections.get(guest_connection_id).unwrap();
688                assert!(guest_connection.projects.contains(project_id));
689            }
690            assert_eq!(project.active_replica_ids.len(), project.guests.len(),);
691            assert_eq!(
692                project.active_replica_ids,
693                project
694                    .guests
695                    .values()
696                    .map(|(replica_id, _)| *replica_id)
697                    .collect::<HashSet<_>>(),
698            );
699        }
700
701        for (channel_id, channel) in &self.channels {
702            for connection_id in &channel.connection_ids {
703                let connection = self.connections.get(connection_id).unwrap();
704                assert!(connection.channels.contains(channel_id));
705            }
706        }
707    }
708}
709
710impl Project {
711    pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
712        self.guests.keys().copied().collect()
713    }
714
715    pub fn connection_ids(&self) -> Vec<ConnectionId> {
716        self.guests
717            .keys()
718            .copied()
719            .chain(Some(self.host_connection_id))
720            .collect()
721    }
722}
723
724impl Channel {
725    fn connection_ids(&self) -> Vec<ConnectionId> {
726        self.connection_ids.iter().copied().collect()
727    }
728}