1use crate::db::{ChannelId, UserId};
2use anyhow::anyhow;
3use collections::{BTreeMap, HashMap, HashSet};
4use rpc::{proto, ConnectionId};
5use std::{collections::hash_map, path::PathBuf};
6
7#[derive(Default)]
8pub struct Store {
9 connections: HashMap<ConnectionId, ConnectionState>,
10 connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
11 projects: HashMap<u64, Project>,
12 visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
13 channels: HashMap<ChannelId, Channel>,
14 next_project_id: u64,
15}
16
17struct ConnectionState {
18 user_id: UserId,
19 projects: HashSet<u64>,
20 channels: HashSet<ChannelId>,
21}
22
23pub struct Project {
24 pub host_connection_id: ConnectionId,
25 pub host_user_id: UserId,
26 pub share: Option<ProjectShare>,
27 pub worktrees: HashMap<u64, Worktree>,
28}
29
30pub struct Worktree {
31 pub authorized_user_ids: Vec<UserId>,
32 pub root_name: String,
33 pub weak: bool,
34}
35
36#[derive(Default)]
37pub struct ProjectShare {
38 pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
39 pub active_replica_ids: HashSet<ReplicaId>,
40 pub worktrees: HashMap<u64, WorktreeShare>,
41}
42
43#[derive(Default)]
44pub struct WorktreeShare {
45 pub entries: HashMap<u64, proto::Entry>,
46 pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
47}
48
49#[derive(Default)]
50pub struct Channel {
51 pub connection_ids: HashSet<ConnectionId>,
52}
53
54pub type ReplicaId = u16;
55
56#[derive(Default)]
57pub struct RemovedConnectionState {
58 pub hosted_projects: HashMap<u64, Project>,
59 pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
60 pub contact_ids: HashSet<UserId>,
61}
62
63pub struct JoinedProject<'a> {
64 pub replica_id: ReplicaId,
65 pub project: &'a Project,
66}
67
68pub struct UnsharedProject {
69 pub connection_ids: Vec<ConnectionId>,
70 pub authorized_user_ids: Vec<UserId>,
71}
72
73pub struct LeftProject {
74 pub connection_ids: Vec<ConnectionId>,
75 pub authorized_user_ids: Vec<UserId>,
76}
77
78impl Store {
79 pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
80 self.connections.insert(
81 connection_id,
82 ConnectionState {
83 user_id,
84 projects: Default::default(),
85 channels: Default::default(),
86 },
87 );
88 self.connections_by_user_id
89 .entry(user_id)
90 .or_default()
91 .insert(connection_id);
92 }
93
94 pub fn remove_connection(
95 &mut self,
96 connection_id: ConnectionId,
97 ) -> tide::Result<RemovedConnectionState> {
98 let connection = if let Some(connection) = self.connections.remove(&connection_id) {
99 connection
100 } else {
101 return Err(anyhow!("no such connection"))?;
102 };
103
104 for channel_id in &connection.channels {
105 if let Some(channel) = self.channels.get_mut(&channel_id) {
106 channel.connection_ids.remove(&connection_id);
107 }
108 }
109
110 let user_connections = self
111 .connections_by_user_id
112 .get_mut(&connection.user_id)
113 .unwrap();
114 user_connections.remove(&connection_id);
115 if user_connections.is_empty() {
116 self.connections_by_user_id.remove(&connection.user_id);
117 }
118
119 let mut result = RemovedConnectionState::default();
120 for project_id in connection.projects.clone() {
121 if let Ok(project) = self.unregister_project(project_id, connection_id) {
122 result.contact_ids.extend(project.authorized_user_ids());
123 result.hosted_projects.insert(project_id, project);
124 } else if let Ok(project) = self.leave_project(connection_id, project_id) {
125 result
126 .guest_project_ids
127 .insert(project_id, project.connection_ids);
128 result.contact_ids.extend(project.authorized_user_ids);
129 }
130 }
131
132 #[cfg(test)]
133 self.check_invariants();
134
135 Ok(result)
136 }
137
138 #[cfg(test)]
139 pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
140 self.channels.get(&id)
141 }
142
143 pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
144 if let Some(connection) = self.connections.get_mut(&connection_id) {
145 connection.channels.insert(channel_id);
146 self.channels
147 .entry(channel_id)
148 .or_default()
149 .connection_ids
150 .insert(connection_id);
151 }
152 }
153
154 pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
155 if let Some(connection) = self.connections.get_mut(&connection_id) {
156 connection.channels.remove(&channel_id);
157 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
158 entry.get_mut().connection_ids.remove(&connection_id);
159 if entry.get_mut().connection_ids.is_empty() {
160 entry.remove();
161 }
162 }
163 }
164 }
165
166 pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
167 Ok(self
168 .connections
169 .get(&connection_id)
170 .ok_or_else(|| anyhow!("unknown connection"))?
171 .user_id)
172 }
173
174 pub fn connection_ids_for_user<'a>(
175 &'a self,
176 user_id: UserId,
177 ) -> impl 'a + Iterator<Item = ConnectionId> {
178 self.connections_by_user_id
179 .get(&user_id)
180 .into_iter()
181 .flatten()
182 .copied()
183 }
184
185 pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
186 let mut contacts = HashMap::default();
187 for project_id in self
188 .visible_projects_by_user_id
189 .get(&user_id)
190 .unwrap_or(&HashSet::default())
191 {
192 let project = &self.projects[project_id];
193
194 let mut guests = HashSet::default();
195 if let Ok(share) = project.share() {
196 for guest_connection_id in share.guests.keys() {
197 if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
198 guests.insert(user_id.to_proto());
199 }
200 }
201 }
202
203 if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
204 let mut worktree_root_names = project
205 .worktrees
206 .values()
207 .filter(|worktree| !worktree.weak)
208 .map(|worktree| worktree.root_name.clone())
209 .collect::<Vec<_>>();
210 worktree_root_names.sort_unstable();
211 contacts
212 .entry(host_user_id)
213 .or_insert_with(|| proto::Contact {
214 user_id: host_user_id.to_proto(),
215 projects: Vec::new(),
216 })
217 .projects
218 .push(proto::ProjectMetadata {
219 id: *project_id,
220 worktree_root_names,
221 is_shared: project.share.is_some(),
222 guests: guests.into_iter().collect(),
223 });
224 }
225 }
226
227 contacts.into_values().collect()
228 }
229
230 pub fn register_project(
231 &mut self,
232 host_connection_id: ConnectionId,
233 host_user_id: UserId,
234 ) -> u64 {
235 let project_id = self.next_project_id;
236 self.projects.insert(
237 project_id,
238 Project {
239 host_connection_id,
240 host_user_id,
241 share: None,
242 worktrees: Default::default(),
243 },
244 );
245 self.next_project_id += 1;
246 project_id
247 }
248
249 pub fn register_worktree(
250 &mut self,
251 project_id: u64,
252 worktree_id: u64,
253 connection_id: ConnectionId,
254 worktree: Worktree,
255 ) -> tide::Result<()> {
256 let project = self
257 .projects
258 .get_mut(&project_id)
259 .ok_or_else(|| anyhow!("no such project"))?;
260 if project.host_connection_id == connection_id {
261 for authorized_user_id in &worktree.authorized_user_ids {
262 self.visible_projects_by_user_id
263 .entry(*authorized_user_id)
264 .or_default()
265 .insert(project_id);
266 }
267 if let Some(connection) = self.connections.get_mut(&project.host_connection_id) {
268 connection.projects.insert(project_id);
269 }
270 project.worktrees.insert(worktree_id, worktree);
271 if let Ok(share) = project.share_mut() {
272 share.worktrees.insert(worktree_id, Default::default());
273 }
274
275 #[cfg(test)]
276 self.check_invariants();
277 Ok(())
278 } else {
279 Err(anyhow!("no such project"))?
280 }
281 }
282
283 pub fn unregister_project(
284 &mut self,
285 project_id: u64,
286 connection_id: ConnectionId,
287 ) -> tide::Result<Project> {
288 match self.projects.entry(project_id) {
289 hash_map::Entry::Occupied(e) => {
290 if e.get().host_connection_id == connection_id {
291 for user_id in e.get().authorized_user_ids() {
292 if let hash_map::Entry::Occupied(mut projects) =
293 self.visible_projects_by_user_id.entry(user_id)
294 {
295 projects.get_mut().remove(&project_id);
296 }
297 }
298
299 let project = e.remove();
300 if let Some(share) = &project.share {
301 for guest_connection in share.guests.keys() {
302 if let Some(connection) = self.connections.get_mut(&guest_connection) {
303 connection.projects.remove(&project_id);
304 }
305 }
306 }
307
308 Ok(project)
309 } else {
310 Err(anyhow!("no such project"))?
311 }
312 }
313 hash_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?,
314 }
315 }
316
317 pub fn unregister_worktree(
318 &mut self,
319 project_id: u64,
320 worktree_id: u64,
321 acting_connection_id: ConnectionId,
322 ) -> tide::Result<(Worktree, Vec<ConnectionId>)> {
323 let project = self
324 .projects
325 .get_mut(&project_id)
326 .ok_or_else(|| anyhow!("no such project"))?;
327 if project.host_connection_id != acting_connection_id {
328 Err(anyhow!("not your worktree"))?;
329 }
330
331 let worktree = project
332 .worktrees
333 .remove(&worktree_id)
334 .ok_or_else(|| anyhow!("no such worktree"))?;
335
336 let mut guest_connection_ids = Vec::new();
337 if let Ok(share) = project.share_mut() {
338 guest_connection_ids.extend(share.guests.keys());
339 share.worktrees.remove(&worktree_id);
340 }
341
342 for authorized_user_id in &worktree.authorized_user_ids {
343 if let Some(visible_projects) =
344 self.visible_projects_by_user_id.get_mut(authorized_user_id)
345 {
346 if !project.has_authorized_user_id(*authorized_user_id) {
347 visible_projects.remove(&project_id);
348 }
349 }
350 }
351
352 #[cfg(test)]
353 self.check_invariants();
354
355 Ok((worktree, guest_connection_ids))
356 }
357
358 pub fn share_project(&mut self, project_id: u64, connection_id: ConnectionId) -> bool {
359 if let Some(project) = self.projects.get_mut(&project_id) {
360 if project.host_connection_id == connection_id {
361 let mut share = ProjectShare::default();
362 for worktree_id in project.worktrees.keys() {
363 share.worktrees.insert(*worktree_id, Default::default());
364 }
365 project.share = Some(share);
366 return true;
367 }
368 }
369 false
370 }
371
372 pub fn unshare_project(
373 &mut self,
374 project_id: u64,
375 acting_connection_id: ConnectionId,
376 ) -> tide::Result<UnsharedProject> {
377 let project = if let Some(project) = self.projects.get_mut(&project_id) {
378 project
379 } else {
380 return Err(anyhow!("no such project"))?;
381 };
382
383 if project.host_connection_id != acting_connection_id {
384 return Err(anyhow!("not your project"))?;
385 }
386
387 let connection_ids = project.connection_ids();
388 let authorized_user_ids = project.authorized_user_ids();
389 if let Some(share) = project.share.take() {
390 for connection_id in share.guests.into_keys() {
391 if let Some(connection) = self.connections.get_mut(&connection_id) {
392 connection.projects.remove(&project_id);
393 }
394 }
395
396 #[cfg(test)]
397 self.check_invariants();
398
399 Ok(UnsharedProject {
400 connection_ids,
401 authorized_user_ids,
402 })
403 } else {
404 Err(anyhow!("project is not shared"))?
405 }
406 }
407
408 pub fn update_diagnostic_summary(
409 &mut self,
410 project_id: u64,
411 worktree_id: u64,
412 connection_id: ConnectionId,
413 summary: proto::DiagnosticSummary,
414 ) -> tide::Result<Vec<ConnectionId>> {
415 let project = self
416 .projects
417 .get_mut(&project_id)
418 .ok_or_else(|| anyhow!("no such project"))?;
419 if project.host_connection_id == connection_id {
420 let worktree = project
421 .share_mut()?
422 .worktrees
423 .get_mut(&worktree_id)
424 .ok_or_else(|| anyhow!("no such worktree"))?;
425 worktree
426 .diagnostic_summaries
427 .insert(summary.path.clone().into(), summary);
428 return Ok(project.connection_ids());
429 }
430
431 Err(anyhow!("no such worktree"))?
432 }
433
434 pub fn join_project(
435 &mut self,
436 connection_id: ConnectionId,
437 user_id: UserId,
438 project_id: u64,
439 ) -> tide::Result<JoinedProject> {
440 let connection = self
441 .connections
442 .get_mut(&connection_id)
443 .ok_or_else(|| anyhow!("no such connection"))?;
444 let project = self
445 .projects
446 .get_mut(&project_id)
447 .and_then(|project| {
448 if project.has_authorized_user_id(user_id) {
449 Some(project)
450 } else {
451 None
452 }
453 })
454 .ok_or_else(|| anyhow!("no such project"))?;
455
456 let share = project.share_mut()?;
457 connection.projects.insert(project_id);
458
459 let mut replica_id = 1;
460 while share.active_replica_ids.contains(&replica_id) {
461 replica_id += 1;
462 }
463 share.active_replica_ids.insert(replica_id);
464 share.guests.insert(connection_id, (replica_id, user_id));
465
466 #[cfg(test)]
467 self.check_invariants();
468
469 Ok(JoinedProject {
470 replica_id,
471 project: &self.projects[&project_id],
472 })
473 }
474
475 pub fn leave_project(
476 &mut self,
477 connection_id: ConnectionId,
478 project_id: u64,
479 ) -> tide::Result<LeftProject> {
480 let project = self
481 .projects
482 .get_mut(&project_id)
483 .ok_or_else(|| anyhow!("no such project"))?;
484 let share = project
485 .share
486 .as_mut()
487 .ok_or_else(|| anyhow!("project is not shared"))?;
488 let (replica_id, _) = share
489 .guests
490 .remove(&connection_id)
491 .ok_or_else(|| anyhow!("cannot leave a project before joining it"))?;
492 share.active_replica_ids.remove(&replica_id);
493
494 if let Some(connection) = self.connections.get_mut(&connection_id) {
495 connection.projects.remove(&project_id);
496 }
497
498 let connection_ids = project.connection_ids();
499 let authorized_user_ids = project.authorized_user_ids();
500
501 #[cfg(test)]
502 self.check_invariants();
503
504 Ok(LeftProject {
505 connection_ids,
506 authorized_user_ids,
507 })
508 }
509
510 pub fn update_worktree(
511 &mut self,
512 connection_id: ConnectionId,
513 project_id: u64,
514 worktree_id: u64,
515 removed_entries: &[u64],
516 updated_entries: &[proto::Entry],
517 ) -> tide::Result<Vec<ConnectionId>> {
518 let project = self.write_project(project_id, connection_id)?;
519 let worktree = project
520 .share_mut()?
521 .worktrees
522 .get_mut(&worktree_id)
523 .ok_or_else(|| anyhow!("no such worktree"))?;
524 for entry_id in removed_entries {
525 worktree.entries.remove(&entry_id);
526 }
527 for entry in updated_entries {
528 worktree.entries.insert(entry.id, entry.clone());
529 }
530 Ok(project.connection_ids())
531 }
532
533 pub fn project_connection_ids(
534 &self,
535 project_id: u64,
536 acting_connection_id: ConnectionId,
537 ) -> tide::Result<Vec<ConnectionId>> {
538 Ok(self
539 .read_project(project_id, acting_connection_id)?
540 .connection_ids())
541 }
542
543 pub fn channel_connection_ids(&self, channel_id: ChannelId) -> tide::Result<Vec<ConnectionId>> {
544 Ok(self
545 .channels
546 .get(&channel_id)
547 .ok_or_else(|| anyhow!("no such channel"))?
548 .connection_ids())
549 }
550
551 #[cfg(test)]
552 pub fn project(&self, project_id: u64) -> Option<&Project> {
553 self.projects.get(&project_id)
554 }
555
556 pub fn read_project(
557 &self,
558 project_id: u64,
559 connection_id: ConnectionId,
560 ) -> tide::Result<&Project> {
561 let project = self
562 .projects
563 .get(&project_id)
564 .ok_or_else(|| anyhow!("no such project"))?;
565 if project.host_connection_id == connection_id
566 || project
567 .share
568 .as_ref()
569 .ok_or_else(|| anyhow!("project is not shared"))?
570 .guests
571 .contains_key(&connection_id)
572 {
573 Ok(project)
574 } else {
575 Err(anyhow!("no such project"))?
576 }
577 }
578
579 fn write_project(
580 &mut self,
581 project_id: u64,
582 connection_id: ConnectionId,
583 ) -> tide::Result<&mut Project> {
584 let project = self
585 .projects
586 .get_mut(&project_id)
587 .ok_or_else(|| anyhow!("no such project"))?;
588 if project.host_connection_id == connection_id
589 || project
590 .share
591 .as_ref()
592 .ok_or_else(|| anyhow!("project is not shared"))?
593 .guests
594 .contains_key(&connection_id)
595 {
596 Ok(project)
597 } else {
598 Err(anyhow!("no such project"))?
599 }
600 }
601
602 #[cfg(test)]
603 fn check_invariants(&self) {
604 for (connection_id, connection) in &self.connections {
605 for project_id in &connection.projects {
606 let project = &self.projects.get(&project_id).unwrap();
607 if project.host_connection_id != *connection_id {
608 assert!(project
609 .share
610 .as_ref()
611 .unwrap()
612 .guests
613 .contains_key(connection_id));
614 }
615 }
616 for channel_id in &connection.channels {
617 let channel = self.channels.get(channel_id).unwrap();
618 assert!(channel.connection_ids.contains(connection_id));
619 }
620 assert!(self
621 .connections_by_user_id
622 .get(&connection.user_id)
623 .unwrap()
624 .contains(connection_id));
625 }
626
627 for (user_id, connection_ids) in &self.connections_by_user_id {
628 for connection_id in connection_ids {
629 assert_eq!(
630 self.connections.get(connection_id).unwrap().user_id,
631 *user_id
632 );
633 }
634 }
635
636 for (project_id, project) in &self.projects {
637 let host_connection = self.connections.get(&project.host_connection_id).unwrap();
638 assert!(host_connection.projects.contains(project_id));
639
640 for authorized_user_ids in project.authorized_user_ids() {
641 let visible_project_ids = self
642 .visible_projects_by_user_id
643 .get(&authorized_user_ids)
644 .unwrap();
645 assert!(visible_project_ids.contains(project_id));
646 }
647
648 if let Some(share) = &project.share {
649 for guest_connection_id in share.guests.keys() {
650 let guest_connection = self.connections.get(guest_connection_id).unwrap();
651 assert!(guest_connection.projects.contains(project_id));
652 }
653 assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
654 assert_eq!(
655 share.active_replica_ids,
656 share
657 .guests
658 .values()
659 .map(|(replica_id, _)| *replica_id)
660 .collect::<HashSet<_>>(),
661 );
662 }
663 }
664
665 for (user_id, visible_project_ids) in &self.visible_projects_by_user_id {
666 for project_id in visible_project_ids {
667 let project = self.projects.get(project_id).unwrap();
668 assert!(project.authorized_user_ids().contains(user_id));
669 }
670 }
671
672 for (channel_id, channel) in &self.channels {
673 for connection_id in &channel.connection_ids {
674 let connection = self.connections.get(connection_id).unwrap();
675 assert!(connection.channels.contains(channel_id));
676 }
677 }
678 }
679}
680
681impl Project {
682 pub fn has_authorized_user_id(&self, user_id: UserId) -> bool {
683 self.worktrees
684 .values()
685 .any(|worktree| worktree.authorized_user_ids.contains(&user_id))
686 }
687
688 pub fn authorized_user_ids(&self) -> Vec<UserId> {
689 let mut ids = self
690 .worktrees
691 .values()
692 .flat_map(|worktree| worktree.authorized_user_ids.iter())
693 .copied()
694 .collect::<Vec<_>>();
695 ids.sort_unstable();
696 ids.dedup();
697 ids
698 }
699
700 pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
701 if let Some(share) = &self.share {
702 share.guests.keys().copied().collect()
703 } else {
704 Vec::new()
705 }
706 }
707
708 pub fn connection_ids(&self) -> Vec<ConnectionId> {
709 if let Some(share) = &self.share {
710 share
711 .guests
712 .keys()
713 .copied()
714 .chain(Some(self.host_connection_id))
715 .collect()
716 } else {
717 vec![self.host_connection_id]
718 }
719 }
720
721 pub fn share(&self) -> tide::Result<&ProjectShare> {
722 Ok(self
723 .share
724 .as_ref()
725 .ok_or_else(|| anyhow!("worktree is not shared"))?)
726 }
727
728 fn share_mut(&mut self) -> tide::Result<&mut ProjectShare> {
729 Ok(self
730 .share
731 .as_mut()
732 .ok_or_else(|| anyhow!("worktree is not shared"))?)
733 }
734}
735
736impl Channel {
737 fn connection_ids(&self) -> Vec<ConnectionId> {
738 self.connection_ids.iter().copied().collect()
739 }
740}