1use crate::db::{ChannelId, UserId};
2use anyhow::anyhow;
3use collections::{BTreeMap, HashMap, HashSet};
4use rpc::{proto, ConnectionId};
5use std::{collections::hash_map, path::PathBuf};
6
7#[derive(Default)]
8pub struct Store {
9 connections: HashMap<ConnectionId, ConnectionState>,
10 connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
11 projects: HashMap<u64, Project>,
12 visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
13 channels: HashMap<ChannelId, Channel>,
14 next_project_id: u64,
15}
16
17struct ConnectionState {
18 user_id: UserId,
19 projects: HashSet<u64>,
20 channels: HashSet<ChannelId>,
21}
22
23pub struct Project {
24 pub host_connection_id: ConnectionId,
25 pub host_user_id: UserId,
26 pub share: Option<ProjectShare>,
27 pub worktrees: HashMap<u64, Worktree>,
28}
29
30pub struct Worktree {
31 pub authorized_user_ids: Vec<UserId>,
32 pub root_name: String,
33 pub weak: bool,
34}
35
36#[derive(Default)]
37pub struct ProjectShare {
38 pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
39 pub active_replica_ids: HashSet<ReplicaId>,
40 pub worktrees: HashMap<u64, WorktreeShare>,
41}
42
43#[derive(Default)]
44pub struct WorktreeShare {
45 pub entries: HashMap<u64, proto::Entry>,
46 pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
47}
48
49#[derive(Default)]
50pub struct Channel {
51 pub connection_ids: HashSet<ConnectionId>,
52}
53
54pub type ReplicaId = u16;
55
56#[derive(Default)]
57pub struct RemovedConnectionState {
58 pub hosted_projects: HashMap<u64, Project>,
59 pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
60 pub contact_ids: HashSet<UserId>,
61}
62
63pub struct JoinedProject<'a> {
64 pub replica_id: ReplicaId,
65 pub project: &'a Project,
66}
67
68pub struct UnsharedProject {
69 pub connection_ids: Vec<ConnectionId>,
70 pub authorized_user_ids: Vec<UserId>,
71}
72
73pub struct LeftProject {
74 pub connection_ids: Vec<ConnectionId>,
75 pub authorized_user_ids: Vec<UserId>,
76}
77
78impl Store {
79 pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
80 self.connections.insert(
81 connection_id,
82 ConnectionState {
83 user_id,
84 projects: Default::default(),
85 channels: Default::default(),
86 },
87 );
88 self.connections_by_user_id
89 .entry(user_id)
90 .or_default()
91 .insert(connection_id);
92 }
93
94 pub fn remove_connection(
95 &mut self,
96 connection_id: ConnectionId,
97 ) -> tide::Result<RemovedConnectionState> {
98 let connection = if let Some(connection) = self.connections.remove(&connection_id) {
99 connection
100 } else {
101 return Err(anyhow!("no such connection"))?;
102 };
103
104 for channel_id in &connection.channels {
105 if let Some(channel) = self.channels.get_mut(&channel_id) {
106 channel.connection_ids.remove(&connection_id);
107 }
108 }
109
110 let user_connections = self
111 .connections_by_user_id
112 .get_mut(&connection.user_id)
113 .unwrap();
114 user_connections.remove(&connection_id);
115 if user_connections.is_empty() {
116 self.connections_by_user_id.remove(&connection.user_id);
117 }
118
119 let mut result = RemovedConnectionState::default();
120 for project_id in connection.projects.clone() {
121 if let Ok(project) = self.unregister_project(project_id, connection_id) {
122 result.contact_ids.extend(project.authorized_user_ids());
123 result.hosted_projects.insert(project_id, project);
124 } else if let Ok(project) = self.leave_project(connection_id, project_id) {
125 result
126 .guest_project_ids
127 .insert(project_id, project.connection_ids);
128 result.contact_ids.extend(project.authorized_user_ids);
129 }
130 }
131
132 #[cfg(test)]
133 self.check_invariants();
134
135 Ok(result)
136 }
137
138 #[cfg(test)]
139 pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
140 self.channels.get(&id)
141 }
142
143 pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
144 if let Some(connection) = self.connections.get_mut(&connection_id) {
145 connection.channels.insert(channel_id);
146 self.channels
147 .entry(channel_id)
148 .or_default()
149 .connection_ids
150 .insert(connection_id);
151 }
152 }
153
154 pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
155 if let Some(connection) = self.connections.get_mut(&connection_id) {
156 connection.channels.remove(&channel_id);
157 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
158 entry.get_mut().connection_ids.remove(&connection_id);
159 if entry.get_mut().connection_ids.is_empty() {
160 entry.remove();
161 }
162 }
163 }
164 }
165
166 pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
167 Ok(self
168 .connections
169 .get(&connection_id)
170 .ok_or_else(|| anyhow!("unknown connection"))?
171 .user_id)
172 }
173
174 pub fn connection_ids_for_user<'a>(
175 &'a self,
176 user_id: UserId,
177 ) -> impl 'a + Iterator<Item = ConnectionId> {
178 self.connections_by_user_id
179 .get(&user_id)
180 .into_iter()
181 .flatten()
182 .copied()
183 }
184
185 pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
186 let mut contacts = HashMap::default();
187 for project_id in self
188 .visible_projects_by_user_id
189 .get(&user_id)
190 .unwrap_or(&HashSet::default())
191 {
192 let project = &self.projects[project_id];
193
194 let mut guests = HashSet::default();
195 if let Ok(share) = project.share() {
196 for guest_connection_id in share.guests.keys() {
197 if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
198 guests.insert(user_id.to_proto());
199 }
200 }
201 }
202
203 if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
204 let mut worktree_root_names = project
205 .worktrees
206 .values()
207 .filter(|worktree| !worktree.weak)
208 .map(|worktree| worktree.root_name.clone())
209 .collect::<Vec<_>>();
210 worktree_root_names.sort_unstable();
211 contacts
212 .entry(host_user_id)
213 .or_insert_with(|| proto::Contact {
214 user_id: host_user_id.to_proto(),
215 projects: Vec::new(),
216 })
217 .projects
218 .push(proto::ProjectMetadata {
219 id: *project_id,
220 worktree_root_names,
221 is_shared: project.share.is_some(),
222 guests: guests.into_iter().collect(),
223 });
224 }
225 }
226
227 contacts.into_values().collect()
228 }
229
230 pub fn register_project(
231 &mut self,
232 host_connection_id: ConnectionId,
233 host_user_id: UserId,
234 ) -> u64 {
235 let project_id = self.next_project_id;
236 self.projects.insert(
237 project_id,
238 Project {
239 host_connection_id,
240 host_user_id,
241 share: None,
242 worktrees: Default::default(),
243 },
244 );
245 self.next_project_id += 1;
246 project_id
247 }
248
249 pub fn register_worktree(
250 &mut self,
251 project_id: u64,
252 worktree_id: u64,
253 connection_id: ConnectionId,
254 worktree: Worktree,
255 ) -> tide::Result<()> {
256 let project = self
257 .projects
258 .get_mut(&project_id)
259 .ok_or_else(|| anyhow!("no such project"))?;
260 if project.host_connection_id == connection_id {
261 for authorized_user_id in &worktree.authorized_user_ids {
262 self.visible_projects_by_user_id
263 .entry(*authorized_user_id)
264 .or_default()
265 .insert(project_id);
266 }
267 if let Some(connection) = self.connections.get_mut(&project.host_connection_id) {
268 connection.projects.insert(project_id);
269 }
270 project.worktrees.insert(worktree_id, worktree);
271 if let Ok(share) = project.share_mut() {
272 share.worktrees.insert(worktree_id, Default::default());
273 }
274
275 #[cfg(test)]
276 self.check_invariants();
277 Ok(())
278 } else {
279 Err(anyhow!("no such project"))?
280 }
281 }
282
283 pub fn unregister_project(
284 &mut self,
285 project_id: u64,
286 connection_id: ConnectionId,
287 ) -> tide::Result<Project> {
288 match self.projects.entry(project_id) {
289 hash_map::Entry::Occupied(e) => {
290 if e.get().host_connection_id == connection_id {
291 for user_id in e.get().authorized_user_ids() {
292 if let hash_map::Entry::Occupied(mut projects) =
293 self.visible_projects_by_user_id.entry(user_id)
294 {
295 projects.get_mut().remove(&project_id);
296 }
297 }
298
299 let project = e.remove();
300
301 if let Some(host_connection) = self.connections.get_mut(&connection_id) {
302 host_connection.projects.remove(&project_id);
303 }
304
305 if let Some(share) = &project.share {
306 for guest_connection in share.guests.keys() {
307 if let Some(connection) = self.connections.get_mut(&guest_connection) {
308 connection.projects.remove(&project_id);
309 }
310 }
311 }
312
313 #[cfg(test)]
314 self.check_invariants();
315 Ok(project)
316 } else {
317 Err(anyhow!("no such project"))?
318 }
319 }
320 hash_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?,
321 }
322 }
323
324 pub fn unregister_worktree(
325 &mut self,
326 project_id: u64,
327 worktree_id: u64,
328 acting_connection_id: ConnectionId,
329 ) -> tide::Result<(Worktree, Vec<ConnectionId>)> {
330 let project = self
331 .projects
332 .get_mut(&project_id)
333 .ok_or_else(|| anyhow!("no such project"))?;
334 if project.host_connection_id != acting_connection_id {
335 Err(anyhow!("not your worktree"))?;
336 }
337
338 let worktree = project
339 .worktrees
340 .remove(&worktree_id)
341 .ok_or_else(|| anyhow!("no such worktree"))?;
342
343 let mut guest_connection_ids = Vec::new();
344 if let Ok(share) = project.share_mut() {
345 guest_connection_ids.extend(share.guests.keys());
346 share.worktrees.remove(&worktree_id);
347 }
348
349 for authorized_user_id in &worktree.authorized_user_ids {
350 if let Some(visible_projects) =
351 self.visible_projects_by_user_id.get_mut(authorized_user_id)
352 {
353 if !project.has_authorized_user_id(*authorized_user_id) {
354 visible_projects.remove(&project_id);
355 }
356 }
357 }
358
359 #[cfg(test)]
360 self.check_invariants();
361
362 Ok((worktree, guest_connection_ids))
363 }
364
365 pub fn share_project(&mut self, project_id: u64, connection_id: ConnectionId) -> bool {
366 if let Some(project) = self.projects.get_mut(&project_id) {
367 if project.host_connection_id == connection_id {
368 let mut share = ProjectShare::default();
369 for worktree_id in project.worktrees.keys() {
370 share.worktrees.insert(*worktree_id, Default::default());
371 }
372 project.share = Some(share);
373 return true;
374 }
375 }
376 false
377 }
378
379 pub fn unshare_project(
380 &mut self,
381 project_id: u64,
382 acting_connection_id: ConnectionId,
383 ) -> tide::Result<UnsharedProject> {
384 let project = if let Some(project) = self.projects.get_mut(&project_id) {
385 project
386 } else {
387 return Err(anyhow!("no such project"))?;
388 };
389
390 if project.host_connection_id != acting_connection_id {
391 return Err(anyhow!("not your project"))?;
392 }
393
394 let connection_ids = project.connection_ids();
395 let authorized_user_ids = project.authorized_user_ids();
396 if let Some(share) = project.share.take() {
397 for connection_id in share.guests.into_keys() {
398 if let Some(connection) = self.connections.get_mut(&connection_id) {
399 connection.projects.remove(&project_id);
400 }
401 }
402
403 #[cfg(test)]
404 self.check_invariants();
405
406 Ok(UnsharedProject {
407 connection_ids,
408 authorized_user_ids,
409 })
410 } else {
411 Err(anyhow!("project is not shared"))?
412 }
413 }
414
415 pub fn update_diagnostic_summary(
416 &mut self,
417 project_id: u64,
418 worktree_id: u64,
419 connection_id: ConnectionId,
420 summary: proto::DiagnosticSummary,
421 ) -> tide::Result<Vec<ConnectionId>> {
422 let project = self
423 .projects
424 .get_mut(&project_id)
425 .ok_or_else(|| anyhow!("no such project"))?;
426 if project.host_connection_id == connection_id {
427 let worktree = project
428 .share_mut()?
429 .worktrees
430 .get_mut(&worktree_id)
431 .ok_or_else(|| anyhow!("no such worktree"))?;
432 worktree
433 .diagnostic_summaries
434 .insert(summary.path.clone().into(), summary);
435 return Ok(project.connection_ids());
436 }
437
438 Err(anyhow!("no such worktree"))?
439 }
440
441 pub fn join_project(
442 &mut self,
443 connection_id: ConnectionId,
444 user_id: UserId,
445 project_id: u64,
446 ) -> tide::Result<JoinedProject> {
447 let connection = self
448 .connections
449 .get_mut(&connection_id)
450 .ok_or_else(|| anyhow!("no such connection"))?;
451 let project = self
452 .projects
453 .get_mut(&project_id)
454 .and_then(|project| {
455 if project.has_authorized_user_id(user_id) {
456 Some(project)
457 } else {
458 None
459 }
460 })
461 .ok_or_else(|| anyhow!("no such project"))?;
462
463 let share = project.share_mut()?;
464 connection.projects.insert(project_id);
465
466 let mut replica_id = 1;
467 while share.active_replica_ids.contains(&replica_id) {
468 replica_id += 1;
469 }
470 share.active_replica_ids.insert(replica_id);
471 share.guests.insert(connection_id, (replica_id, user_id));
472
473 #[cfg(test)]
474 self.check_invariants();
475
476 Ok(JoinedProject {
477 replica_id,
478 project: &self.projects[&project_id],
479 })
480 }
481
482 pub fn leave_project(
483 &mut self,
484 connection_id: ConnectionId,
485 project_id: u64,
486 ) -> tide::Result<LeftProject> {
487 let project = self
488 .projects
489 .get_mut(&project_id)
490 .ok_or_else(|| anyhow!("no such project"))?;
491 let share = project
492 .share
493 .as_mut()
494 .ok_or_else(|| anyhow!("project is not shared"))?;
495 let (replica_id, _) = share
496 .guests
497 .remove(&connection_id)
498 .ok_or_else(|| anyhow!("cannot leave a project before joining it"))?;
499 share.active_replica_ids.remove(&replica_id);
500
501 if let Some(connection) = self.connections.get_mut(&connection_id) {
502 connection.projects.remove(&project_id);
503 }
504
505 let connection_ids = project.connection_ids();
506 let authorized_user_ids = project.authorized_user_ids();
507
508 #[cfg(test)]
509 self.check_invariants();
510
511 Ok(LeftProject {
512 connection_ids,
513 authorized_user_ids,
514 })
515 }
516
517 pub fn update_worktree(
518 &mut self,
519 connection_id: ConnectionId,
520 project_id: u64,
521 worktree_id: u64,
522 removed_entries: &[u64],
523 updated_entries: &[proto::Entry],
524 ) -> tide::Result<Vec<ConnectionId>> {
525 let project = self.write_project(project_id, connection_id)?;
526 let worktree = project
527 .share_mut()?
528 .worktrees
529 .get_mut(&worktree_id)
530 .ok_or_else(|| anyhow!("no such worktree"))?;
531 for entry_id in removed_entries {
532 worktree.entries.remove(&entry_id);
533 }
534 for entry in updated_entries {
535 worktree.entries.insert(entry.id, entry.clone());
536 }
537 Ok(project.connection_ids())
538 }
539
540 pub fn project_connection_ids(
541 &self,
542 project_id: u64,
543 acting_connection_id: ConnectionId,
544 ) -> tide::Result<Vec<ConnectionId>> {
545 Ok(self
546 .read_project(project_id, acting_connection_id)?
547 .connection_ids())
548 }
549
550 pub fn channel_connection_ids(&self, channel_id: ChannelId) -> tide::Result<Vec<ConnectionId>> {
551 Ok(self
552 .channels
553 .get(&channel_id)
554 .ok_or_else(|| anyhow!("no such channel"))?
555 .connection_ids())
556 }
557
558 #[cfg(test)]
559 pub fn project(&self, project_id: u64) -> Option<&Project> {
560 self.projects.get(&project_id)
561 }
562
563 pub fn read_project(
564 &self,
565 project_id: u64,
566 connection_id: ConnectionId,
567 ) -> tide::Result<&Project> {
568 let project = self
569 .projects
570 .get(&project_id)
571 .ok_or_else(|| anyhow!("no such project"))?;
572 if project.host_connection_id == connection_id
573 || project
574 .share
575 .as_ref()
576 .ok_or_else(|| anyhow!("project is not shared"))?
577 .guests
578 .contains_key(&connection_id)
579 {
580 Ok(project)
581 } else {
582 Err(anyhow!("no such project"))?
583 }
584 }
585
586 fn write_project(
587 &mut self,
588 project_id: u64,
589 connection_id: ConnectionId,
590 ) -> tide::Result<&mut Project> {
591 let project = self
592 .projects
593 .get_mut(&project_id)
594 .ok_or_else(|| anyhow!("no such project"))?;
595 if project.host_connection_id == connection_id
596 || project
597 .share
598 .as_ref()
599 .ok_or_else(|| anyhow!("project is not shared"))?
600 .guests
601 .contains_key(&connection_id)
602 {
603 Ok(project)
604 } else {
605 Err(anyhow!("no such project"))?
606 }
607 }
608
609 #[cfg(test)]
610 fn check_invariants(&self) {
611 for (connection_id, connection) in &self.connections {
612 for project_id in &connection.projects {
613 let project = &self.projects.get(&project_id).unwrap();
614 if project.host_connection_id != *connection_id {
615 assert!(project
616 .share
617 .as_ref()
618 .unwrap()
619 .guests
620 .contains_key(connection_id));
621 }
622 }
623 for channel_id in &connection.channels {
624 let channel = self.channels.get(channel_id).unwrap();
625 assert!(channel.connection_ids.contains(connection_id));
626 }
627 assert!(self
628 .connections_by_user_id
629 .get(&connection.user_id)
630 .unwrap()
631 .contains(connection_id));
632 }
633
634 for (user_id, connection_ids) in &self.connections_by_user_id {
635 for connection_id in connection_ids {
636 assert_eq!(
637 self.connections.get(connection_id).unwrap().user_id,
638 *user_id
639 );
640 }
641 }
642
643 for (project_id, project) in &self.projects {
644 let host_connection = self.connections.get(&project.host_connection_id).unwrap();
645 assert!(host_connection.projects.contains(project_id));
646
647 for authorized_user_ids in project.authorized_user_ids() {
648 let visible_project_ids = self
649 .visible_projects_by_user_id
650 .get(&authorized_user_ids)
651 .unwrap();
652 assert!(visible_project_ids.contains(project_id));
653 }
654
655 if let Some(share) = &project.share {
656 for guest_connection_id in share.guests.keys() {
657 let guest_connection = self.connections.get(guest_connection_id).unwrap();
658 assert!(guest_connection.projects.contains(project_id));
659 }
660 assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
661 assert_eq!(
662 share.active_replica_ids,
663 share
664 .guests
665 .values()
666 .map(|(replica_id, _)| *replica_id)
667 .collect::<HashSet<_>>(),
668 );
669 }
670 }
671
672 for (user_id, visible_project_ids) in &self.visible_projects_by_user_id {
673 for project_id in visible_project_ids {
674 let project = self.projects.get(project_id).unwrap();
675 assert!(project.authorized_user_ids().contains(user_id));
676 }
677 }
678
679 for (channel_id, channel) in &self.channels {
680 for connection_id in &channel.connection_ids {
681 let connection = self.connections.get(connection_id).unwrap();
682 assert!(connection.channels.contains(channel_id));
683 }
684 }
685 }
686}
687
688impl Project {
689 pub fn has_authorized_user_id(&self, user_id: UserId) -> bool {
690 self.worktrees
691 .values()
692 .any(|worktree| worktree.authorized_user_ids.contains(&user_id))
693 }
694
695 pub fn authorized_user_ids(&self) -> Vec<UserId> {
696 let mut ids = self
697 .worktrees
698 .values()
699 .flat_map(|worktree| worktree.authorized_user_ids.iter())
700 .copied()
701 .collect::<Vec<_>>();
702 ids.sort_unstable();
703 ids.dedup();
704 ids
705 }
706
707 pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
708 if let Some(share) = &self.share {
709 share.guests.keys().copied().collect()
710 } else {
711 Vec::new()
712 }
713 }
714
715 pub fn connection_ids(&self) -> Vec<ConnectionId> {
716 if let Some(share) = &self.share {
717 share
718 .guests
719 .keys()
720 .copied()
721 .chain(Some(self.host_connection_id))
722 .collect()
723 } else {
724 vec![self.host_connection_id]
725 }
726 }
727
728 pub fn share(&self) -> tide::Result<&ProjectShare> {
729 Ok(self
730 .share
731 .as_ref()
732 .ok_or_else(|| anyhow!("worktree is not shared"))?)
733 }
734
735 fn share_mut(&mut self) -> tide::Result<&mut ProjectShare> {
736 Ok(self
737 .share
738 .as_mut()
739 .ok_or_else(|| anyhow!("worktree is not shared"))?)
740 }
741}
742
743impl Channel {
744 fn connection_ids(&self) -> Vec<ConnectionId> {
745 self.connection_ids.iter().copied().collect()
746 }
747}