1use crate::db::{ChannelId, UserId};
2use anyhow::anyhow;
3use collections::{BTreeMap, HashMap, HashSet};
4use rpc::{proto, ConnectionId};
5use std::{collections::hash_map, path::PathBuf};
6
7#[derive(Default)]
8pub struct Store {
9 connections: HashMap<ConnectionId, ConnectionState>,
10 connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
11 projects: HashMap<u64, Project>,
12 visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
13 channels: HashMap<ChannelId, Channel>,
14 next_project_id: u64,
15}
16
17struct ConnectionState {
18 user_id: UserId,
19 projects: HashSet<u64>,
20 channels: HashSet<ChannelId>,
21}
22
23pub struct Project {
24 pub host_connection_id: ConnectionId,
25 pub host_user_id: UserId,
26 pub share: Option<ProjectShare>,
27 pub worktrees: HashMap<u64, Worktree>,
28 pub language_servers: Vec<proto::LanguageServer>,
29}
30
31pub struct Worktree {
32 pub authorized_user_ids: Vec<UserId>,
33 pub root_name: String,
34 pub visible: bool,
35}
36
37#[derive(Default)]
38pub struct ProjectShare {
39 pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
40 pub active_replica_ids: HashSet<ReplicaId>,
41 pub worktrees: HashMap<u64, WorktreeShare>,
42}
43
44#[derive(Default)]
45pub struct WorktreeShare {
46 pub entries: HashMap<u64, proto::Entry>,
47 pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
48}
49
50#[derive(Default)]
51pub struct Channel {
52 pub connection_ids: HashSet<ConnectionId>,
53}
54
55pub type ReplicaId = u16;
56
57#[derive(Default)]
58pub struct RemovedConnectionState {
59 pub hosted_projects: HashMap<u64, Project>,
60 pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
61 pub contact_ids: HashSet<UserId>,
62}
63
64pub struct JoinedProject<'a> {
65 pub replica_id: ReplicaId,
66 pub project: &'a Project,
67}
68
69pub struct UnsharedProject {
70 pub connection_ids: Vec<ConnectionId>,
71 pub authorized_user_ids: Vec<UserId>,
72}
73
74pub struct LeftProject {
75 pub connection_ids: Vec<ConnectionId>,
76 pub authorized_user_ids: Vec<UserId>,
77}
78
79impl Store {
80 pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
81 self.connections.insert(
82 connection_id,
83 ConnectionState {
84 user_id,
85 projects: Default::default(),
86 channels: Default::default(),
87 },
88 );
89 self.connections_by_user_id
90 .entry(user_id)
91 .or_default()
92 .insert(connection_id);
93 }
94
95 pub fn remove_connection(
96 &mut self,
97 connection_id: ConnectionId,
98 ) -> tide::Result<RemovedConnectionState> {
99 let connection = if let Some(connection) = self.connections.remove(&connection_id) {
100 connection
101 } else {
102 return Err(anyhow!("no such connection"))?;
103 };
104
105 for channel_id in &connection.channels {
106 if let Some(channel) = self.channels.get_mut(&channel_id) {
107 channel.connection_ids.remove(&connection_id);
108 }
109 }
110
111 let user_connections = self
112 .connections_by_user_id
113 .get_mut(&connection.user_id)
114 .unwrap();
115 user_connections.remove(&connection_id);
116 if user_connections.is_empty() {
117 self.connections_by_user_id.remove(&connection.user_id);
118 }
119
120 let mut result = RemovedConnectionState::default();
121 for project_id in connection.projects.clone() {
122 if let Ok(project) = self.unregister_project(project_id, connection_id) {
123 result.contact_ids.extend(project.authorized_user_ids());
124 result.hosted_projects.insert(project_id, project);
125 } else if let Ok(project) = self.leave_project(connection_id, project_id) {
126 result
127 .guest_project_ids
128 .insert(project_id, project.connection_ids);
129 result.contact_ids.extend(project.authorized_user_ids);
130 }
131 }
132
133 #[cfg(test)]
134 self.check_invariants();
135
136 Ok(result)
137 }
138
139 #[cfg(test)]
140 pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
141 self.channels.get(&id)
142 }
143
144 pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
145 if let Some(connection) = self.connections.get_mut(&connection_id) {
146 connection.channels.insert(channel_id);
147 self.channels
148 .entry(channel_id)
149 .or_default()
150 .connection_ids
151 .insert(connection_id);
152 }
153 }
154
155 pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
156 if let Some(connection) = self.connections.get_mut(&connection_id) {
157 connection.channels.remove(&channel_id);
158 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
159 entry.get_mut().connection_ids.remove(&connection_id);
160 if entry.get_mut().connection_ids.is_empty() {
161 entry.remove();
162 }
163 }
164 }
165 }
166
167 pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
168 Ok(self
169 .connections
170 .get(&connection_id)
171 .ok_or_else(|| anyhow!("unknown connection"))?
172 .user_id)
173 }
174
175 pub fn connection_ids_for_user<'a>(
176 &'a self,
177 user_id: UserId,
178 ) -> impl 'a + Iterator<Item = ConnectionId> {
179 self.connections_by_user_id
180 .get(&user_id)
181 .into_iter()
182 .flatten()
183 .copied()
184 }
185
186 pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
187 let mut contacts = HashMap::default();
188 for project_id in self
189 .visible_projects_by_user_id
190 .get(&user_id)
191 .unwrap_or(&HashSet::default())
192 {
193 let project = &self.projects[project_id];
194
195 let mut guests = HashSet::default();
196 if let Ok(share) = project.share() {
197 for guest_connection_id in share.guests.keys() {
198 if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
199 guests.insert(user_id.to_proto());
200 }
201 }
202 }
203
204 if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
205 let mut worktree_root_names = project
206 .worktrees
207 .values()
208 .filter(|worktree| worktree.visible)
209 .map(|worktree| worktree.root_name.clone())
210 .collect::<Vec<_>>();
211 worktree_root_names.sort_unstable();
212 contacts
213 .entry(host_user_id)
214 .or_insert_with(|| proto::Contact {
215 user_id: host_user_id.to_proto(),
216 projects: Vec::new(),
217 })
218 .projects
219 .push(proto::ProjectMetadata {
220 id: *project_id,
221 worktree_root_names,
222 is_shared: project.share.is_some(),
223 guests: guests.into_iter().collect(),
224 });
225 }
226 }
227
228 contacts.into_values().collect()
229 }
230
231 pub fn register_project(
232 &mut self,
233 host_connection_id: ConnectionId,
234 host_user_id: UserId,
235 ) -> u64 {
236 let project_id = self.next_project_id;
237 self.projects.insert(
238 project_id,
239 Project {
240 host_connection_id,
241 host_user_id,
242 share: None,
243 worktrees: Default::default(),
244 language_servers: Default::default(),
245 },
246 );
247 if let Some(connection) = self.connections.get_mut(&host_connection_id) {
248 connection.projects.insert(project_id);
249 }
250 self.next_project_id += 1;
251 project_id
252 }
253
254 pub fn register_worktree(
255 &mut self,
256 project_id: u64,
257 worktree_id: u64,
258 connection_id: ConnectionId,
259 worktree: Worktree,
260 ) -> tide::Result<()> {
261 let project = self
262 .projects
263 .get_mut(&project_id)
264 .ok_or_else(|| anyhow!("no such project"))?;
265 if project.host_connection_id == connection_id {
266 for authorized_user_id in &worktree.authorized_user_ids {
267 self.visible_projects_by_user_id
268 .entry(*authorized_user_id)
269 .or_default()
270 .insert(project_id);
271 }
272
273 project.worktrees.insert(worktree_id, worktree);
274 if let Ok(share) = project.share_mut() {
275 share.worktrees.insert(worktree_id, Default::default());
276 }
277
278 #[cfg(test)]
279 self.check_invariants();
280 Ok(())
281 } else {
282 Err(anyhow!("no such project"))?
283 }
284 }
285
286 pub fn unregister_project(
287 &mut self,
288 project_id: u64,
289 connection_id: ConnectionId,
290 ) -> tide::Result<Project> {
291 match self.projects.entry(project_id) {
292 hash_map::Entry::Occupied(e) => {
293 if e.get().host_connection_id == connection_id {
294 for user_id in e.get().authorized_user_ids() {
295 if let hash_map::Entry::Occupied(mut projects) =
296 self.visible_projects_by_user_id.entry(user_id)
297 {
298 projects.get_mut().remove(&project_id);
299 }
300 }
301
302 let project = e.remove();
303
304 if let Some(host_connection) = self.connections.get_mut(&connection_id) {
305 host_connection.projects.remove(&project_id);
306 }
307
308 if let Some(share) = &project.share {
309 for guest_connection in share.guests.keys() {
310 if let Some(connection) = self.connections.get_mut(&guest_connection) {
311 connection.projects.remove(&project_id);
312 }
313 }
314 }
315
316 #[cfg(test)]
317 self.check_invariants();
318 Ok(project)
319 } else {
320 Err(anyhow!("no such project"))?
321 }
322 }
323 hash_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?,
324 }
325 }
326
327 pub fn unregister_worktree(
328 &mut self,
329 project_id: u64,
330 worktree_id: u64,
331 acting_connection_id: ConnectionId,
332 ) -> tide::Result<(Worktree, Vec<ConnectionId>)> {
333 let project = self
334 .projects
335 .get_mut(&project_id)
336 .ok_or_else(|| anyhow!("no such project"))?;
337 if project.host_connection_id != acting_connection_id {
338 Err(anyhow!("not your worktree"))?;
339 }
340
341 let worktree = project
342 .worktrees
343 .remove(&worktree_id)
344 .ok_or_else(|| anyhow!("no such worktree"))?;
345
346 let mut guest_connection_ids = Vec::new();
347 if let Ok(share) = project.share_mut() {
348 guest_connection_ids.extend(share.guests.keys());
349 share.worktrees.remove(&worktree_id);
350 }
351
352 for authorized_user_id in &worktree.authorized_user_ids {
353 if let Some(visible_projects) =
354 self.visible_projects_by_user_id.get_mut(authorized_user_id)
355 {
356 if !project.has_authorized_user_id(*authorized_user_id) {
357 visible_projects.remove(&project_id);
358 }
359 }
360 }
361
362 #[cfg(test)]
363 self.check_invariants();
364
365 Ok((worktree, guest_connection_ids))
366 }
367
368 pub fn share_project(&mut self, project_id: u64, connection_id: ConnectionId) -> bool {
369 if let Some(project) = self.projects.get_mut(&project_id) {
370 if project.host_connection_id == connection_id {
371 let mut share = ProjectShare::default();
372 for worktree_id in project.worktrees.keys() {
373 share.worktrees.insert(*worktree_id, Default::default());
374 }
375 project.share = Some(share);
376 return true;
377 }
378 }
379 false
380 }
381
382 pub fn unshare_project(
383 &mut self,
384 project_id: u64,
385 acting_connection_id: ConnectionId,
386 ) -> tide::Result<UnsharedProject> {
387 let project = if let Some(project) = self.projects.get_mut(&project_id) {
388 project
389 } else {
390 return Err(anyhow!("no such project"))?;
391 };
392
393 if project.host_connection_id != acting_connection_id {
394 return Err(anyhow!("not your project"))?;
395 }
396
397 let connection_ids = project.connection_ids();
398 let authorized_user_ids = project.authorized_user_ids();
399 if let Some(share) = project.share.take() {
400 for connection_id in share.guests.into_keys() {
401 if let Some(connection) = self.connections.get_mut(&connection_id) {
402 connection.projects.remove(&project_id);
403 }
404 }
405
406 #[cfg(test)]
407 self.check_invariants();
408
409 Ok(UnsharedProject {
410 connection_ids,
411 authorized_user_ids,
412 })
413 } else {
414 Err(anyhow!("project is not shared"))?
415 }
416 }
417
418 pub fn update_diagnostic_summary(
419 &mut self,
420 project_id: u64,
421 worktree_id: u64,
422 connection_id: ConnectionId,
423 summary: proto::DiagnosticSummary,
424 ) -> tide::Result<Vec<ConnectionId>> {
425 let project = self
426 .projects
427 .get_mut(&project_id)
428 .ok_or_else(|| anyhow!("no such project"))?;
429 if project.host_connection_id == connection_id {
430 let worktree = project
431 .share_mut()?
432 .worktrees
433 .get_mut(&worktree_id)
434 .ok_or_else(|| anyhow!("no such worktree"))?;
435 worktree
436 .diagnostic_summaries
437 .insert(summary.path.clone().into(), summary);
438 return Ok(project.connection_ids());
439 }
440
441 Err(anyhow!("no such worktree"))?
442 }
443
444 pub fn start_language_server(
445 &mut self,
446 project_id: u64,
447 connection_id: ConnectionId,
448 language_server: proto::LanguageServer,
449 ) -> tide::Result<Vec<ConnectionId>> {
450 let project = self
451 .projects
452 .get_mut(&project_id)
453 .ok_or_else(|| anyhow!("no such project"))?;
454 if project.host_connection_id == connection_id {
455 project.language_servers.push(language_server);
456 return Ok(project.connection_ids());
457 }
458
459 Err(anyhow!("no such project"))?
460 }
461
462 pub fn join_project(
463 &mut self,
464 connection_id: ConnectionId,
465 user_id: UserId,
466 project_id: u64,
467 ) -> tide::Result<JoinedProject> {
468 let connection = self
469 .connections
470 .get_mut(&connection_id)
471 .ok_or_else(|| anyhow!("no such connection"))?;
472 let project = self
473 .projects
474 .get_mut(&project_id)
475 .and_then(|project| {
476 if project.has_authorized_user_id(user_id) {
477 Some(project)
478 } else {
479 None
480 }
481 })
482 .ok_or_else(|| anyhow!("no such project"))?;
483
484 let share = project.share_mut()?;
485 connection.projects.insert(project_id);
486
487 let mut replica_id = 1;
488 while share.active_replica_ids.contains(&replica_id) {
489 replica_id += 1;
490 }
491 share.active_replica_ids.insert(replica_id);
492 share.guests.insert(connection_id, (replica_id, user_id));
493
494 #[cfg(test)]
495 self.check_invariants();
496
497 Ok(JoinedProject {
498 replica_id,
499 project: &self.projects[&project_id],
500 })
501 }
502
503 pub fn leave_project(
504 &mut self,
505 connection_id: ConnectionId,
506 project_id: u64,
507 ) -> tide::Result<LeftProject> {
508 let project = self
509 .projects
510 .get_mut(&project_id)
511 .ok_or_else(|| anyhow!("no such project"))?;
512 let share = project
513 .share
514 .as_mut()
515 .ok_or_else(|| anyhow!("project is not shared"))?;
516 let (replica_id, _) = share
517 .guests
518 .remove(&connection_id)
519 .ok_or_else(|| anyhow!("cannot leave a project before joining it"))?;
520 share.active_replica_ids.remove(&replica_id);
521
522 if let Some(connection) = self.connections.get_mut(&connection_id) {
523 connection.projects.remove(&project_id);
524 }
525
526 let connection_ids = project.connection_ids();
527 let authorized_user_ids = project.authorized_user_ids();
528
529 #[cfg(test)]
530 self.check_invariants();
531
532 Ok(LeftProject {
533 connection_ids,
534 authorized_user_ids,
535 })
536 }
537
538 pub fn update_worktree(
539 &mut self,
540 connection_id: ConnectionId,
541 project_id: u64,
542 worktree_id: u64,
543 removed_entries: &[u64],
544 updated_entries: &[proto::Entry],
545 ) -> tide::Result<Vec<ConnectionId>> {
546 let project = self.write_project(project_id, connection_id)?;
547 let worktree = project
548 .share_mut()?
549 .worktrees
550 .get_mut(&worktree_id)
551 .ok_or_else(|| anyhow!("no such worktree"))?;
552 for entry_id in removed_entries {
553 worktree.entries.remove(&entry_id);
554 }
555 for entry in updated_entries {
556 worktree.entries.insert(entry.id, entry.clone());
557 }
558 let connection_ids = project.connection_ids();
559
560 #[cfg(test)]
561 self.check_invariants();
562
563 Ok(connection_ids)
564 }
565
566 pub fn project_connection_ids(
567 &self,
568 project_id: u64,
569 acting_connection_id: ConnectionId,
570 ) -> tide::Result<Vec<ConnectionId>> {
571 Ok(self
572 .read_project(project_id, acting_connection_id)?
573 .connection_ids())
574 }
575
576 pub fn channel_connection_ids(&self, channel_id: ChannelId) -> tide::Result<Vec<ConnectionId>> {
577 Ok(self
578 .channels
579 .get(&channel_id)
580 .ok_or_else(|| anyhow!("no such channel"))?
581 .connection_ids())
582 }
583
584 #[cfg(test)]
585 pub fn project(&self, project_id: u64) -> Option<&Project> {
586 self.projects.get(&project_id)
587 }
588
589 pub fn read_project(
590 &self,
591 project_id: u64,
592 connection_id: ConnectionId,
593 ) -> tide::Result<&Project> {
594 let project = self
595 .projects
596 .get(&project_id)
597 .ok_or_else(|| anyhow!("no such project"))?;
598 if project.host_connection_id == connection_id
599 || project
600 .share
601 .as_ref()
602 .ok_or_else(|| anyhow!("project is not shared"))?
603 .guests
604 .contains_key(&connection_id)
605 {
606 Ok(project)
607 } else {
608 Err(anyhow!("no such project"))?
609 }
610 }
611
612 fn write_project(
613 &mut self,
614 project_id: u64,
615 connection_id: ConnectionId,
616 ) -> tide::Result<&mut Project> {
617 let project = self
618 .projects
619 .get_mut(&project_id)
620 .ok_or_else(|| anyhow!("no such project"))?;
621 if project.host_connection_id == connection_id
622 || project
623 .share
624 .as_ref()
625 .ok_or_else(|| anyhow!("project is not shared"))?
626 .guests
627 .contains_key(&connection_id)
628 {
629 Ok(project)
630 } else {
631 Err(anyhow!("no such project"))?
632 }
633 }
634
635 #[cfg(test)]
636 fn check_invariants(&self) {
637 for (connection_id, connection) in &self.connections {
638 for project_id in &connection.projects {
639 let project = &self.projects.get(&project_id).unwrap();
640 if project.host_connection_id != *connection_id {
641 assert!(project
642 .share
643 .as_ref()
644 .unwrap()
645 .guests
646 .contains_key(connection_id));
647 }
648
649 if let Some(share) = project.share.as_ref() {
650 for (worktree_id, worktree) in share.worktrees.iter() {
651 let mut paths = HashMap::default();
652 for entry in worktree.entries.values() {
653 let prev_entry = paths.insert(&entry.path, entry);
654 assert_eq!(
655 prev_entry,
656 None,
657 "worktree {:?}, duplicate path for entries {:?} and {:?}",
658 worktree_id,
659 prev_entry.unwrap(),
660 entry
661 );
662 }
663 }
664 }
665 }
666 for channel_id in &connection.channels {
667 let channel = self.channels.get(channel_id).unwrap();
668 assert!(channel.connection_ids.contains(connection_id));
669 }
670 assert!(self
671 .connections_by_user_id
672 .get(&connection.user_id)
673 .unwrap()
674 .contains(connection_id));
675 }
676
677 for (user_id, connection_ids) in &self.connections_by_user_id {
678 for connection_id in connection_ids {
679 assert_eq!(
680 self.connections.get(connection_id).unwrap().user_id,
681 *user_id
682 );
683 }
684 }
685
686 for (project_id, project) in &self.projects {
687 let host_connection = self.connections.get(&project.host_connection_id).unwrap();
688 assert!(host_connection.projects.contains(project_id));
689
690 for authorized_user_ids in project.authorized_user_ids() {
691 let visible_project_ids = self
692 .visible_projects_by_user_id
693 .get(&authorized_user_ids)
694 .unwrap();
695 assert!(visible_project_ids.contains(project_id));
696 }
697
698 if let Some(share) = &project.share {
699 for guest_connection_id in share.guests.keys() {
700 let guest_connection = self.connections.get(guest_connection_id).unwrap();
701 assert!(guest_connection.projects.contains(project_id));
702 }
703 assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
704 assert_eq!(
705 share.active_replica_ids,
706 share
707 .guests
708 .values()
709 .map(|(replica_id, _)| *replica_id)
710 .collect::<HashSet<_>>(),
711 );
712 }
713 }
714
715 for (user_id, visible_project_ids) in &self.visible_projects_by_user_id {
716 for project_id in visible_project_ids {
717 let project = self.projects.get(project_id).unwrap();
718 assert!(project.authorized_user_ids().contains(user_id));
719 }
720 }
721
722 for (channel_id, channel) in &self.channels {
723 for connection_id in &channel.connection_ids {
724 let connection = self.connections.get(connection_id).unwrap();
725 assert!(connection.channels.contains(channel_id));
726 }
727 }
728 }
729}
730
731impl Project {
732 pub fn has_authorized_user_id(&self, user_id: UserId) -> bool {
733 self.worktrees
734 .values()
735 .any(|worktree| worktree.authorized_user_ids.contains(&user_id))
736 }
737
738 pub fn authorized_user_ids(&self) -> Vec<UserId> {
739 let mut ids = self
740 .worktrees
741 .values()
742 .flat_map(|worktree| worktree.authorized_user_ids.iter())
743 .copied()
744 .collect::<Vec<_>>();
745 ids.sort_unstable();
746 ids.dedup();
747 ids
748 }
749
750 pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
751 if let Some(share) = &self.share {
752 share.guests.keys().copied().collect()
753 } else {
754 Vec::new()
755 }
756 }
757
758 pub fn connection_ids(&self) -> Vec<ConnectionId> {
759 if let Some(share) = &self.share {
760 share
761 .guests
762 .keys()
763 .copied()
764 .chain(Some(self.host_connection_id))
765 .collect()
766 } else {
767 vec![self.host_connection_id]
768 }
769 }
770
771 pub fn share(&self) -> tide::Result<&ProjectShare> {
772 Ok(self
773 .share
774 .as_ref()
775 .ok_or_else(|| anyhow!("worktree is not shared"))?)
776 }
777
778 fn share_mut(&mut self) -> tide::Result<&mut ProjectShare> {
779 Ok(self
780 .share
781 .as_mut()
782 .ok_or_else(|| anyhow!("worktree is not shared"))?)
783 }
784}
785
786impl Channel {
787 fn connection_ids(&self) -> Vec<ConnectionId> {
788 self.connection_ids.iter().copied().collect()
789 }
790}