1use crate::db::{ChannelId, UserId};
2use anyhow::anyhow;
3use collections::{BTreeMap, HashMap, HashSet};
4use rpc::{proto, ConnectionId};
5use std::{collections::hash_map, path::PathBuf};
6
7#[derive(Default)]
8pub struct Store {
9 connections: HashMap<ConnectionId, ConnectionState>,
10 connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
11 projects: HashMap<u64, Project>,
12 visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
13 channels: HashMap<ChannelId, Channel>,
14 next_project_id: u64,
15}
16
17struct ConnectionState {
18 user_id: UserId,
19 projects: HashSet<u64>,
20 channels: HashSet<ChannelId>,
21}
22
23pub struct Project {
24 pub host_connection_id: ConnectionId,
25 pub host_user_id: UserId,
26 pub share: Option<ProjectShare>,
27 pub worktrees: HashMap<u64, Worktree>,
28}
29
30pub struct Worktree {
31 pub authorized_user_ids: Vec<UserId>,
32 pub root_name: String,
33 pub share: Option<WorktreeShare>,
34 pub weak: bool,
35}
36
37#[derive(Default)]
38pub struct ProjectShare {
39 pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
40 pub active_replica_ids: HashSet<ReplicaId>,
41}
42
43pub struct WorktreeShare {
44 pub entries: HashMap<u64, proto::Entry>,
45 pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
46}
47
48#[derive(Default)]
49pub struct Channel {
50 pub connection_ids: HashSet<ConnectionId>,
51}
52
53pub type ReplicaId = u16;
54
55#[derive(Default)]
56pub struct RemovedConnectionState {
57 pub hosted_projects: HashMap<u64, Project>,
58 pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
59 pub contact_ids: HashSet<UserId>,
60}
61
62pub struct JoinedProject<'a> {
63 pub replica_id: ReplicaId,
64 pub project: &'a Project,
65}
66
67pub struct UnsharedProject {
68 pub connection_ids: Vec<ConnectionId>,
69 pub authorized_user_ids: Vec<UserId>,
70}
71
72pub struct LeftProject {
73 pub connection_ids: Vec<ConnectionId>,
74 pub authorized_user_ids: Vec<UserId>,
75}
76
77pub struct SharedWorktree {
78 pub authorized_user_ids: Vec<UserId>,
79 pub connection_ids: Vec<ConnectionId>,
80}
81
82impl Store {
83 pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
84 self.connections.insert(
85 connection_id,
86 ConnectionState {
87 user_id,
88 projects: Default::default(),
89 channels: Default::default(),
90 },
91 );
92 self.connections_by_user_id
93 .entry(user_id)
94 .or_default()
95 .insert(connection_id);
96 }
97
98 pub fn remove_connection(
99 &mut self,
100 connection_id: ConnectionId,
101 ) -> tide::Result<RemovedConnectionState> {
102 let connection = if let Some(connection) = self.connections.remove(&connection_id) {
103 connection
104 } else {
105 return Err(anyhow!("no such connection"))?;
106 };
107
108 for channel_id in &connection.channels {
109 if let Some(channel) = self.channels.get_mut(&channel_id) {
110 channel.connection_ids.remove(&connection_id);
111 }
112 }
113
114 let user_connections = self
115 .connections_by_user_id
116 .get_mut(&connection.user_id)
117 .unwrap();
118 user_connections.remove(&connection_id);
119 if user_connections.is_empty() {
120 self.connections_by_user_id.remove(&connection.user_id);
121 }
122
123 let mut result = RemovedConnectionState::default();
124 for project_id in connection.projects.clone() {
125 if let Some(project) = self.unregister_project(project_id, connection_id) {
126 result.contact_ids.extend(project.authorized_user_ids());
127 result.hosted_projects.insert(project_id, project);
128 } else if let Some(project) = self.leave_project(connection_id, project_id) {
129 result
130 .guest_project_ids
131 .insert(project_id, project.connection_ids);
132 result.contact_ids.extend(project.authorized_user_ids);
133 }
134 }
135
136 #[cfg(test)]
137 self.check_invariants();
138
139 Ok(result)
140 }
141
142 #[cfg(test)]
143 pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
144 self.channels.get(&id)
145 }
146
147 pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
148 if let Some(connection) = self.connections.get_mut(&connection_id) {
149 connection.channels.insert(channel_id);
150 self.channels
151 .entry(channel_id)
152 .or_default()
153 .connection_ids
154 .insert(connection_id);
155 }
156 }
157
158 pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
159 if let Some(connection) = self.connections.get_mut(&connection_id) {
160 connection.channels.remove(&channel_id);
161 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
162 entry.get_mut().connection_ids.remove(&connection_id);
163 if entry.get_mut().connection_ids.is_empty() {
164 entry.remove();
165 }
166 }
167 }
168 }
169
170 pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
171 Ok(self
172 .connections
173 .get(&connection_id)
174 .ok_or_else(|| anyhow!("unknown connection"))?
175 .user_id)
176 }
177
178 pub fn connection_ids_for_user<'a>(
179 &'a self,
180 user_id: UserId,
181 ) -> impl 'a + Iterator<Item = ConnectionId> {
182 self.connections_by_user_id
183 .get(&user_id)
184 .into_iter()
185 .flatten()
186 .copied()
187 }
188
189 pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
190 let mut contacts = HashMap::default();
191 for project_id in self
192 .visible_projects_by_user_id
193 .get(&user_id)
194 .unwrap_or(&HashSet::default())
195 {
196 let project = &self.projects[project_id];
197
198 let mut guests = HashSet::default();
199 if let Ok(share) = project.share() {
200 for guest_connection_id in share.guests.keys() {
201 if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
202 guests.insert(user_id.to_proto());
203 }
204 }
205 }
206
207 if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
208 let mut worktree_root_names = project
209 .worktrees
210 .values()
211 .filter(|worktree| !worktree.weak)
212 .map(|worktree| worktree.root_name.clone())
213 .collect::<Vec<_>>();
214 worktree_root_names.sort_unstable();
215 contacts
216 .entry(host_user_id)
217 .or_insert_with(|| proto::Contact {
218 user_id: host_user_id.to_proto(),
219 projects: Vec::new(),
220 })
221 .projects
222 .push(proto::ProjectMetadata {
223 id: *project_id,
224 worktree_root_names,
225 is_shared: project.share.is_some(),
226 guests: guests.into_iter().collect(),
227 });
228 }
229 }
230
231 contacts.into_values().collect()
232 }
233
234 pub fn register_project(
235 &mut self,
236 host_connection_id: ConnectionId,
237 host_user_id: UserId,
238 ) -> u64 {
239 let project_id = self.next_project_id;
240 self.projects.insert(
241 project_id,
242 Project {
243 host_connection_id,
244 host_user_id,
245 share: None,
246 worktrees: Default::default(),
247 },
248 );
249 self.next_project_id += 1;
250 project_id
251 }
252
253 pub fn register_worktree(
254 &mut self,
255 project_id: u64,
256 worktree_id: u64,
257 worktree: Worktree,
258 ) -> bool {
259 if let Some(project) = self.projects.get_mut(&project_id) {
260 for authorized_user_id in &worktree.authorized_user_ids {
261 self.visible_projects_by_user_id
262 .entry(*authorized_user_id)
263 .or_default()
264 .insert(project_id);
265 }
266 if let Some(connection) = self.connections.get_mut(&project.host_connection_id) {
267 connection.projects.insert(project_id);
268 }
269 project.worktrees.insert(worktree_id, worktree);
270
271 #[cfg(test)]
272 self.check_invariants();
273 true
274 } else {
275 false
276 }
277 }
278
279 pub fn unregister_project(
280 &mut self,
281 project_id: u64,
282 connection_id: ConnectionId,
283 ) -> Option<Project> {
284 match self.projects.entry(project_id) {
285 hash_map::Entry::Occupied(e) => {
286 if e.get().host_connection_id == connection_id {
287 for user_id in e.get().authorized_user_ids() {
288 if let hash_map::Entry::Occupied(mut projects) =
289 self.visible_projects_by_user_id.entry(user_id)
290 {
291 projects.get_mut().remove(&project_id);
292 }
293 }
294
295 Some(e.remove())
296 } else {
297 None
298 }
299 }
300 hash_map::Entry::Vacant(_) => None,
301 }
302 }
303
304 pub fn unregister_worktree(
305 &mut self,
306 project_id: u64,
307 worktree_id: u64,
308 acting_connection_id: ConnectionId,
309 ) -> tide::Result<(Worktree, Vec<ConnectionId>)> {
310 let project = self
311 .projects
312 .get_mut(&project_id)
313 .ok_or_else(|| anyhow!("no such project"))?;
314 if project.host_connection_id != acting_connection_id {
315 Err(anyhow!("not your worktree"))?;
316 }
317
318 let worktree = project
319 .worktrees
320 .remove(&worktree_id)
321 .ok_or_else(|| anyhow!("no such worktree"))?;
322
323 let mut guest_connection_ids = Vec::new();
324 if let Some(share) = &project.share {
325 guest_connection_ids.extend(share.guests.keys());
326 }
327
328 for authorized_user_id in &worktree.authorized_user_ids {
329 if let Some(visible_projects) =
330 self.visible_projects_by_user_id.get_mut(authorized_user_id)
331 {
332 if !project.has_authorized_user_id(*authorized_user_id) {
333 visible_projects.remove(&project_id);
334 }
335 }
336 }
337
338 #[cfg(test)]
339 self.check_invariants();
340
341 Ok((worktree, guest_connection_ids))
342 }
343
344 pub fn share_project(&mut self, project_id: u64, connection_id: ConnectionId) -> bool {
345 if let Some(project) = self.projects.get_mut(&project_id) {
346 if project.host_connection_id == connection_id {
347 project.share = Some(ProjectShare::default());
348 return true;
349 }
350 }
351 false
352 }
353
354 pub fn unshare_project(
355 &mut self,
356 project_id: u64,
357 acting_connection_id: ConnectionId,
358 ) -> tide::Result<UnsharedProject> {
359 let project = if let Some(project) = self.projects.get_mut(&project_id) {
360 project
361 } else {
362 return Err(anyhow!("no such project"))?;
363 };
364
365 if project.host_connection_id != acting_connection_id {
366 return Err(anyhow!("not your project"))?;
367 }
368
369 let connection_ids = project.connection_ids();
370 let authorized_user_ids = project.authorized_user_ids();
371 if let Some(share) = project.share.take() {
372 for connection_id in share.guests.into_keys() {
373 if let Some(connection) = self.connections.get_mut(&connection_id) {
374 connection.projects.remove(&project_id);
375 }
376 }
377
378 for worktree in project.worktrees.values_mut() {
379 worktree.share.take();
380 }
381
382 #[cfg(test)]
383 self.check_invariants();
384
385 Ok(UnsharedProject {
386 connection_ids,
387 authorized_user_ids,
388 })
389 } else {
390 Err(anyhow!("project is not shared"))?
391 }
392 }
393
394 pub fn share_worktree(
395 &mut self,
396 project_id: u64,
397 worktree_id: u64,
398 connection_id: ConnectionId,
399 entries: HashMap<u64, proto::Entry>,
400 diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
401 ) -> Option<SharedWorktree> {
402 let project = self.projects.get_mut(&project_id)?;
403 let worktree = project.worktrees.get_mut(&worktree_id)?;
404 if project.host_connection_id == connection_id && project.share.is_some() {
405 worktree.share = Some(WorktreeShare {
406 entries,
407 diagnostic_summaries,
408 });
409 Some(SharedWorktree {
410 authorized_user_ids: project.authorized_user_ids(),
411 connection_ids: project.guest_connection_ids(),
412 })
413 } else {
414 None
415 }
416 }
417
418 pub fn update_diagnostic_summary(
419 &mut self,
420 project_id: u64,
421 worktree_id: u64,
422 connection_id: ConnectionId,
423 summary: proto::DiagnosticSummary,
424 ) -> Option<Vec<ConnectionId>> {
425 let project = self.projects.get_mut(&project_id)?;
426 let worktree = project.worktrees.get_mut(&worktree_id)?;
427 if project.host_connection_id == connection_id {
428 if let Some(share) = worktree.share.as_mut() {
429 share
430 .diagnostic_summaries
431 .insert(summary.path.clone().into(), summary);
432 return Some(project.connection_ids());
433 }
434 }
435
436 None
437 }
438
439 pub fn join_project(
440 &mut self,
441 connection_id: ConnectionId,
442 user_id: UserId,
443 project_id: u64,
444 ) -> tide::Result<JoinedProject> {
445 let connection = self
446 .connections
447 .get_mut(&connection_id)
448 .ok_or_else(|| anyhow!("no such connection"))?;
449 let project = self
450 .projects
451 .get_mut(&project_id)
452 .and_then(|project| {
453 if project.has_authorized_user_id(user_id) {
454 Some(project)
455 } else {
456 None
457 }
458 })
459 .ok_or_else(|| anyhow!("no such project"))?;
460
461 let share = project.share_mut()?;
462 connection.projects.insert(project_id);
463
464 let mut replica_id = 1;
465 while share.active_replica_ids.contains(&replica_id) {
466 replica_id += 1;
467 }
468 share.active_replica_ids.insert(replica_id);
469 share.guests.insert(connection_id, (replica_id, user_id));
470
471 #[cfg(test)]
472 self.check_invariants();
473
474 Ok(JoinedProject {
475 replica_id,
476 project: &self.projects[&project_id],
477 })
478 }
479
480 pub fn leave_project(
481 &mut self,
482 connection_id: ConnectionId,
483 project_id: u64,
484 ) -> Option<LeftProject> {
485 let project = self.projects.get_mut(&project_id)?;
486 let share = project.share.as_mut()?;
487 let (replica_id, _) = share.guests.remove(&connection_id)?;
488 share.active_replica_ids.remove(&replica_id);
489
490 if let Some(connection) = self.connections.get_mut(&connection_id) {
491 connection.projects.remove(&project_id);
492 }
493
494 let connection_ids = project.connection_ids();
495 let authorized_user_ids = project.authorized_user_ids();
496
497 #[cfg(test)]
498 self.check_invariants();
499
500 Some(LeftProject {
501 connection_ids,
502 authorized_user_ids,
503 })
504 }
505
506 pub fn update_worktree(
507 &mut self,
508 connection_id: ConnectionId,
509 project_id: u64,
510 worktree_id: u64,
511 removed_entries: &[u64],
512 updated_entries: &[proto::Entry],
513 ) -> Option<Vec<ConnectionId>> {
514 let project = self.write_project(project_id, connection_id)?;
515 let share = project.worktrees.get_mut(&worktree_id)?.share.as_mut()?;
516 for entry_id in removed_entries {
517 share.entries.remove(&entry_id);
518 }
519 for entry in updated_entries {
520 share.entries.insert(entry.id, entry.clone());
521 }
522 Some(project.connection_ids())
523 }
524
525 pub fn project_connection_ids(
526 &self,
527 project_id: u64,
528 acting_connection_id: ConnectionId,
529 ) -> Option<Vec<ConnectionId>> {
530 Some(
531 self.read_project(project_id, acting_connection_id)?
532 .connection_ids(),
533 )
534 }
535
536 pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
537 Some(self.channels.get(&channel_id)?.connection_ids())
538 }
539
540 #[cfg(test)]
541 pub fn project(&self, project_id: u64) -> Option<&Project> {
542 self.projects.get(&project_id)
543 }
544
545 pub fn read_project(&self, project_id: u64, connection_id: ConnectionId) -> Option<&Project> {
546 let project = self.projects.get(&project_id)?;
547 if project.host_connection_id == connection_id
548 || project.share.as_ref()?.guests.contains_key(&connection_id)
549 {
550 Some(project)
551 } else {
552 None
553 }
554 }
555
556 fn write_project(
557 &mut self,
558 project_id: u64,
559 connection_id: ConnectionId,
560 ) -> Option<&mut Project> {
561 let project = self.projects.get_mut(&project_id)?;
562 if project.host_connection_id == connection_id
563 || project.share.as_ref()?.guests.contains_key(&connection_id)
564 {
565 Some(project)
566 } else {
567 None
568 }
569 }
570
571 #[cfg(test)]
572 fn check_invariants(&self) {
573 for (connection_id, connection) in &self.connections {
574 for project_id in &connection.projects {
575 let project = &self.projects.get(&project_id).unwrap();
576 if project.host_connection_id != *connection_id {
577 assert!(project
578 .share
579 .as_ref()
580 .unwrap()
581 .guests
582 .contains_key(connection_id));
583 }
584 }
585 for channel_id in &connection.channels {
586 let channel = self.channels.get(channel_id).unwrap();
587 assert!(channel.connection_ids.contains(connection_id));
588 }
589 assert!(self
590 .connections_by_user_id
591 .get(&connection.user_id)
592 .unwrap()
593 .contains(connection_id));
594 }
595
596 for (user_id, connection_ids) in &self.connections_by_user_id {
597 for connection_id in connection_ids {
598 assert_eq!(
599 self.connections.get(connection_id).unwrap().user_id,
600 *user_id
601 );
602 }
603 }
604
605 for (project_id, project) in &self.projects {
606 let host_connection = self.connections.get(&project.host_connection_id).unwrap();
607 assert!(host_connection.projects.contains(project_id));
608
609 for authorized_user_ids in project.authorized_user_ids() {
610 let visible_project_ids = self
611 .visible_projects_by_user_id
612 .get(&authorized_user_ids)
613 .unwrap();
614 assert!(visible_project_ids.contains(project_id));
615 }
616
617 if let Some(share) = &project.share {
618 for guest_connection_id in share.guests.keys() {
619 let guest_connection = self.connections.get(guest_connection_id).unwrap();
620 assert!(guest_connection.projects.contains(project_id));
621 }
622 assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
623 assert_eq!(
624 share.active_replica_ids,
625 share
626 .guests
627 .values()
628 .map(|(replica_id, _)| *replica_id)
629 .collect::<HashSet<_>>(),
630 );
631 }
632 }
633
634 for (user_id, visible_project_ids) in &self.visible_projects_by_user_id {
635 for project_id in visible_project_ids {
636 let project = self.projects.get(project_id).unwrap();
637 assert!(project.authorized_user_ids().contains(user_id));
638 }
639 }
640
641 for (channel_id, channel) in &self.channels {
642 for connection_id in &channel.connection_ids {
643 let connection = self.connections.get(connection_id).unwrap();
644 assert!(connection.channels.contains(channel_id));
645 }
646 }
647 }
648}
649
650impl Project {
651 pub fn has_authorized_user_id(&self, user_id: UserId) -> bool {
652 self.worktrees
653 .values()
654 .any(|worktree| worktree.authorized_user_ids.contains(&user_id))
655 }
656
657 pub fn authorized_user_ids(&self) -> Vec<UserId> {
658 let mut ids = self
659 .worktrees
660 .values()
661 .flat_map(|worktree| worktree.authorized_user_ids.iter())
662 .copied()
663 .collect::<Vec<_>>();
664 ids.sort_unstable();
665 ids.dedup();
666 ids
667 }
668
669 pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
670 if let Some(share) = &self.share {
671 share.guests.keys().copied().collect()
672 } else {
673 Vec::new()
674 }
675 }
676
677 pub fn connection_ids(&self) -> Vec<ConnectionId> {
678 if let Some(share) = &self.share {
679 share
680 .guests
681 .keys()
682 .copied()
683 .chain(Some(self.host_connection_id))
684 .collect()
685 } else {
686 vec![self.host_connection_id]
687 }
688 }
689
690 pub fn share(&self) -> tide::Result<&ProjectShare> {
691 Ok(self
692 .share
693 .as_ref()
694 .ok_or_else(|| anyhow!("worktree is not shared"))?)
695 }
696
697 fn share_mut(&mut self) -> tide::Result<&mut ProjectShare> {
698 Ok(self
699 .share
700 .as_mut()
701 .ok_or_else(|| anyhow!("worktree is not shared"))?)
702 }
703}
704
705impl Channel {
706 fn connection_ids(&self) -> Vec<ConnectionId> {
707 self.connection_ids.iter().copied().collect()
708 }
709}