1use crate::db::{ChannelId, UserId};
2use anyhow::anyhow;
3use collections::{BTreeMap, HashMap, HashSet};
4use rpc::{proto, ConnectionId};
5use std::{collections::hash_map, path::PathBuf};
6
7#[derive(Default)]
8pub struct Store {
9 connections: HashMap<ConnectionId, ConnectionState>,
10 connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
11 projects: HashMap<u64, Project>,
12 visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
13 channels: HashMap<ChannelId, Channel>,
14 next_project_id: u64,
15}
16
17struct ConnectionState {
18 user_id: UserId,
19 projects: HashSet<u64>,
20 channels: HashSet<ChannelId>,
21}
22
23pub struct Project {
24 pub host_connection_id: ConnectionId,
25 pub host_user_id: UserId,
26 pub share: Option<ProjectShare>,
27 pub worktrees: HashMap<u64, Worktree>,
28}
29
30pub struct Worktree {
31 pub authorized_user_ids: Vec<UserId>,
32 pub root_name: String,
33 pub share: Option<WorktreeShare>,
34 pub weak: bool,
35}
36
37#[derive(Default)]
38pub struct ProjectShare {
39 pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
40 pub active_replica_ids: HashSet<ReplicaId>,
41}
42
43pub struct WorktreeShare {
44 pub entries: HashMap<u64, proto::Entry>,
45 pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
46}
47
48#[derive(Default)]
49pub struct Channel {
50 pub connection_ids: HashSet<ConnectionId>,
51}
52
53pub type ReplicaId = u16;
54
55#[derive(Default)]
56pub struct RemovedConnectionState {
57 pub hosted_projects: HashMap<u64, Project>,
58 pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
59 pub contact_ids: HashSet<UserId>,
60}
61
62pub struct JoinedProject<'a> {
63 pub replica_id: ReplicaId,
64 pub project: &'a Project,
65}
66
67pub struct UnsharedProject {
68 pub connection_ids: Vec<ConnectionId>,
69 pub authorized_user_ids: Vec<UserId>,
70}
71
72pub struct LeftProject {
73 pub connection_ids: Vec<ConnectionId>,
74 pub authorized_user_ids: Vec<UserId>,
75}
76
77pub struct SharedWorktree {
78 pub authorized_user_ids: Vec<UserId>,
79 pub connection_ids: Vec<ConnectionId>,
80}
81
82impl Store {
83 pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
84 self.connections.insert(
85 connection_id,
86 ConnectionState {
87 user_id,
88 projects: Default::default(),
89 channels: Default::default(),
90 },
91 );
92 self.connections_by_user_id
93 .entry(user_id)
94 .or_default()
95 .insert(connection_id);
96 }
97
98 pub fn remove_connection(
99 &mut self,
100 connection_id: ConnectionId,
101 ) -> tide::Result<RemovedConnectionState> {
102 let connection = if let Some(connection) = self.connections.remove(&connection_id) {
103 connection
104 } else {
105 return Err(anyhow!("no such connection"))?;
106 };
107
108 for channel_id in &connection.channels {
109 if let Some(channel) = self.channels.get_mut(&channel_id) {
110 channel.connection_ids.remove(&connection_id);
111 }
112 }
113
114 let user_connections = self
115 .connections_by_user_id
116 .get_mut(&connection.user_id)
117 .unwrap();
118 user_connections.remove(&connection_id);
119 if user_connections.is_empty() {
120 self.connections_by_user_id.remove(&connection.user_id);
121 }
122
123 let mut result = RemovedConnectionState::default();
124 for project_id in connection.projects.clone() {
125 if let Ok(project) = self.unregister_project(project_id, connection_id) {
126 result.contact_ids.extend(project.authorized_user_ids());
127 result.hosted_projects.insert(project_id, project);
128 } else if let Ok(project) = self.leave_project(connection_id, project_id) {
129 result
130 .guest_project_ids
131 .insert(project_id, project.connection_ids);
132 result.contact_ids.extend(project.authorized_user_ids);
133 }
134 }
135
136 #[cfg(test)]
137 self.check_invariants();
138
139 Ok(result)
140 }
141
142 #[cfg(test)]
143 pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
144 self.channels.get(&id)
145 }
146
147 pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
148 if let Some(connection) = self.connections.get_mut(&connection_id) {
149 connection.channels.insert(channel_id);
150 self.channels
151 .entry(channel_id)
152 .or_default()
153 .connection_ids
154 .insert(connection_id);
155 }
156 }
157
158 pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
159 if let Some(connection) = self.connections.get_mut(&connection_id) {
160 connection.channels.remove(&channel_id);
161 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
162 entry.get_mut().connection_ids.remove(&connection_id);
163 if entry.get_mut().connection_ids.is_empty() {
164 entry.remove();
165 }
166 }
167 }
168 }
169
170 pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
171 Ok(self
172 .connections
173 .get(&connection_id)
174 .ok_or_else(|| anyhow!("unknown connection"))?
175 .user_id)
176 }
177
178 pub fn connection_ids_for_user<'a>(
179 &'a self,
180 user_id: UserId,
181 ) -> impl 'a + Iterator<Item = ConnectionId> {
182 self.connections_by_user_id
183 .get(&user_id)
184 .into_iter()
185 .flatten()
186 .copied()
187 }
188
189 pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
190 let mut contacts = HashMap::default();
191 for project_id in self
192 .visible_projects_by_user_id
193 .get(&user_id)
194 .unwrap_or(&HashSet::default())
195 {
196 let project = &self.projects[project_id];
197
198 let mut guests = HashSet::default();
199 if let Ok(share) = project.share() {
200 for guest_connection_id in share.guests.keys() {
201 if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
202 guests.insert(user_id.to_proto());
203 }
204 }
205 }
206
207 if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
208 let mut worktree_root_names = project
209 .worktrees
210 .values()
211 .filter(|worktree| !worktree.weak)
212 .map(|worktree| worktree.root_name.clone())
213 .collect::<Vec<_>>();
214 worktree_root_names.sort_unstable();
215 contacts
216 .entry(host_user_id)
217 .or_insert_with(|| proto::Contact {
218 user_id: host_user_id.to_proto(),
219 projects: Vec::new(),
220 })
221 .projects
222 .push(proto::ProjectMetadata {
223 id: *project_id,
224 worktree_root_names,
225 is_shared: project.share.is_some(),
226 guests: guests.into_iter().collect(),
227 });
228 }
229 }
230
231 contacts.into_values().collect()
232 }
233
234 pub fn register_project(
235 &mut self,
236 host_connection_id: ConnectionId,
237 host_user_id: UserId,
238 ) -> u64 {
239 let project_id = self.next_project_id;
240 self.projects.insert(
241 project_id,
242 Project {
243 host_connection_id,
244 host_user_id,
245 share: None,
246 worktrees: Default::default(),
247 },
248 );
249 self.next_project_id += 1;
250 project_id
251 }
252
253 pub fn register_worktree(
254 &mut self,
255 project_id: u64,
256 worktree_id: u64,
257 connection_id: ConnectionId,
258 worktree: Worktree,
259 ) -> tide::Result<()> {
260 let project = self
261 .projects
262 .get_mut(&project_id)
263 .ok_or_else(|| anyhow!("no such project"))?;
264 if project.host_connection_id == connection_id {
265 for authorized_user_id in &worktree.authorized_user_ids {
266 self.visible_projects_by_user_id
267 .entry(*authorized_user_id)
268 .or_default()
269 .insert(project_id);
270 }
271 if let Some(connection) = self.connections.get_mut(&project.host_connection_id) {
272 connection.projects.insert(project_id);
273 }
274 project.worktrees.insert(worktree_id, worktree);
275
276 #[cfg(test)]
277 self.check_invariants();
278 Ok(())
279 } else {
280 Err(anyhow!("no such project"))?
281 }
282 }
283
284 pub fn unregister_project(
285 &mut self,
286 project_id: u64,
287 connection_id: ConnectionId,
288 ) -> tide::Result<Project> {
289 match self.projects.entry(project_id) {
290 hash_map::Entry::Occupied(e) => {
291 if e.get().host_connection_id == connection_id {
292 for user_id in e.get().authorized_user_ids() {
293 if let hash_map::Entry::Occupied(mut projects) =
294 self.visible_projects_by_user_id.entry(user_id)
295 {
296 projects.get_mut().remove(&project_id);
297 }
298 }
299
300 Ok(e.remove())
301 } else {
302 Err(anyhow!("no such project"))?
303 }
304 }
305 hash_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?,
306 }
307 }
308
309 pub fn unregister_worktree(
310 &mut self,
311 project_id: u64,
312 worktree_id: u64,
313 acting_connection_id: ConnectionId,
314 ) -> tide::Result<(Worktree, Vec<ConnectionId>)> {
315 let project = self
316 .projects
317 .get_mut(&project_id)
318 .ok_or_else(|| anyhow!("no such project"))?;
319 if project.host_connection_id != acting_connection_id {
320 Err(anyhow!("not your worktree"))?;
321 }
322
323 let worktree = project
324 .worktrees
325 .remove(&worktree_id)
326 .ok_or_else(|| anyhow!("no such worktree"))?;
327
328 let mut guest_connection_ids = Vec::new();
329 if let Some(share) = &project.share {
330 guest_connection_ids.extend(share.guests.keys());
331 }
332
333 for authorized_user_id in &worktree.authorized_user_ids {
334 if let Some(visible_projects) =
335 self.visible_projects_by_user_id.get_mut(authorized_user_id)
336 {
337 if !project.has_authorized_user_id(*authorized_user_id) {
338 visible_projects.remove(&project_id);
339 }
340 }
341 }
342
343 #[cfg(test)]
344 self.check_invariants();
345
346 Ok((worktree, guest_connection_ids))
347 }
348
349 pub fn share_project(&mut self, project_id: u64, connection_id: ConnectionId) -> bool {
350 if let Some(project) = self.projects.get_mut(&project_id) {
351 if project.host_connection_id == connection_id {
352 project.share = Some(ProjectShare::default());
353 return true;
354 }
355 }
356 false
357 }
358
359 pub fn unshare_project(
360 &mut self,
361 project_id: u64,
362 acting_connection_id: ConnectionId,
363 ) -> tide::Result<UnsharedProject> {
364 let project = if let Some(project) = self.projects.get_mut(&project_id) {
365 project
366 } else {
367 return Err(anyhow!("no such project"))?;
368 };
369
370 if project.host_connection_id != acting_connection_id {
371 return Err(anyhow!("not your project"))?;
372 }
373
374 let connection_ids = project.connection_ids();
375 let authorized_user_ids = project.authorized_user_ids();
376 if let Some(share) = project.share.take() {
377 for connection_id in share.guests.into_keys() {
378 if let Some(connection) = self.connections.get_mut(&connection_id) {
379 connection.projects.remove(&project_id);
380 }
381 }
382
383 for worktree in project.worktrees.values_mut() {
384 worktree.share.take();
385 }
386
387 #[cfg(test)]
388 self.check_invariants();
389
390 Ok(UnsharedProject {
391 connection_ids,
392 authorized_user_ids,
393 })
394 } else {
395 Err(anyhow!("project is not shared"))?
396 }
397 }
398
399 pub fn share_worktree(
400 &mut self,
401 project_id: u64,
402 worktree_id: u64,
403 connection_id: ConnectionId,
404 entries: HashMap<u64, proto::Entry>,
405 diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
406 ) -> tide::Result<SharedWorktree> {
407 let project = self
408 .projects
409 .get_mut(&project_id)
410 .ok_or_else(|| anyhow!("no such project"))?;
411 let worktree = project
412 .worktrees
413 .get_mut(&worktree_id)
414 .ok_or_else(|| anyhow!("no such worktree"))?;
415 if project.host_connection_id == connection_id && project.share.is_some() {
416 worktree.share = Some(WorktreeShare {
417 entries,
418 diagnostic_summaries,
419 });
420 Ok(SharedWorktree {
421 authorized_user_ids: project.authorized_user_ids(),
422 connection_ids: project.guest_connection_ids(),
423 })
424 } else {
425 Err(anyhow!("no such worktree"))?
426 }
427 }
428
429 pub fn update_diagnostic_summary(
430 &mut self,
431 project_id: u64,
432 worktree_id: u64,
433 connection_id: ConnectionId,
434 summary: proto::DiagnosticSummary,
435 ) -> tide::Result<Vec<ConnectionId>> {
436 let project = self
437 .projects
438 .get_mut(&project_id)
439 .ok_or_else(|| anyhow!("no such project"))?;
440 let worktree = project
441 .worktrees
442 .get_mut(&worktree_id)
443 .ok_or_else(|| anyhow!("no such worktree"))?;
444 if project.host_connection_id == connection_id {
445 if let Some(share) = worktree.share.as_mut() {
446 share
447 .diagnostic_summaries
448 .insert(summary.path.clone().into(), summary);
449 return Ok(project.connection_ids());
450 }
451 }
452
453 Err(anyhow!("no such worktree"))?
454 }
455
456 pub fn join_project(
457 &mut self,
458 connection_id: ConnectionId,
459 user_id: UserId,
460 project_id: u64,
461 ) -> tide::Result<JoinedProject> {
462 let connection = self
463 .connections
464 .get_mut(&connection_id)
465 .ok_or_else(|| anyhow!("no such connection"))?;
466 let project = self
467 .projects
468 .get_mut(&project_id)
469 .and_then(|project| {
470 if project.has_authorized_user_id(user_id) {
471 Some(project)
472 } else {
473 None
474 }
475 })
476 .ok_or_else(|| anyhow!("no such project"))?;
477
478 let share = project.share_mut()?;
479 connection.projects.insert(project_id);
480
481 let mut replica_id = 1;
482 while share.active_replica_ids.contains(&replica_id) {
483 replica_id += 1;
484 }
485 share.active_replica_ids.insert(replica_id);
486 share.guests.insert(connection_id, (replica_id, user_id));
487
488 #[cfg(test)]
489 self.check_invariants();
490
491 Ok(JoinedProject {
492 replica_id,
493 project: &self.projects[&project_id],
494 })
495 }
496
497 pub fn leave_project(
498 &mut self,
499 connection_id: ConnectionId,
500 project_id: u64,
501 ) -> tide::Result<LeftProject> {
502 let project = self
503 .projects
504 .get_mut(&project_id)
505 .ok_or_else(|| anyhow!("no such project"))?;
506 let share = project
507 .share
508 .as_mut()
509 .ok_or_else(|| anyhow!("project is not shared"))?;
510 let (replica_id, _) = share
511 .guests
512 .remove(&connection_id)
513 .ok_or_else(|| anyhow!("cannot leave a project before joining it"))?;
514 share.active_replica_ids.remove(&replica_id);
515
516 if let Some(connection) = self.connections.get_mut(&connection_id) {
517 connection.projects.remove(&project_id);
518 }
519
520 let connection_ids = project.connection_ids();
521 let authorized_user_ids = project.authorized_user_ids();
522
523 #[cfg(test)]
524 self.check_invariants();
525
526 Ok(LeftProject {
527 connection_ids,
528 authorized_user_ids,
529 })
530 }
531
532 pub fn update_worktree(
533 &mut self,
534 connection_id: ConnectionId,
535 project_id: u64,
536 worktree_id: u64,
537 removed_entries: &[u64],
538 updated_entries: &[proto::Entry],
539 ) -> tide::Result<Vec<ConnectionId>> {
540 let project = self.write_project(project_id, connection_id)?;
541 let share = project
542 .worktrees
543 .get_mut(&worktree_id)
544 .ok_or_else(|| anyhow!("no such worktree"))?
545 .share
546 .as_mut()
547 .ok_or_else(|| anyhow!("worktree is not shared"))?;
548 for entry_id in removed_entries {
549 share.entries.remove(&entry_id);
550 }
551 for entry in updated_entries {
552 share.entries.insert(entry.id, entry.clone());
553 }
554 Ok(project.connection_ids())
555 }
556
557 pub fn project_connection_ids(
558 &self,
559 project_id: u64,
560 acting_connection_id: ConnectionId,
561 ) -> tide::Result<Vec<ConnectionId>> {
562 Ok(self
563 .read_project(project_id, acting_connection_id)?
564 .connection_ids())
565 }
566
567 pub fn channel_connection_ids(&self, channel_id: ChannelId) -> tide::Result<Vec<ConnectionId>> {
568 Ok(self
569 .channels
570 .get(&channel_id)
571 .ok_or_else(|| anyhow!("no such channel"))?
572 .connection_ids())
573 }
574
575 #[cfg(test)]
576 pub fn project(&self, project_id: u64) -> Option<&Project> {
577 self.projects.get(&project_id)
578 }
579
580 pub fn read_project(
581 &self,
582 project_id: u64,
583 connection_id: ConnectionId,
584 ) -> tide::Result<&Project> {
585 let project = self
586 .projects
587 .get(&project_id)
588 .ok_or_else(|| anyhow!("no such project"))?;
589 if project.host_connection_id == connection_id
590 || project
591 .share
592 .as_ref()
593 .ok_or_else(|| anyhow!("project is not shared"))?
594 .guests
595 .contains_key(&connection_id)
596 {
597 Ok(project)
598 } else {
599 Err(anyhow!("no such project"))?
600 }
601 }
602
603 fn write_project(
604 &mut self,
605 project_id: u64,
606 connection_id: ConnectionId,
607 ) -> tide::Result<&mut Project> {
608 let project = self
609 .projects
610 .get_mut(&project_id)
611 .ok_or_else(|| anyhow!("no such project"))?;
612 if project.host_connection_id == connection_id
613 || project
614 .share
615 .as_ref()
616 .ok_or_else(|| anyhow!("project is not shared"))?
617 .guests
618 .contains_key(&connection_id)
619 {
620 Ok(project)
621 } else {
622 Err(anyhow!("no such project"))?
623 }
624 }
625
626 #[cfg(test)]
627 fn check_invariants(&self) {
628 for (connection_id, connection) in &self.connections {
629 for project_id in &connection.projects {
630 let project = &self.projects.get(&project_id).unwrap();
631 if project.host_connection_id != *connection_id {
632 assert!(project
633 .share
634 .as_ref()
635 .unwrap()
636 .guests
637 .contains_key(connection_id));
638 }
639 }
640 for channel_id in &connection.channels {
641 let channel = self.channels.get(channel_id).unwrap();
642 assert!(channel.connection_ids.contains(connection_id));
643 }
644 assert!(self
645 .connections_by_user_id
646 .get(&connection.user_id)
647 .unwrap()
648 .contains(connection_id));
649 }
650
651 for (user_id, connection_ids) in &self.connections_by_user_id {
652 for connection_id in connection_ids {
653 assert_eq!(
654 self.connections.get(connection_id).unwrap().user_id,
655 *user_id
656 );
657 }
658 }
659
660 for (project_id, project) in &self.projects {
661 let host_connection = self.connections.get(&project.host_connection_id).unwrap();
662 assert!(host_connection.projects.contains(project_id));
663
664 for authorized_user_ids in project.authorized_user_ids() {
665 let visible_project_ids = self
666 .visible_projects_by_user_id
667 .get(&authorized_user_ids)
668 .unwrap();
669 assert!(visible_project_ids.contains(project_id));
670 }
671
672 if let Some(share) = &project.share {
673 for guest_connection_id in share.guests.keys() {
674 let guest_connection = self.connections.get(guest_connection_id).unwrap();
675 assert!(guest_connection.projects.contains(project_id));
676 }
677 assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
678 assert_eq!(
679 share.active_replica_ids,
680 share
681 .guests
682 .values()
683 .map(|(replica_id, _)| *replica_id)
684 .collect::<HashSet<_>>(),
685 );
686 }
687 }
688
689 for (user_id, visible_project_ids) in &self.visible_projects_by_user_id {
690 for project_id in visible_project_ids {
691 let project = self.projects.get(project_id).unwrap();
692 assert!(project.authorized_user_ids().contains(user_id));
693 }
694 }
695
696 for (channel_id, channel) in &self.channels {
697 for connection_id in &channel.connection_ids {
698 let connection = self.connections.get(connection_id).unwrap();
699 assert!(connection.channels.contains(channel_id));
700 }
701 }
702 }
703}
704
705impl Project {
706 pub fn has_authorized_user_id(&self, user_id: UserId) -> bool {
707 self.worktrees
708 .values()
709 .any(|worktree| worktree.authorized_user_ids.contains(&user_id))
710 }
711
712 pub fn authorized_user_ids(&self) -> Vec<UserId> {
713 let mut ids = self
714 .worktrees
715 .values()
716 .flat_map(|worktree| worktree.authorized_user_ids.iter())
717 .copied()
718 .collect::<Vec<_>>();
719 ids.sort_unstable();
720 ids.dedup();
721 ids
722 }
723
724 pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
725 if let Some(share) = &self.share {
726 share.guests.keys().copied().collect()
727 } else {
728 Vec::new()
729 }
730 }
731
732 pub fn connection_ids(&self) -> Vec<ConnectionId> {
733 if let Some(share) = &self.share {
734 share
735 .guests
736 .keys()
737 .copied()
738 .chain(Some(self.host_connection_id))
739 .collect()
740 } else {
741 vec![self.host_connection_id]
742 }
743 }
744
745 pub fn share(&self) -> tide::Result<&ProjectShare> {
746 Ok(self
747 .share
748 .as_ref()
749 .ok_or_else(|| anyhow!("worktree is not shared"))?)
750 }
751
752 fn share_mut(&mut self) -> tide::Result<&mut ProjectShare> {
753 Ok(self
754 .share
755 .as_mut()
756 .ok_or_else(|| anyhow!("worktree is not shared"))?)
757 }
758}
759
760impl Channel {
761 fn connection_ids(&self) -> Vec<ConnectionId> {
762 self.connection_ids.iter().copied().collect()
763 }
764}