1use crate::db::{self, ChannelId, UserId};
2use anyhow::{anyhow, Result};
3use collections::{hash_map::Entry, BTreeMap, HashMap, HashSet};
4use rpc::{proto, ConnectionId, Receipt};
5use std::{collections::hash_map, mem, path::PathBuf};
6use tracing::instrument;
7
8#[derive(Default)]
9pub struct Store {
10 connections: HashMap<ConnectionId, ConnectionState>,
11 connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
12 projects: HashMap<u64, Project>,
13 channels: HashMap<ChannelId, Channel>,
14 next_project_id: u64,
15}
16
17struct ConnectionState {
18 user_id: UserId,
19 projects: HashSet<u64>,
20 requested_projects: HashSet<u64>,
21 channels: HashSet<ChannelId>,
22}
23
24pub struct Project {
25 pub host_connection_id: ConnectionId,
26 pub host_user_id: UserId,
27 pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
28 pub join_requests: HashMap<UserId, Vec<Receipt<proto::JoinProject>>>,
29 pub active_replica_ids: HashSet<ReplicaId>,
30 pub worktrees: HashMap<u64, Worktree>,
31 pub language_servers: Vec<proto::LanguageServer>,
32}
33
34#[derive(Default)]
35pub struct Worktree {
36 pub root_name: String,
37 pub visible: bool,
38 pub entries: HashMap<u64, proto::Entry>,
39 pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
40 pub scan_id: u64,
41}
42
43#[derive(Default)]
44pub struct Channel {
45 pub connection_ids: HashSet<ConnectionId>,
46}
47
48pub type ReplicaId = u16;
49
50#[derive(Default)]
51pub struct RemovedConnectionState {
52 pub user_id: UserId,
53 pub hosted_projects: HashMap<u64, Project>,
54 pub guest_project_ids: HashSet<u64>,
55 pub contact_ids: HashSet<UserId>,
56}
57
58pub struct LeftProject {
59 pub host_user_id: UserId,
60 pub host_connection_id: ConnectionId,
61 pub connection_ids: Vec<ConnectionId>,
62 pub remove_collaborator: bool,
63 pub cancel_request: Option<UserId>,
64 pub unshare: bool,
65}
66
67#[derive(Copy, Clone)]
68pub struct Metrics {
69 pub connections: usize,
70 pub registered_projects: usize,
71 pub shared_projects: usize,
72}
73
74impl Store {
75 pub fn metrics(&self) -> Metrics {
76 let connections = self.connections.len();
77 let mut registered_projects = 0;
78 let mut shared_projects = 0;
79 for project in self.projects.values() {
80 registered_projects += 1;
81 if !project.guests.is_empty() {
82 shared_projects += 1;
83 }
84 }
85
86 Metrics {
87 connections,
88 registered_projects,
89 shared_projects,
90 }
91 }
92
93 #[instrument(skip(self))]
94 pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
95 self.connections.insert(
96 connection_id,
97 ConnectionState {
98 user_id,
99 projects: Default::default(),
100 requested_projects: Default::default(),
101 channels: Default::default(),
102 },
103 );
104 self.connections_by_user_id
105 .entry(user_id)
106 .or_default()
107 .insert(connection_id);
108 }
109
110 #[instrument(skip(self))]
111 pub fn remove_connection(
112 &mut self,
113 connection_id: ConnectionId,
114 ) -> Result<RemovedConnectionState> {
115 let connection = self
116 .connections
117 .get_mut(&connection_id)
118 .ok_or_else(|| anyhow!("no such connection"))?;
119
120 let user_id = connection.user_id;
121 let connection_projects = mem::take(&mut connection.projects);
122 let connection_channels = mem::take(&mut connection.channels);
123
124 let mut result = RemovedConnectionState::default();
125 result.user_id = user_id;
126
127 // Leave all channels.
128 for channel_id in connection_channels {
129 self.leave_channel(connection_id, channel_id);
130 }
131
132 // Unregister and leave all projects.
133 for project_id in connection_projects {
134 if let Ok(project) = self.unregister_project(project_id, connection_id) {
135 result.hosted_projects.insert(project_id, project);
136 } else if self.leave_project(connection_id, project_id).is_ok() {
137 result.guest_project_ids.insert(project_id);
138 }
139 }
140
141 let user_connections = self.connections_by_user_id.get_mut(&user_id).unwrap();
142 user_connections.remove(&connection_id);
143 if user_connections.is_empty() {
144 self.connections_by_user_id.remove(&user_id);
145 }
146
147 self.connections.remove(&connection_id).unwrap();
148
149 Ok(result)
150 }
151
152 #[cfg(test)]
153 pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
154 self.channels.get(&id)
155 }
156
157 pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
158 if let Some(connection) = self.connections.get_mut(&connection_id) {
159 connection.channels.insert(channel_id);
160 self.channels
161 .entry(channel_id)
162 .or_default()
163 .connection_ids
164 .insert(connection_id);
165 }
166 }
167
168 pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
169 if let Some(connection) = self.connections.get_mut(&connection_id) {
170 connection.channels.remove(&channel_id);
171 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
172 entry.get_mut().connection_ids.remove(&connection_id);
173 if entry.get_mut().connection_ids.is_empty() {
174 entry.remove();
175 }
176 }
177 }
178 }
179
180 pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> Result<UserId> {
181 Ok(self
182 .connections
183 .get(&connection_id)
184 .ok_or_else(|| anyhow!("unknown connection"))?
185 .user_id)
186 }
187
188 pub fn connection_ids_for_user<'a>(
189 &'a self,
190 user_id: UserId,
191 ) -> impl 'a + Iterator<Item = ConnectionId> {
192 self.connections_by_user_id
193 .get(&user_id)
194 .into_iter()
195 .flatten()
196 .copied()
197 }
198
199 pub fn is_user_online(&self, user_id: UserId) -> bool {
200 !self
201 .connections_by_user_id
202 .get(&user_id)
203 .unwrap_or(&Default::default())
204 .is_empty()
205 }
206
207 pub fn build_initial_contacts_update(
208 &self,
209 contacts: Vec<db::Contact>,
210 ) -> proto::UpdateContacts {
211 let mut update = proto::UpdateContacts::default();
212
213 for contact in contacts {
214 match contact {
215 db::Contact::Accepted {
216 user_id,
217 should_notify,
218 } => {
219 update
220 .contacts
221 .push(self.contact_for_user(user_id, should_notify));
222 }
223 db::Contact::Outgoing { user_id } => {
224 update.outgoing_requests.push(user_id.to_proto())
225 }
226 db::Contact::Incoming {
227 user_id,
228 should_notify,
229 } => update
230 .incoming_requests
231 .push(proto::IncomingContactRequest {
232 requester_id: user_id.to_proto(),
233 should_notify,
234 }),
235 }
236 }
237
238 update
239 }
240
241 pub fn contact_for_user(&self, user_id: UserId, should_notify: bool) -> proto::Contact {
242 proto::Contact {
243 user_id: user_id.to_proto(),
244 projects: self.project_metadata_for_user(user_id),
245 online: self.is_user_online(user_id),
246 should_notify,
247 }
248 }
249
250 pub fn project_metadata_for_user(&self, user_id: UserId) -> Vec<proto::ProjectMetadata> {
251 let connection_ids = self.connections_by_user_id.get(&user_id);
252 let project_ids = connection_ids.iter().flat_map(|connection_ids| {
253 connection_ids
254 .iter()
255 .filter_map(|connection_id| self.connections.get(connection_id))
256 .flat_map(|connection| connection.projects.iter().copied())
257 });
258
259 let mut metadata = Vec::new();
260 for project_id in project_ids {
261 if let Some(project) = self.projects.get(&project_id) {
262 if project.host_user_id == user_id {
263 metadata.push(proto::ProjectMetadata {
264 id: project_id,
265 worktree_root_names: project
266 .worktrees
267 .values()
268 .map(|worktree| worktree.root_name.clone())
269 .collect(),
270 guests: project
271 .guests
272 .values()
273 .map(|(_, user_id)| user_id.to_proto())
274 .collect(),
275 });
276 }
277 }
278 }
279
280 metadata
281 }
282
283 pub fn register_project(
284 &mut self,
285 host_connection_id: ConnectionId,
286 host_user_id: UserId,
287 ) -> u64 {
288 let project_id = self.next_project_id;
289 self.projects.insert(
290 project_id,
291 Project {
292 host_connection_id,
293 host_user_id,
294 guests: Default::default(),
295 join_requests: Default::default(),
296 active_replica_ids: Default::default(),
297 worktrees: Default::default(),
298 language_servers: Default::default(),
299 },
300 );
301 if let Some(connection) = self.connections.get_mut(&host_connection_id) {
302 connection.projects.insert(project_id);
303 }
304 self.next_project_id += 1;
305 project_id
306 }
307
308 pub fn register_worktree(
309 &mut self,
310 project_id: u64,
311 worktree_id: u64,
312 connection_id: ConnectionId,
313 worktree: Worktree,
314 ) -> Result<()> {
315 let project = self
316 .projects
317 .get_mut(&project_id)
318 .ok_or_else(|| anyhow!("no such project"))?;
319 if project.host_connection_id == connection_id {
320 project.worktrees.insert(worktree_id, worktree);
321 Ok(())
322 } else {
323 Err(anyhow!("no such project"))?
324 }
325 }
326
327 pub fn unregister_project(
328 &mut self,
329 project_id: u64,
330 connection_id: ConnectionId,
331 ) -> Result<Project> {
332 match self.projects.entry(project_id) {
333 hash_map::Entry::Occupied(e) => {
334 if e.get().host_connection_id == connection_id {
335 let project = e.remove();
336
337 if let Some(host_connection) = self.connections.get_mut(&connection_id) {
338 host_connection.projects.remove(&project_id);
339 }
340
341 for guest_connection in project.guests.keys() {
342 if let Some(connection) = self.connections.get_mut(&guest_connection) {
343 connection.projects.remove(&project_id);
344 }
345 }
346
347 for requester_user_id in project.join_requests.keys() {
348 if let Some(requester_connection_ids) =
349 self.connections_by_user_id.get_mut(&requester_user_id)
350 {
351 for requester_connection_id in requester_connection_ids.iter() {
352 if let Some(requester_connection) =
353 self.connections.get_mut(requester_connection_id)
354 {
355 requester_connection.requested_projects.remove(&project_id);
356 }
357 }
358 }
359 }
360
361 Ok(project)
362 } else {
363 Err(anyhow!("no such project"))?
364 }
365 }
366 hash_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?,
367 }
368 }
369
370 pub fn unregister_worktree(
371 &mut self,
372 project_id: u64,
373 worktree_id: u64,
374 acting_connection_id: ConnectionId,
375 ) -> Result<(Worktree, Vec<ConnectionId>)> {
376 let project = self
377 .projects
378 .get_mut(&project_id)
379 .ok_or_else(|| anyhow!("no such project"))?;
380 if project.host_connection_id != acting_connection_id {
381 Err(anyhow!("not your worktree"))?;
382 }
383
384 let worktree = project
385 .worktrees
386 .remove(&worktree_id)
387 .ok_or_else(|| anyhow!("no such worktree"))?;
388 Ok((worktree, project.guest_connection_ids()))
389 }
390
391 pub fn update_diagnostic_summary(
392 &mut self,
393 project_id: u64,
394 worktree_id: u64,
395 connection_id: ConnectionId,
396 summary: proto::DiagnosticSummary,
397 ) -> Result<Vec<ConnectionId>> {
398 let project = self
399 .projects
400 .get_mut(&project_id)
401 .ok_or_else(|| anyhow!("no such project"))?;
402 if project.host_connection_id == connection_id {
403 let worktree = project
404 .worktrees
405 .get_mut(&worktree_id)
406 .ok_or_else(|| anyhow!("no such worktree"))?;
407 worktree
408 .diagnostic_summaries
409 .insert(summary.path.clone().into(), summary);
410 return Ok(project.connection_ids());
411 }
412
413 Err(anyhow!("no such worktree"))?
414 }
415
416 pub fn start_language_server(
417 &mut self,
418 project_id: u64,
419 connection_id: ConnectionId,
420 language_server: proto::LanguageServer,
421 ) -> Result<Vec<ConnectionId>> {
422 let project = self
423 .projects
424 .get_mut(&project_id)
425 .ok_or_else(|| anyhow!("no such project"))?;
426 if project.host_connection_id == connection_id {
427 project.language_servers.push(language_server);
428 return Ok(project.connection_ids());
429 }
430
431 Err(anyhow!("no such project"))?
432 }
433
434 pub fn request_join_project(
435 &mut self,
436 requester_id: UserId,
437 project_id: u64,
438 receipt: Receipt<proto::JoinProject>,
439 ) -> Result<()> {
440 let connection = self
441 .connections
442 .get_mut(&receipt.sender_id)
443 .ok_or_else(|| anyhow!("no such connection"))?;
444 let project = self
445 .projects
446 .get_mut(&project_id)
447 .ok_or_else(|| anyhow!("no such project"))?;
448 connection.requested_projects.insert(project_id);
449 project
450 .join_requests
451 .entry(requester_id)
452 .or_default()
453 .push(receipt);
454 Ok(())
455 }
456
457 pub fn deny_join_project_request(
458 &mut self,
459 responder_connection_id: ConnectionId,
460 requester_id: UserId,
461 project_id: u64,
462 ) -> Option<Vec<Receipt<proto::JoinProject>>> {
463 let project = self.projects.get_mut(&project_id)?;
464 if responder_connection_id != project.host_connection_id {
465 return None;
466 }
467
468 let receipts = project.join_requests.remove(&requester_id)?;
469 for receipt in &receipts {
470 let requester_connection = self.connections.get_mut(&receipt.sender_id)?;
471 requester_connection.requested_projects.remove(&project_id);
472 }
473 Some(receipts)
474 }
475
476 pub fn accept_join_project_request(
477 &mut self,
478 responder_connection_id: ConnectionId,
479 requester_id: UserId,
480 project_id: u64,
481 ) -> Option<(Vec<(Receipt<proto::JoinProject>, ReplicaId)>, &Project)> {
482 let project = self.projects.get_mut(&project_id)?;
483 if responder_connection_id != project.host_connection_id {
484 return None;
485 }
486
487 let receipts = project.join_requests.remove(&requester_id)?;
488 let mut receipts_with_replica_ids = Vec::new();
489 for receipt in receipts {
490 let requester_connection = self.connections.get_mut(&receipt.sender_id)?;
491 requester_connection.requested_projects.remove(&project_id);
492 requester_connection.projects.insert(project_id);
493 let mut replica_id = 1;
494 while project.active_replica_ids.contains(&replica_id) {
495 replica_id += 1;
496 }
497 project.active_replica_ids.insert(replica_id);
498 project
499 .guests
500 .insert(receipt.sender_id, (replica_id, requester_id));
501 receipts_with_replica_ids.push((receipt, replica_id));
502 }
503
504 Some((receipts_with_replica_ids, project))
505 }
506
507 pub fn leave_project(
508 &mut self,
509 connection_id: ConnectionId,
510 project_id: u64,
511 ) -> Result<LeftProject> {
512 let user_id = self.user_id_for_connection(connection_id)?;
513 let project = self
514 .projects
515 .get_mut(&project_id)
516 .ok_or_else(|| anyhow!("no such project"))?;
517
518 // If the connection leaving the project is a collaborator, remove it.
519 let remove_collaborator =
520 if let Some((replica_id, _)) = project.guests.remove(&connection_id) {
521 project.active_replica_ids.remove(&replica_id);
522 true
523 } else {
524 false
525 };
526
527 // If the connection leaving the project has a pending request, remove it.
528 // If that user has no other pending requests on other connections, indicate that the request should be cancelled.
529 let mut cancel_request = None;
530 if let Entry::Occupied(mut entry) = project.join_requests.entry(user_id) {
531 entry
532 .get_mut()
533 .retain(|receipt| receipt.sender_id != connection_id);
534 if entry.get().is_empty() {
535 entry.remove();
536 cancel_request = Some(user_id);
537 }
538 }
539
540 if let Some(connection) = self.connections.get_mut(&connection_id) {
541 connection.projects.remove(&project_id);
542 }
543
544 let connection_ids = project.connection_ids();
545 let unshare = connection_ids.len() <= 1 && project.join_requests.is_empty();
546
547 Ok(LeftProject {
548 host_connection_id: project.host_connection_id,
549 host_user_id: project.host_user_id,
550 connection_ids,
551 cancel_request,
552 unshare,
553 remove_collaborator,
554 })
555 }
556
557 pub fn update_worktree(
558 &mut self,
559 connection_id: ConnectionId,
560 project_id: u64,
561 worktree_id: u64,
562 removed_entries: &[u64],
563 updated_entries: &[proto::Entry],
564 scan_id: u64,
565 ) -> Result<Vec<ConnectionId>> {
566 let project = self.write_project(project_id, connection_id)?;
567 let worktree = project
568 .worktrees
569 .get_mut(&worktree_id)
570 .ok_or_else(|| anyhow!("no such worktree"))?;
571 for entry_id in removed_entries {
572 worktree.entries.remove(&entry_id);
573 }
574 for entry in updated_entries {
575 worktree.entries.insert(entry.id, entry.clone());
576 }
577 worktree.scan_id = scan_id;
578 let connection_ids = project.connection_ids();
579 Ok(connection_ids)
580 }
581
582 pub fn project_connection_ids(
583 &self,
584 project_id: u64,
585 acting_connection_id: ConnectionId,
586 ) -> Result<Vec<ConnectionId>> {
587 Ok(self
588 .read_project(project_id, acting_connection_id)?
589 .connection_ids())
590 }
591
592 pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Result<Vec<ConnectionId>> {
593 Ok(self
594 .channels
595 .get(&channel_id)
596 .ok_or_else(|| anyhow!("no such channel"))?
597 .connection_ids())
598 }
599
600 pub fn project(&self, project_id: u64) -> Result<&Project> {
601 self.projects
602 .get(&project_id)
603 .ok_or_else(|| anyhow!("no such project"))
604 }
605
606 pub fn read_project(&self, project_id: u64, connection_id: ConnectionId) -> Result<&Project> {
607 let project = self
608 .projects
609 .get(&project_id)
610 .ok_or_else(|| anyhow!("no such project"))?;
611 if project.host_connection_id == connection_id
612 || project.guests.contains_key(&connection_id)
613 {
614 Ok(project)
615 } else {
616 Err(anyhow!("no such project"))?
617 }
618 }
619
620 fn write_project(
621 &mut self,
622 project_id: u64,
623 connection_id: ConnectionId,
624 ) -> Result<&mut Project> {
625 let project = self
626 .projects
627 .get_mut(&project_id)
628 .ok_or_else(|| anyhow!("no such project"))?;
629 if project.host_connection_id == connection_id
630 || project.guests.contains_key(&connection_id)
631 {
632 Ok(project)
633 } else {
634 Err(anyhow!("no such project"))?
635 }
636 }
637
638 #[cfg(test)]
639 pub fn check_invariants(&self) {
640 for (connection_id, connection) in &self.connections {
641 for project_id in &connection.projects {
642 let project = &self.projects.get(&project_id).unwrap();
643 if project.host_connection_id != *connection_id {
644 assert!(project.guests.contains_key(connection_id));
645 }
646
647 for (worktree_id, worktree) in project.worktrees.iter() {
648 let mut paths = HashMap::default();
649 for entry in worktree.entries.values() {
650 let prev_entry = paths.insert(&entry.path, entry);
651 assert_eq!(
652 prev_entry,
653 None,
654 "worktree {:?}, duplicate path for entries {:?} and {:?}",
655 worktree_id,
656 prev_entry.unwrap(),
657 entry
658 );
659 }
660 }
661 }
662 for channel_id in &connection.channels {
663 let channel = self.channels.get(channel_id).unwrap();
664 assert!(channel.connection_ids.contains(connection_id));
665 }
666 assert!(self
667 .connections_by_user_id
668 .get(&connection.user_id)
669 .unwrap()
670 .contains(connection_id));
671 }
672
673 for (user_id, connection_ids) in &self.connections_by_user_id {
674 for connection_id in connection_ids {
675 assert_eq!(
676 self.connections.get(connection_id).unwrap().user_id,
677 *user_id
678 );
679 }
680 }
681
682 for (project_id, project) in &self.projects {
683 let host_connection = self.connections.get(&project.host_connection_id).unwrap();
684 assert!(host_connection.projects.contains(project_id));
685
686 for guest_connection_id in project.guests.keys() {
687 let guest_connection = self.connections.get(guest_connection_id).unwrap();
688 assert!(guest_connection.projects.contains(project_id));
689 }
690 assert_eq!(project.active_replica_ids.len(), project.guests.len(),);
691 assert_eq!(
692 project.active_replica_ids,
693 project
694 .guests
695 .values()
696 .map(|(replica_id, _)| *replica_id)
697 .collect::<HashSet<_>>(),
698 );
699 }
700
701 for (channel_id, channel) in &self.channels {
702 for connection_id in &channel.connection_ids {
703 let connection = self.connections.get(connection_id).unwrap();
704 assert!(connection.channels.contains(channel_id));
705 }
706 }
707 }
708}
709
710impl Project {
711 pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
712 self.guests.keys().copied().collect()
713 }
714
715 pub fn connection_ids(&self) -> Vec<ConnectionId> {
716 self.guests
717 .keys()
718 .copied()
719 .chain(Some(self.host_connection_id))
720 .collect()
721 }
722}
723
724impl Channel {
725 fn connection_ids(&self) -> Vec<ConnectionId> {
726 self.connection_ids.iter().copied().collect()
727 }
728}