1use crate::db::{ChannelId, UserId};
2use anyhow::anyhow;
3use collections::{HashMap, HashSet};
4use rpc::{proto, ConnectionId};
5use std::collections::hash_map;
6
7#[derive(Default)]
8pub struct Store {
9 connections: HashMap<ConnectionId, ConnectionState>,
10 connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
11 projects: HashMap<u64, Project>,
12 visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
13 channels: HashMap<ChannelId, Channel>,
14 next_project_id: u64,
15}
16
17struct ConnectionState {
18 user_id: UserId,
19 projects: HashSet<u64>,
20 channels: HashSet<ChannelId>,
21}
22
23pub struct Project {
24 pub host_connection_id: ConnectionId,
25 pub host_user_id: UserId,
26 pub share: Option<ProjectShare>,
27 pub worktrees: HashMap<u64, Worktree>,
28}
29
30pub struct Worktree {
31 pub authorized_user_ids: Vec<UserId>,
32 pub root_name: String,
33 pub share: Option<WorktreeShare>,
34}
35
36#[derive(Default)]
37pub struct ProjectShare {
38 pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
39 pub active_replica_ids: HashSet<ReplicaId>,
40}
41
42pub struct WorktreeShare {
43 pub entries: HashMap<u64, proto::Entry>,
44}
45
46#[derive(Default)]
47pub struct Channel {
48 pub connection_ids: HashSet<ConnectionId>,
49}
50
51pub type ReplicaId = u16;
52
53#[derive(Default)]
54pub struct RemovedConnectionState {
55 pub hosted_projects: HashMap<u64, Project>,
56 pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
57 pub contact_ids: HashSet<UserId>,
58}
59
60pub struct JoinedProject<'a> {
61 pub replica_id: ReplicaId,
62 pub project: &'a Project,
63}
64
65pub struct UnsharedWorktree {
66 pub connection_ids: Vec<ConnectionId>,
67 pub authorized_user_ids: Vec<UserId>,
68}
69
70pub struct LeftProject {
71 pub connection_ids: Vec<ConnectionId>,
72 pub authorized_user_ids: Vec<UserId>,
73}
74
75impl Store {
76 pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
77 self.connections.insert(
78 connection_id,
79 ConnectionState {
80 user_id,
81 projects: Default::default(),
82 channels: Default::default(),
83 },
84 );
85 self.connections_by_user_id
86 .entry(user_id)
87 .or_default()
88 .insert(connection_id);
89 }
90
91 pub fn remove_connection(
92 &mut self,
93 connection_id: ConnectionId,
94 ) -> tide::Result<RemovedConnectionState> {
95 let connection = if let Some(connection) = self.connections.remove(&connection_id) {
96 connection
97 } else {
98 return Err(anyhow!("no such connection"))?;
99 };
100
101 for channel_id in &connection.channels {
102 if let Some(channel) = self.channels.get_mut(&channel_id) {
103 channel.connection_ids.remove(&connection_id);
104 }
105 }
106
107 let user_connections = self
108 .connections_by_user_id
109 .get_mut(&connection.user_id)
110 .unwrap();
111 user_connections.remove(&connection_id);
112 if user_connections.is_empty() {
113 self.connections_by_user_id.remove(&connection.user_id);
114 }
115
116 let mut result = RemovedConnectionState::default();
117 for project_id in connection.projects.clone() {
118 if let Some((project, authorized_user_ids)) =
119 self.unregister_project(project_id, connection_id)
120 {
121 result.contact_ids.extend(authorized_user_ids);
122 result.hosted_projects.insert(project_id, project);
123 } else if let Some(project) = self.leave_project(connection_id, project_id) {
124 result
125 .guest_project_ids
126 .insert(project_id, project.connection_ids);
127 result.contact_ids.extend(project.authorized_user_ids);
128 }
129 }
130
131 #[cfg(test)]
132 self.check_invariants();
133
134 Ok(result)
135 }
136
137 #[cfg(test)]
138 pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
139 self.channels.get(&id)
140 }
141
142 pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
143 if let Some(connection) = self.connections.get_mut(&connection_id) {
144 connection.channels.insert(channel_id);
145 self.channels
146 .entry(channel_id)
147 .or_default()
148 .connection_ids
149 .insert(connection_id);
150 }
151 }
152
153 pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
154 if let Some(connection) = self.connections.get_mut(&connection_id) {
155 connection.channels.remove(&channel_id);
156 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
157 entry.get_mut().connection_ids.remove(&connection_id);
158 if entry.get_mut().connection_ids.is_empty() {
159 entry.remove();
160 }
161 }
162 }
163 }
164
165 pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
166 Ok(self
167 .connections
168 .get(&connection_id)
169 .ok_or_else(|| anyhow!("unknown connection"))?
170 .user_id)
171 }
172
173 pub fn connection_ids_for_user<'a>(
174 &'a self,
175 user_id: UserId,
176 ) -> impl 'a + Iterator<Item = ConnectionId> {
177 self.connections_by_user_id
178 .get(&user_id)
179 .into_iter()
180 .flatten()
181 .copied()
182 }
183
184 pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
185 let mut contacts = HashMap::default();
186 for project_id in self
187 .visible_projects_by_user_id
188 .get(&user_id)
189 .unwrap_or(&HashSet::default())
190 {
191 let project = &self.projects[project_id];
192
193 let mut guests = HashSet::default();
194 if let Ok(share) = project.share() {
195 for guest_connection_id in share.guests.keys() {
196 if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
197 guests.insert(user_id.to_proto());
198 }
199 }
200 }
201
202 if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
203 let mut worktree_root_names = project
204 .worktrees
205 .values()
206 .map(|worktree| worktree.root_name.clone())
207 .collect::<Vec<_>>();
208 worktree_root_names.sort_unstable();
209 contacts
210 .entry(host_user_id)
211 .or_insert_with(|| proto::Contact {
212 user_id: host_user_id.to_proto(),
213 projects: Vec::new(),
214 })
215 .projects
216 .push(proto::ProjectMetadata {
217 id: *project_id,
218 worktree_root_names,
219 is_shared: project.share.is_some(),
220 guests: guests.into_iter().collect(),
221 });
222 }
223 }
224
225 contacts.into_values().collect()
226 }
227
228 pub fn register_project(
229 &mut self,
230 host_connection_id: ConnectionId,
231 host_user_id: UserId,
232 ) -> u64 {
233 let project_id = self.next_project_id;
234 self.projects.insert(
235 project_id,
236 Project {
237 host_connection_id,
238 host_user_id,
239 share: None,
240 worktrees: Default::default(),
241 },
242 );
243 self.next_project_id += 1;
244 project_id
245 }
246
247 pub fn register_worktree(
248 &mut self,
249 project_id: u64,
250 worktree_id: u64,
251 worktree: Worktree,
252 ) -> bool {
253 if let Some(project) = self.projects.get_mut(&project_id) {
254 for authorized_user_id in &worktree.authorized_user_ids {
255 self.visible_projects_by_user_id
256 .entry(*authorized_user_id)
257 .or_default()
258 .insert(project_id);
259 }
260 if let Some(connection) = self.connections.get_mut(&project.host_connection_id) {
261 connection.projects.insert(project_id);
262 }
263 project.worktrees.insert(worktree_id, worktree);
264
265 #[cfg(test)]
266 self.check_invariants();
267 true
268 } else {
269 false
270 }
271 }
272
273 pub fn unregister_project(
274 &mut self,
275 project_id: u64,
276 connection_id: ConnectionId,
277 ) -> Option<(Project, Vec<UserId>)> {
278 match self.projects.entry(project_id) {
279 hash_map::Entry::Occupied(e) => {
280 if e.get().host_connection_id != connection_id {
281 return None;
282 }
283 }
284 hash_map::Entry::Vacant(_) => return None,
285 }
286
287 todo!()
288 }
289
290 pub fn unregister_worktree(
291 &mut self,
292 project_id: u64,
293 worktree_id: u64,
294 acting_connection_id: ConnectionId,
295 ) -> tide::Result<(Worktree, Vec<ConnectionId>)> {
296 let project = self
297 .projects
298 .get_mut(&project_id)
299 .ok_or_else(|| anyhow!("no such project"))?;
300 if project.host_connection_id != acting_connection_id {
301 Err(anyhow!("not your worktree"))?;
302 }
303
304 let worktree = project
305 .worktrees
306 .remove(&worktree_id)
307 .ok_or_else(|| anyhow!("no such worktree"))?;
308
309 let mut guest_connection_ids = Vec::new();
310 if let Some(share) = &project.share {
311 guest_connection_ids.extend(share.guests.keys());
312 }
313
314 for authorized_user_id in &worktree.authorized_user_ids {
315 if let Some(visible_projects) =
316 self.visible_projects_by_user_id.get_mut(authorized_user_id)
317 {
318 if !project.has_authorized_user_id(*authorized_user_id) {
319 visible_projects.remove(&project_id);
320 }
321 }
322 }
323
324 #[cfg(test)]
325 self.check_invariants();
326
327 Ok((worktree, guest_connection_ids))
328 }
329
330 pub fn share_project(&mut self, project_id: u64, connection_id: ConnectionId) -> bool {
331 if let Some(project) = self.projects.get_mut(&project_id) {
332 if project.host_connection_id == connection_id {
333 project.share = Some(ProjectShare::default());
334 return true;
335 }
336 }
337 false
338 }
339
340 pub fn unshare_project(
341 &mut self,
342 project_id: u64,
343 acting_connection_id: ConnectionId,
344 ) -> tide::Result<UnsharedWorktree> {
345 let project = if let Some(project) = self.projects.get_mut(&project_id) {
346 project
347 } else {
348 return Err(anyhow!("no such project"))?;
349 };
350
351 if project.host_connection_id != acting_connection_id {
352 return Err(anyhow!("not your project"))?;
353 }
354
355 let connection_ids = project.connection_ids();
356 let authorized_user_ids = project.authorized_user_ids();
357 if let Some(share) = project.share.take() {
358 for connection_id in share.guests.into_keys() {
359 if let Some(connection) = self.connections.get_mut(&connection_id) {
360 connection.projects.remove(&project_id);
361 }
362 }
363
364 #[cfg(test)]
365 self.check_invariants();
366
367 Ok(UnsharedWorktree {
368 connection_ids,
369 authorized_user_ids,
370 })
371 } else {
372 Err(anyhow!("project is not shared"))?
373 }
374 }
375
376 pub fn share_worktree(
377 &mut self,
378 project_id: u64,
379 worktree_id: u64,
380 connection_id: ConnectionId,
381 entries: HashMap<u64, proto::Entry>,
382 ) -> Option<Vec<UserId>> {
383 let project = self.projects.get_mut(&project_id)?;
384 let worktree = project.worktrees.get_mut(&worktree_id)?;
385 if project.host_connection_id == connection_id && project.share.is_some() {
386 worktree.share = Some(WorktreeShare { entries });
387 Some(project.authorized_user_ids())
388 } else {
389 None
390 }
391 }
392
393 pub fn join_project(
394 &mut self,
395 connection_id: ConnectionId,
396 user_id: UserId,
397 project_id: u64,
398 ) -> tide::Result<JoinedProject> {
399 let connection = self
400 .connections
401 .get_mut(&connection_id)
402 .ok_or_else(|| anyhow!("no such connection"))?;
403 let project = self
404 .projects
405 .get_mut(&project_id)
406 .and_then(|project| {
407 if project.has_authorized_user_id(user_id) {
408 Some(project)
409 } else {
410 None
411 }
412 })
413 .ok_or_else(|| anyhow!("no such project"))?;
414
415 let share = project.share_mut()?;
416 connection.projects.insert(project_id);
417
418 let mut replica_id = 1;
419 while share.active_replica_ids.contains(&replica_id) {
420 replica_id += 1;
421 }
422 share.active_replica_ids.insert(replica_id);
423 share.guests.insert(connection_id, (replica_id, user_id));
424
425 #[cfg(test)]
426 self.check_invariants();
427
428 Ok(JoinedProject {
429 replica_id,
430 project: &self.projects[&project_id],
431 })
432 }
433
434 pub fn leave_project(
435 &mut self,
436 connection_id: ConnectionId,
437 project_id: u64,
438 ) -> Option<LeftProject> {
439 let project = self.projects.get_mut(&project_id)?;
440 let share = project.share.as_mut()?;
441 let (replica_id, _) = share.guests.remove(&connection_id)?;
442 share.active_replica_ids.remove(&replica_id);
443
444 if let Some(connection) = self.connections.get_mut(&connection_id) {
445 connection.projects.remove(&project_id);
446 }
447
448 let connection_ids = project.connection_ids();
449 let authorized_user_ids = project.authorized_user_ids();
450
451 #[cfg(test)]
452 self.check_invariants();
453
454 Some(LeftProject {
455 connection_ids,
456 authorized_user_ids,
457 })
458 }
459
460 pub fn update_worktree(
461 &mut self,
462 connection_id: ConnectionId,
463 project_id: u64,
464 worktree_id: u64,
465 removed_entries: &[u64],
466 updated_entries: &[proto::Entry],
467 ) -> Option<Vec<ConnectionId>> {
468 let project = self.write_project(project_id, connection_id)?;
469 let share = project.worktrees.get_mut(&worktree_id)?.share.as_mut()?;
470 for entry_id in removed_entries {
471 share.entries.remove(&entry_id);
472 }
473 for entry in updated_entries {
474 share.entries.insert(entry.id, entry.clone());
475 }
476 Some(project.connection_ids())
477 }
478
479 pub fn project_connection_ids(
480 &self,
481 project_id: u64,
482 acting_connection_id: ConnectionId,
483 ) -> Option<Vec<ConnectionId>> {
484 Some(
485 self.read_project(project_id, acting_connection_id)?
486 .connection_ids(),
487 )
488 }
489
490 pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
491 Some(self.channels.get(&channel_id)?.connection_ids())
492 }
493
494 pub fn read_project(&self, project_id: u64, connection_id: ConnectionId) -> Option<&Project> {
495 let project = self.projects.get(&project_id)?;
496 if project.host_connection_id == connection_id
497 || project.share.as_ref()?.guests.contains_key(&connection_id)
498 {
499 Some(project)
500 } else {
501 None
502 }
503 }
504
505 fn write_project(
506 &mut self,
507 project_id: u64,
508 connection_id: ConnectionId,
509 ) -> Option<&mut Project> {
510 let project = self.projects.get_mut(&project_id)?;
511 if project.host_connection_id == connection_id
512 || project.share.as_ref()?.guests.contains_key(&connection_id)
513 {
514 Some(project)
515 } else {
516 None
517 }
518 }
519
520 #[cfg(test)]
521 fn check_invariants(&self) {
522 for (connection_id, connection) in &self.connections {
523 for project_id in &connection.projects {
524 let project = &self.projects.get(&project_id).unwrap();
525 if project.host_connection_id != *connection_id {
526 assert!(project
527 .share
528 .as_ref()
529 .unwrap()
530 .guests
531 .contains_key(connection_id));
532 }
533 }
534 for channel_id in &connection.channels {
535 let channel = self.channels.get(channel_id).unwrap();
536 assert!(channel.connection_ids.contains(connection_id));
537 }
538 assert!(self
539 .connections_by_user_id
540 .get(&connection.user_id)
541 .unwrap()
542 .contains(connection_id));
543 }
544
545 for (user_id, connection_ids) in &self.connections_by_user_id {
546 for connection_id in connection_ids {
547 assert_eq!(
548 self.connections.get(connection_id).unwrap().user_id,
549 *user_id
550 );
551 }
552 }
553
554 for (project_id, project) in &self.projects {
555 let host_connection = self.connections.get(&project.host_connection_id).unwrap();
556 assert!(host_connection.projects.contains(project_id));
557
558 for authorized_user_ids in project.authorized_user_ids() {
559 let visible_project_ids = self
560 .visible_projects_by_user_id
561 .get(&authorized_user_ids)
562 .unwrap();
563 assert!(visible_project_ids.contains(project_id));
564 }
565
566 if let Some(share) = &project.share {
567 for guest_connection_id in share.guests.keys() {
568 let guest_connection = self.connections.get(guest_connection_id).unwrap();
569 assert!(guest_connection.projects.contains(project_id));
570 }
571 assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
572 assert_eq!(
573 share.active_replica_ids,
574 share
575 .guests
576 .values()
577 .map(|(replica_id, _)| *replica_id)
578 .collect::<HashSet<_>>(),
579 );
580 }
581 }
582
583 for (user_id, visible_project_ids) in &self.visible_projects_by_user_id {
584 for project_id in visible_project_ids {
585 let project = self.projects.get(project_id).unwrap();
586 assert!(project.authorized_user_ids().contains(user_id));
587 }
588 }
589
590 for (channel_id, channel) in &self.channels {
591 for connection_id in &channel.connection_ids {
592 let connection = self.connections.get(connection_id).unwrap();
593 assert!(connection.channels.contains(channel_id));
594 }
595 }
596 }
597}
598
599impl Project {
600 pub fn has_authorized_user_id(&self, user_id: UserId) -> bool {
601 self.worktrees
602 .values()
603 .any(|worktree| worktree.authorized_user_ids.contains(&user_id))
604 }
605
606 pub fn authorized_user_ids(&self) -> Vec<UserId> {
607 let mut ids = self
608 .worktrees
609 .values()
610 .flat_map(|worktree| worktree.authorized_user_ids.iter())
611 .copied()
612 .collect::<Vec<_>>();
613 ids.sort_unstable();
614 ids.dedup();
615 ids
616 }
617
618 pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
619 if let Some(share) = &self.share {
620 share.guests.keys().copied().collect()
621 } else {
622 Vec::new()
623 }
624 }
625
626 pub fn connection_ids(&self) -> Vec<ConnectionId> {
627 if let Some(share) = &self.share {
628 share
629 .guests
630 .keys()
631 .copied()
632 .chain(Some(self.host_connection_id))
633 .collect()
634 } else {
635 vec![self.host_connection_id]
636 }
637 }
638
639 pub fn share(&self) -> tide::Result<&ProjectShare> {
640 Ok(self
641 .share
642 .as_ref()
643 .ok_or_else(|| anyhow!("worktree is not shared"))?)
644 }
645
646 fn share_mut(&mut self) -> tide::Result<&mut ProjectShare> {
647 Ok(self
648 .share
649 .as_mut()
650 .ok_or_else(|| anyhow!("worktree is not shared"))?)
651 }
652}
653
654impl Channel {
655 fn connection_ids(&self) -> Vec<ConnectionId> {
656 self.connection_ids.iter().copied().collect()
657 }
658}