1use crate::db::{ChannelId, UserId};
2use anyhow::anyhow;
3use collections::{BTreeMap, HashMap, HashSet};
4use rpc::{proto, ConnectionId};
5use std::{collections::hash_map, path::PathBuf};
6
7#[derive(Default)]
8pub struct Store {
9 connections: HashMap<ConnectionId, ConnectionState>,
10 connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
11 projects: HashMap<u64, Project>,
12 visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
13 channels: HashMap<ChannelId, Channel>,
14 next_project_id: u64,
15}
16
17struct ConnectionState {
18 user_id: UserId,
19 projects: HashSet<u64>,
20 channels: HashSet<ChannelId>,
21}
22
23pub struct Project {
24 pub host_connection_id: ConnectionId,
25 pub host_user_id: UserId,
26 pub share: Option<ProjectShare>,
27 pub worktrees: HashMap<u64, Worktree>,
28 pub language_servers: Vec<proto::LanguageServer>,
29}
30
31pub struct Worktree {
32 pub authorized_user_ids: Vec<UserId>,
33 pub root_name: String,
34 pub visible: bool,
35}
36
37#[derive(Default)]
38pub struct ProjectShare {
39 pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
40 pub active_replica_ids: HashSet<ReplicaId>,
41 pub worktrees: HashMap<u64, WorktreeShare>,
42}
43
44#[derive(Default)]
45pub struct WorktreeShare {
46 pub entries: HashMap<u64, proto::Entry>,
47 pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
48}
49
50#[derive(Default)]
51pub struct Channel {
52 pub connection_ids: HashSet<ConnectionId>,
53}
54
55pub type ReplicaId = u16;
56
57#[derive(Default)]
58pub struct RemovedConnectionState {
59 pub hosted_projects: HashMap<u64, Project>,
60 pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
61 pub contact_ids: HashSet<UserId>,
62}
63
64pub struct JoinedProject<'a> {
65 pub replica_id: ReplicaId,
66 pub project: &'a Project,
67}
68
69pub struct UnsharedProject {
70 pub connection_ids: Vec<ConnectionId>,
71 pub authorized_user_ids: Vec<UserId>,
72}
73
74pub struct LeftProject {
75 pub connection_ids: Vec<ConnectionId>,
76 pub authorized_user_ids: Vec<UserId>,
77}
78
79impl Store {
80 pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
81 self.connections.insert(
82 connection_id,
83 ConnectionState {
84 user_id,
85 projects: Default::default(),
86 channels: Default::default(),
87 },
88 );
89 self.connections_by_user_id
90 .entry(user_id)
91 .or_default()
92 .insert(connection_id);
93 }
94
95 pub fn remove_connection(
96 &mut self,
97 connection_id: ConnectionId,
98 ) -> tide::Result<RemovedConnectionState> {
99 let connection = if let Some(connection) = self.connections.remove(&connection_id) {
100 connection
101 } else {
102 return Err(anyhow!("no such connection"))?;
103 };
104
105 for channel_id in &connection.channels {
106 if let Some(channel) = self.channels.get_mut(&channel_id) {
107 channel.connection_ids.remove(&connection_id);
108 }
109 }
110
111 let user_connections = self
112 .connections_by_user_id
113 .get_mut(&connection.user_id)
114 .unwrap();
115 user_connections.remove(&connection_id);
116 if user_connections.is_empty() {
117 self.connections_by_user_id.remove(&connection.user_id);
118 }
119
120 let mut result = RemovedConnectionState::default();
121 for project_id in connection.projects.clone() {
122 if let Ok(project) = self.unregister_project(project_id, connection_id) {
123 result.contact_ids.extend(project.authorized_user_ids());
124 result.hosted_projects.insert(project_id, project);
125 } else if let Ok(project) = self.leave_project(connection_id, project_id) {
126 result
127 .guest_project_ids
128 .insert(project_id, project.connection_ids);
129 result.contact_ids.extend(project.authorized_user_ids);
130 }
131 }
132
133 #[cfg(test)]
134 self.check_invariants();
135
136 Ok(result)
137 }
138
139 #[cfg(test)]
140 pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
141 self.channels.get(&id)
142 }
143
144 pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
145 if let Some(connection) = self.connections.get_mut(&connection_id) {
146 connection.channels.insert(channel_id);
147 self.channels
148 .entry(channel_id)
149 .or_default()
150 .connection_ids
151 .insert(connection_id);
152 }
153 }
154
155 pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
156 if let Some(connection) = self.connections.get_mut(&connection_id) {
157 connection.channels.remove(&channel_id);
158 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
159 entry.get_mut().connection_ids.remove(&connection_id);
160 if entry.get_mut().connection_ids.is_empty() {
161 entry.remove();
162 }
163 }
164 }
165 }
166
167 pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
168 Ok(self
169 .connections
170 .get(&connection_id)
171 .ok_or_else(|| anyhow!("unknown connection"))?
172 .user_id)
173 }
174
175 pub fn connection_ids_for_user<'a>(
176 &'a self,
177 user_id: UserId,
178 ) -> impl 'a + Iterator<Item = ConnectionId> {
179 self.connections_by_user_id
180 .get(&user_id)
181 .into_iter()
182 .flatten()
183 .copied()
184 }
185
186 pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
187 let mut contacts = HashMap::default();
188 for project_id in self
189 .visible_projects_by_user_id
190 .get(&user_id)
191 .unwrap_or(&HashSet::default())
192 {
193 let project = &self.projects[project_id];
194
195 let mut guests = HashSet::default();
196 if let Ok(share) = project.share() {
197 for guest_connection_id in share.guests.keys() {
198 if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
199 guests.insert(user_id.to_proto());
200 }
201 }
202 }
203
204 if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
205 let mut worktree_root_names = project
206 .worktrees
207 .values()
208 .filter(|worktree| worktree.visible)
209 .map(|worktree| worktree.root_name.clone())
210 .collect::<Vec<_>>();
211 worktree_root_names.sort_unstable();
212 contacts
213 .entry(host_user_id)
214 .or_insert_with(|| proto::Contact {
215 user_id: host_user_id.to_proto(),
216 projects: Vec::new(),
217 })
218 .projects
219 .push(proto::ProjectMetadata {
220 id: *project_id,
221 worktree_root_names,
222 is_shared: project.share.is_some(),
223 guests: guests.into_iter().collect(),
224 });
225 }
226 }
227
228 contacts.into_values().collect()
229 }
230
231 pub fn register_project(
232 &mut self,
233 host_connection_id: ConnectionId,
234 host_user_id: UserId,
235 ) -> u64 {
236 let project_id = self.next_project_id;
237 self.projects.insert(
238 project_id,
239 Project {
240 host_connection_id,
241 host_user_id,
242 share: None,
243 worktrees: Default::default(),
244 language_servers: Default::default(),
245 },
246 );
247 self.next_project_id += 1;
248 project_id
249 }
250
251 pub fn register_worktree(
252 &mut self,
253 project_id: u64,
254 worktree_id: u64,
255 connection_id: ConnectionId,
256 worktree: Worktree,
257 ) -> tide::Result<()> {
258 let project = self
259 .projects
260 .get_mut(&project_id)
261 .ok_or_else(|| anyhow!("no such project"))?;
262 if project.host_connection_id == connection_id {
263 for authorized_user_id in &worktree.authorized_user_ids {
264 self.visible_projects_by_user_id
265 .entry(*authorized_user_id)
266 .or_default()
267 .insert(project_id);
268 }
269 if let Some(connection) = self.connections.get_mut(&project.host_connection_id) {
270 connection.projects.insert(project_id);
271 }
272 project.worktrees.insert(worktree_id, worktree);
273 if let Ok(share) = project.share_mut() {
274 share.worktrees.insert(worktree_id, Default::default());
275 }
276
277 #[cfg(test)]
278 self.check_invariants();
279 Ok(())
280 } else {
281 Err(anyhow!("no such project"))?
282 }
283 }
284
285 pub fn unregister_project(
286 &mut self,
287 project_id: u64,
288 connection_id: ConnectionId,
289 ) -> tide::Result<Project> {
290 match self.projects.entry(project_id) {
291 hash_map::Entry::Occupied(e) => {
292 if e.get().host_connection_id == connection_id {
293 for user_id in e.get().authorized_user_ids() {
294 if let hash_map::Entry::Occupied(mut projects) =
295 self.visible_projects_by_user_id.entry(user_id)
296 {
297 projects.get_mut().remove(&project_id);
298 }
299 }
300
301 let project = e.remove();
302
303 if let Some(host_connection) = self.connections.get_mut(&connection_id) {
304 host_connection.projects.remove(&project_id);
305 }
306
307 if let Some(share) = &project.share {
308 for guest_connection in share.guests.keys() {
309 if let Some(connection) = self.connections.get_mut(&guest_connection) {
310 connection.projects.remove(&project_id);
311 }
312 }
313 }
314
315 #[cfg(test)]
316 self.check_invariants();
317 Ok(project)
318 } else {
319 Err(anyhow!("no such project"))?
320 }
321 }
322 hash_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?,
323 }
324 }
325
326 pub fn unregister_worktree(
327 &mut self,
328 project_id: u64,
329 worktree_id: u64,
330 acting_connection_id: ConnectionId,
331 ) -> tide::Result<(Worktree, Vec<ConnectionId>)> {
332 let project = self
333 .projects
334 .get_mut(&project_id)
335 .ok_or_else(|| anyhow!("no such project"))?;
336 if project.host_connection_id != acting_connection_id {
337 Err(anyhow!("not your worktree"))?;
338 }
339
340 let worktree = project
341 .worktrees
342 .remove(&worktree_id)
343 .ok_or_else(|| anyhow!("no such worktree"))?;
344
345 let mut guest_connection_ids = Vec::new();
346 if let Ok(share) = project.share_mut() {
347 guest_connection_ids.extend(share.guests.keys());
348 share.worktrees.remove(&worktree_id);
349 }
350
351 for authorized_user_id in &worktree.authorized_user_ids {
352 if let Some(visible_projects) =
353 self.visible_projects_by_user_id.get_mut(authorized_user_id)
354 {
355 if !project.has_authorized_user_id(*authorized_user_id) {
356 visible_projects.remove(&project_id);
357 }
358 }
359 }
360
361 #[cfg(test)]
362 self.check_invariants();
363
364 Ok((worktree, guest_connection_ids))
365 }
366
367 pub fn share_project(&mut self, project_id: u64, connection_id: ConnectionId) -> bool {
368 if let Some(project) = self.projects.get_mut(&project_id) {
369 if project.host_connection_id == connection_id {
370 let mut share = ProjectShare::default();
371 for worktree_id in project.worktrees.keys() {
372 share.worktrees.insert(*worktree_id, Default::default());
373 }
374 project.share = Some(share);
375 return true;
376 }
377 }
378 false
379 }
380
381 pub fn unshare_project(
382 &mut self,
383 project_id: u64,
384 acting_connection_id: ConnectionId,
385 ) -> tide::Result<UnsharedProject> {
386 let project = if let Some(project) = self.projects.get_mut(&project_id) {
387 project
388 } else {
389 return Err(anyhow!("no such project"))?;
390 };
391
392 if project.host_connection_id != acting_connection_id {
393 return Err(anyhow!("not your project"))?;
394 }
395
396 let connection_ids = project.connection_ids();
397 let authorized_user_ids = project.authorized_user_ids();
398 if let Some(share) = project.share.take() {
399 for connection_id in share.guests.into_keys() {
400 if let Some(connection) = self.connections.get_mut(&connection_id) {
401 connection.projects.remove(&project_id);
402 }
403 }
404
405 #[cfg(test)]
406 self.check_invariants();
407
408 Ok(UnsharedProject {
409 connection_ids,
410 authorized_user_ids,
411 })
412 } else {
413 Err(anyhow!("project is not shared"))?
414 }
415 }
416
417 pub fn update_diagnostic_summary(
418 &mut self,
419 project_id: u64,
420 worktree_id: u64,
421 connection_id: ConnectionId,
422 summary: proto::DiagnosticSummary,
423 ) -> tide::Result<Vec<ConnectionId>> {
424 let project = self
425 .projects
426 .get_mut(&project_id)
427 .ok_or_else(|| anyhow!("no such project"))?;
428 if project.host_connection_id == connection_id {
429 let worktree = project
430 .share_mut()?
431 .worktrees
432 .get_mut(&worktree_id)
433 .ok_or_else(|| anyhow!("no such worktree"))?;
434 worktree
435 .diagnostic_summaries
436 .insert(summary.path.clone().into(), summary);
437 return Ok(project.connection_ids());
438 }
439
440 Err(anyhow!("no such worktree"))?
441 }
442
443 pub fn start_language_server(
444 &mut self,
445 project_id: u64,
446 connection_id: ConnectionId,
447 language_server: proto::LanguageServer,
448 ) -> tide::Result<Vec<ConnectionId>> {
449 let project = self
450 .projects
451 .get_mut(&project_id)
452 .ok_or_else(|| anyhow!("no such project"))?;
453 if project.host_connection_id == connection_id {
454 project.language_servers.push(language_server);
455 return Ok(project.connection_ids());
456 }
457
458 Err(anyhow!("no such project"))?
459 }
460
461 pub fn join_project(
462 &mut self,
463 connection_id: ConnectionId,
464 user_id: UserId,
465 project_id: u64,
466 ) -> tide::Result<JoinedProject> {
467 let connection = self
468 .connections
469 .get_mut(&connection_id)
470 .ok_or_else(|| anyhow!("no such connection"))?;
471 let project = self
472 .projects
473 .get_mut(&project_id)
474 .and_then(|project| {
475 if project.has_authorized_user_id(user_id) {
476 Some(project)
477 } else {
478 None
479 }
480 })
481 .ok_or_else(|| anyhow!("no such project"))?;
482
483 let share = project.share_mut()?;
484 connection.projects.insert(project_id);
485
486 let mut replica_id = 1;
487 while share.active_replica_ids.contains(&replica_id) {
488 replica_id += 1;
489 }
490 share.active_replica_ids.insert(replica_id);
491 share.guests.insert(connection_id, (replica_id, user_id));
492
493 #[cfg(test)]
494 self.check_invariants();
495
496 Ok(JoinedProject {
497 replica_id,
498 project: &self.projects[&project_id],
499 })
500 }
501
502 pub fn leave_project(
503 &mut self,
504 connection_id: ConnectionId,
505 project_id: u64,
506 ) -> tide::Result<LeftProject> {
507 let project = self
508 .projects
509 .get_mut(&project_id)
510 .ok_or_else(|| anyhow!("no such project"))?;
511 let share = project
512 .share
513 .as_mut()
514 .ok_or_else(|| anyhow!("project is not shared"))?;
515 let (replica_id, _) = share
516 .guests
517 .remove(&connection_id)
518 .ok_or_else(|| anyhow!("cannot leave a project before joining it"))?;
519 share.active_replica_ids.remove(&replica_id);
520
521 if let Some(connection) = self.connections.get_mut(&connection_id) {
522 connection.projects.remove(&project_id);
523 }
524
525 let connection_ids = project.connection_ids();
526 let authorized_user_ids = project.authorized_user_ids();
527
528 #[cfg(test)]
529 self.check_invariants();
530
531 Ok(LeftProject {
532 connection_ids,
533 authorized_user_ids,
534 })
535 }
536
537 pub fn update_worktree(
538 &mut self,
539 connection_id: ConnectionId,
540 project_id: u64,
541 worktree_id: u64,
542 removed_entries: &[u64],
543 updated_entries: &[proto::Entry],
544 ) -> tide::Result<Vec<ConnectionId>> {
545 let project = self.write_project(project_id, connection_id)?;
546 let worktree = project
547 .share_mut()?
548 .worktrees
549 .get_mut(&worktree_id)
550 .ok_or_else(|| anyhow!("no such worktree"))?;
551 for entry_id in removed_entries {
552 worktree.entries.remove(&entry_id);
553 }
554 for entry in updated_entries {
555 worktree.entries.insert(entry.id, entry.clone());
556 }
557 let connection_ids = project.connection_ids();
558
559 #[cfg(test)]
560 self.check_invariants();
561
562 Ok(connection_ids)
563 }
564
565 pub fn project_connection_ids(
566 &self,
567 project_id: u64,
568 acting_connection_id: ConnectionId,
569 ) -> tide::Result<Vec<ConnectionId>> {
570 Ok(self
571 .read_project(project_id, acting_connection_id)?
572 .connection_ids())
573 }
574
575 pub fn channel_connection_ids(&self, channel_id: ChannelId) -> tide::Result<Vec<ConnectionId>> {
576 Ok(self
577 .channels
578 .get(&channel_id)
579 .ok_or_else(|| anyhow!("no such channel"))?
580 .connection_ids())
581 }
582
583 #[cfg(test)]
584 pub fn project(&self, project_id: u64) -> Option<&Project> {
585 self.projects.get(&project_id)
586 }
587
588 pub fn read_project(
589 &self,
590 project_id: u64,
591 connection_id: ConnectionId,
592 ) -> tide::Result<&Project> {
593 let project = self
594 .projects
595 .get(&project_id)
596 .ok_or_else(|| anyhow!("no such project"))?;
597 if project.host_connection_id == connection_id
598 || project
599 .share
600 .as_ref()
601 .ok_or_else(|| anyhow!("project is not shared"))?
602 .guests
603 .contains_key(&connection_id)
604 {
605 Ok(project)
606 } else {
607 Err(anyhow!("no such project"))?
608 }
609 }
610
611 fn write_project(
612 &mut self,
613 project_id: u64,
614 connection_id: ConnectionId,
615 ) -> tide::Result<&mut Project> {
616 let project = self
617 .projects
618 .get_mut(&project_id)
619 .ok_or_else(|| anyhow!("no such project"))?;
620 if project.host_connection_id == connection_id
621 || project
622 .share
623 .as_ref()
624 .ok_or_else(|| anyhow!("project is not shared"))?
625 .guests
626 .contains_key(&connection_id)
627 {
628 Ok(project)
629 } else {
630 Err(anyhow!("no such project"))?
631 }
632 }
633
634 #[cfg(test)]
635 fn check_invariants(&self) {
636 for (connection_id, connection) in &self.connections {
637 for project_id in &connection.projects {
638 let project = &self.projects.get(&project_id).unwrap();
639 if project.host_connection_id != *connection_id {
640 assert!(project
641 .share
642 .as_ref()
643 .unwrap()
644 .guests
645 .contains_key(connection_id));
646 }
647
648 if let Some(share) = project.share.as_ref() {
649 for (worktree_id, worktree) in share.worktrees.iter() {
650 let mut paths = HashMap::default();
651 for entry in worktree.entries.values() {
652 let prev_entry = paths.insert(&entry.path, entry);
653 assert_eq!(
654 prev_entry,
655 None,
656 "worktree {:?}, duplicate path for entries {:?} and {:?}",
657 worktree_id,
658 prev_entry.unwrap(),
659 entry
660 );
661 }
662 }
663 }
664 }
665 for channel_id in &connection.channels {
666 let channel = self.channels.get(channel_id).unwrap();
667 assert!(channel.connection_ids.contains(connection_id));
668 }
669 assert!(self
670 .connections_by_user_id
671 .get(&connection.user_id)
672 .unwrap()
673 .contains(connection_id));
674 }
675
676 for (user_id, connection_ids) in &self.connections_by_user_id {
677 for connection_id in connection_ids {
678 assert_eq!(
679 self.connections.get(connection_id).unwrap().user_id,
680 *user_id
681 );
682 }
683 }
684
685 for (project_id, project) in &self.projects {
686 let host_connection = self.connections.get(&project.host_connection_id).unwrap();
687 assert!(host_connection.projects.contains(project_id));
688
689 for authorized_user_ids in project.authorized_user_ids() {
690 let visible_project_ids = self
691 .visible_projects_by_user_id
692 .get(&authorized_user_ids)
693 .unwrap();
694 assert!(visible_project_ids.contains(project_id));
695 }
696
697 if let Some(share) = &project.share {
698 for guest_connection_id in share.guests.keys() {
699 let guest_connection = self.connections.get(guest_connection_id).unwrap();
700 assert!(guest_connection.projects.contains(project_id));
701 }
702 assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
703 assert_eq!(
704 share.active_replica_ids,
705 share
706 .guests
707 .values()
708 .map(|(replica_id, _)| *replica_id)
709 .collect::<HashSet<_>>(),
710 );
711 }
712 }
713
714 for (user_id, visible_project_ids) in &self.visible_projects_by_user_id {
715 for project_id in visible_project_ids {
716 let project = self.projects.get(project_id).unwrap();
717 assert!(project.authorized_user_ids().contains(user_id));
718 }
719 }
720
721 for (channel_id, channel) in &self.channels {
722 for connection_id in &channel.connection_ids {
723 let connection = self.connections.get(connection_id).unwrap();
724 assert!(connection.channels.contains(channel_id));
725 }
726 }
727 }
728}
729
730impl Project {
731 pub fn has_authorized_user_id(&self, user_id: UserId) -> bool {
732 self.worktrees
733 .values()
734 .any(|worktree| worktree.authorized_user_ids.contains(&user_id))
735 }
736
737 pub fn authorized_user_ids(&self) -> Vec<UserId> {
738 let mut ids = self
739 .worktrees
740 .values()
741 .flat_map(|worktree| worktree.authorized_user_ids.iter())
742 .copied()
743 .collect::<Vec<_>>();
744 ids.sort_unstable();
745 ids.dedup();
746 ids
747 }
748
749 pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
750 if let Some(share) = &self.share {
751 share.guests.keys().copied().collect()
752 } else {
753 Vec::new()
754 }
755 }
756
757 pub fn connection_ids(&self) -> Vec<ConnectionId> {
758 if let Some(share) = &self.share {
759 share
760 .guests
761 .keys()
762 .copied()
763 .chain(Some(self.host_connection_id))
764 .collect()
765 } else {
766 vec![self.host_connection_id]
767 }
768 }
769
770 pub fn share(&self) -> tide::Result<&ProjectShare> {
771 Ok(self
772 .share
773 .as_ref()
774 .ok_or_else(|| anyhow!("worktree is not shared"))?)
775 }
776
777 fn share_mut(&mut self) -> tide::Result<&mut ProjectShare> {
778 Ok(self
779 .share
780 .as_mut()
781 .ok_or_else(|| anyhow!("worktree is not shared"))?)
782 }
783}
784
785impl Channel {
786 fn connection_ids(&self) -> Vec<ConnectionId> {
787 self.connection_ids.iter().copied().collect()
788 }
789}