1use crate::db::{ChannelId, UserId};
2use anyhow::anyhow;
3use collections::{HashMap, HashSet};
4use rpc::{proto, ConnectionId};
5use std::collections::hash_map;
6
7#[derive(Default)]
8pub struct Store {
9 connections: HashMap<ConnectionId, ConnectionState>,
10 connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
11 projects: HashMap<u64, Project>,
12 visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
13 channels: HashMap<ChannelId, Channel>,
14 next_project_id: u64,
15}
16
17struct ConnectionState {
18 user_id: UserId,
19 projects: HashSet<u64>,
20 channels: HashSet<ChannelId>,
21}
22
23pub struct Project {
24 pub host_connection_id: ConnectionId,
25 pub host_user_id: UserId,
26 pub share: Option<ProjectShare>,
27 pub worktrees: HashMap<u64, Worktree>,
28}
29
30pub struct Worktree {
31 pub authorized_user_ids: Vec<UserId>,
32 pub root_name: String,
33 pub share: Option<WorktreeShare>,
34}
35
36#[derive(Default)]
37pub struct ProjectShare {
38 pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
39 pub active_replica_ids: HashSet<ReplicaId>,
40}
41
42pub struct WorktreeShare {
43 pub entries: HashMap<u64, proto::Entry>,
44}
45
46#[derive(Default)]
47pub struct Channel {
48 pub connection_ids: HashSet<ConnectionId>,
49}
50
51pub type ReplicaId = u16;
52
53#[derive(Default)]
54pub struct RemovedConnectionState {
55 pub hosted_projects: HashMap<u64, Project>,
56 pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
57 pub contact_ids: HashSet<UserId>,
58}
59
60pub struct JoinedProject<'a> {
61 pub replica_id: ReplicaId,
62 pub project: &'a Project,
63}
64
65pub struct UnsharedWorktree {
66 pub connection_ids: Vec<ConnectionId>,
67 pub authorized_user_ids: Vec<UserId>,
68}
69
70pub struct LeftProject {
71 pub connection_ids: Vec<ConnectionId>,
72 pub authorized_user_ids: Vec<UserId>,
73}
74
75impl Store {
76 pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
77 self.connections.insert(
78 connection_id,
79 ConnectionState {
80 user_id,
81 projects: Default::default(),
82 channels: Default::default(),
83 },
84 );
85 self.connections_by_user_id
86 .entry(user_id)
87 .or_default()
88 .insert(connection_id);
89 }
90
91 pub fn remove_connection(
92 &mut self,
93 connection_id: ConnectionId,
94 ) -> tide::Result<RemovedConnectionState> {
95 let connection = if let Some(connection) = self.connections.remove(&connection_id) {
96 connection
97 } else {
98 return Err(anyhow!("no such connection"))?;
99 };
100
101 for channel_id in &connection.channels {
102 if let Some(channel) = self.channels.get_mut(&channel_id) {
103 channel.connection_ids.remove(&connection_id);
104 }
105 }
106
107 let user_connections = self
108 .connections_by_user_id
109 .get_mut(&connection.user_id)
110 .unwrap();
111 user_connections.remove(&connection_id);
112 if user_connections.is_empty() {
113 self.connections_by_user_id.remove(&connection.user_id);
114 }
115
116 let mut result = RemovedConnectionState::default();
117 for project_id in connection.projects.clone() {
118 if let Some(project) = self.unregister_project(project_id, connection_id) {
119 result.contact_ids.extend(project.authorized_user_ids());
120 result.hosted_projects.insert(project_id, project);
121 } else if let Some(project) = self.leave_project(connection_id, project_id) {
122 result
123 .guest_project_ids
124 .insert(project_id, project.connection_ids);
125 result.contact_ids.extend(project.authorized_user_ids);
126 }
127 }
128
129 #[cfg(test)]
130 self.check_invariants();
131
132 Ok(result)
133 }
134
135 #[cfg(test)]
136 pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
137 self.channels.get(&id)
138 }
139
140 pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
141 if let Some(connection) = self.connections.get_mut(&connection_id) {
142 connection.channels.insert(channel_id);
143 self.channels
144 .entry(channel_id)
145 .or_default()
146 .connection_ids
147 .insert(connection_id);
148 }
149 }
150
151 pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
152 if let Some(connection) = self.connections.get_mut(&connection_id) {
153 connection.channels.remove(&channel_id);
154 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
155 entry.get_mut().connection_ids.remove(&connection_id);
156 if entry.get_mut().connection_ids.is_empty() {
157 entry.remove();
158 }
159 }
160 }
161 }
162
163 pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
164 Ok(self
165 .connections
166 .get(&connection_id)
167 .ok_or_else(|| anyhow!("unknown connection"))?
168 .user_id)
169 }
170
171 pub fn connection_ids_for_user<'a>(
172 &'a self,
173 user_id: UserId,
174 ) -> impl 'a + Iterator<Item = ConnectionId> {
175 self.connections_by_user_id
176 .get(&user_id)
177 .into_iter()
178 .flatten()
179 .copied()
180 }
181
182 pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
183 let mut contacts = HashMap::default();
184 for project_id in self
185 .visible_projects_by_user_id
186 .get(&user_id)
187 .unwrap_or(&HashSet::default())
188 {
189 let project = &self.projects[project_id];
190
191 let mut guests = HashSet::default();
192 if let Ok(share) = project.share() {
193 for guest_connection_id in share.guests.keys() {
194 if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
195 guests.insert(user_id.to_proto());
196 }
197 }
198 }
199
200 if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
201 let mut worktree_root_names = project
202 .worktrees
203 .values()
204 .map(|worktree| worktree.root_name.clone())
205 .collect::<Vec<_>>();
206 worktree_root_names.sort_unstable();
207 contacts
208 .entry(host_user_id)
209 .or_insert_with(|| proto::Contact {
210 user_id: host_user_id.to_proto(),
211 projects: Vec::new(),
212 })
213 .projects
214 .push(proto::ProjectMetadata {
215 id: *project_id,
216 worktree_root_names,
217 is_shared: project.share.is_some(),
218 guests: guests.into_iter().collect(),
219 });
220 }
221 }
222
223 contacts.into_values().collect()
224 }
225
226 pub fn register_project(
227 &mut self,
228 host_connection_id: ConnectionId,
229 host_user_id: UserId,
230 ) -> u64 {
231 let project_id = self.next_project_id;
232 self.projects.insert(
233 project_id,
234 Project {
235 host_connection_id,
236 host_user_id,
237 share: None,
238 worktrees: Default::default(),
239 },
240 );
241 self.next_project_id += 1;
242 project_id
243 }
244
245 pub fn register_worktree(
246 &mut self,
247 project_id: u64,
248 worktree_id: u64,
249 worktree: Worktree,
250 ) -> bool {
251 if let Some(project) = self.projects.get_mut(&project_id) {
252 for authorized_user_id in &worktree.authorized_user_ids {
253 self.visible_projects_by_user_id
254 .entry(*authorized_user_id)
255 .or_default()
256 .insert(project_id);
257 }
258 if let Some(connection) = self.connections.get_mut(&project.host_connection_id) {
259 connection.projects.insert(project_id);
260 }
261 project.worktrees.insert(worktree_id, worktree);
262
263 #[cfg(test)]
264 self.check_invariants();
265 true
266 } else {
267 false
268 }
269 }
270
271 pub fn unregister_project(
272 &mut self,
273 project_id: u64,
274 connection_id: ConnectionId,
275 ) -> Option<Project> {
276 match self.projects.entry(project_id) {
277 hash_map::Entry::Occupied(e) => {
278 if e.get().host_connection_id == connection_id {
279 for user_id in e.get().authorized_user_ids() {
280 if let hash_map::Entry::Occupied(mut projects) =
281 self.visible_projects_by_user_id.entry(user_id)
282 {
283 projects.get_mut().remove(&project_id);
284 }
285 }
286
287 Some(e.remove())
288 } else {
289 None
290 }
291 }
292 hash_map::Entry::Vacant(_) => None,
293 }
294 }
295
296 pub fn unregister_worktree(
297 &mut self,
298 project_id: u64,
299 worktree_id: u64,
300 acting_connection_id: ConnectionId,
301 ) -> tide::Result<(Worktree, Vec<ConnectionId>)> {
302 let project = self
303 .projects
304 .get_mut(&project_id)
305 .ok_or_else(|| anyhow!("no such project"))?;
306 if project.host_connection_id != acting_connection_id {
307 Err(anyhow!("not your worktree"))?;
308 }
309
310 let worktree = project
311 .worktrees
312 .remove(&worktree_id)
313 .ok_or_else(|| anyhow!("no such worktree"))?;
314
315 let mut guest_connection_ids = Vec::new();
316 if let Some(share) = &project.share {
317 guest_connection_ids.extend(share.guests.keys());
318 }
319
320 for authorized_user_id in &worktree.authorized_user_ids {
321 if let Some(visible_projects) =
322 self.visible_projects_by_user_id.get_mut(authorized_user_id)
323 {
324 if !project.has_authorized_user_id(*authorized_user_id) {
325 visible_projects.remove(&project_id);
326 }
327 }
328 }
329
330 #[cfg(test)]
331 self.check_invariants();
332
333 Ok((worktree, guest_connection_ids))
334 }
335
336 pub fn share_project(&mut self, project_id: u64, connection_id: ConnectionId) -> bool {
337 if let Some(project) = self.projects.get_mut(&project_id) {
338 if project.host_connection_id == connection_id {
339 project.share = Some(ProjectShare::default());
340 return true;
341 }
342 }
343 false
344 }
345
346 pub fn unshare_project(
347 &mut self,
348 project_id: u64,
349 acting_connection_id: ConnectionId,
350 ) -> tide::Result<UnsharedWorktree> {
351 let project = if let Some(project) = self.projects.get_mut(&project_id) {
352 project
353 } else {
354 return Err(anyhow!("no such project"))?;
355 };
356
357 if project.host_connection_id != acting_connection_id {
358 return Err(anyhow!("not your project"))?;
359 }
360
361 let connection_ids = project.connection_ids();
362 let authorized_user_ids = project.authorized_user_ids();
363 if let Some(share) = project.share.take() {
364 for connection_id in share.guests.into_keys() {
365 if let Some(connection) = self.connections.get_mut(&connection_id) {
366 connection.projects.remove(&project_id);
367 }
368 }
369
370 #[cfg(test)]
371 self.check_invariants();
372
373 Ok(UnsharedWorktree {
374 connection_ids,
375 authorized_user_ids,
376 })
377 } else {
378 Err(anyhow!("project is not shared"))?
379 }
380 }
381
382 pub fn share_worktree(
383 &mut self,
384 project_id: u64,
385 worktree_id: u64,
386 connection_id: ConnectionId,
387 entries: HashMap<u64, proto::Entry>,
388 ) -> Option<Vec<UserId>> {
389 let project = self.projects.get_mut(&project_id)?;
390 let worktree = project.worktrees.get_mut(&worktree_id)?;
391 if project.host_connection_id == connection_id && project.share.is_some() {
392 worktree.share = Some(WorktreeShare { entries });
393 Some(project.authorized_user_ids())
394 } else {
395 None
396 }
397 }
398
399 pub fn join_project(
400 &mut self,
401 connection_id: ConnectionId,
402 user_id: UserId,
403 project_id: u64,
404 ) -> tide::Result<JoinedProject> {
405 let connection = self
406 .connections
407 .get_mut(&connection_id)
408 .ok_or_else(|| anyhow!("no such connection"))?;
409 let project = self
410 .projects
411 .get_mut(&project_id)
412 .and_then(|project| {
413 if project.has_authorized_user_id(user_id) {
414 Some(project)
415 } else {
416 None
417 }
418 })
419 .ok_or_else(|| anyhow!("no such project"))?;
420
421 let share = project.share_mut()?;
422 connection.projects.insert(project_id);
423
424 let mut replica_id = 1;
425 while share.active_replica_ids.contains(&replica_id) {
426 replica_id += 1;
427 }
428 share.active_replica_ids.insert(replica_id);
429 share.guests.insert(connection_id, (replica_id, user_id));
430
431 #[cfg(test)]
432 self.check_invariants();
433
434 Ok(JoinedProject {
435 replica_id,
436 project: &self.projects[&project_id],
437 })
438 }
439
440 pub fn leave_project(
441 &mut self,
442 connection_id: ConnectionId,
443 project_id: u64,
444 ) -> Option<LeftProject> {
445 let project = self.projects.get_mut(&project_id)?;
446 let share = project.share.as_mut()?;
447 let (replica_id, _) = share.guests.remove(&connection_id)?;
448 share.active_replica_ids.remove(&replica_id);
449
450 if let Some(connection) = self.connections.get_mut(&connection_id) {
451 connection.projects.remove(&project_id);
452 }
453
454 let connection_ids = project.connection_ids();
455 let authorized_user_ids = project.authorized_user_ids();
456
457 #[cfg(test)]
458 self.check_invariants();
459
460 Some(LeftProject {
461 connection_ids,
462 authorized_user_ids,
463 })
464 }
465
466 pub fn update_worktree(
467 &mut self,
468 connection_id: ConnectionId,
469 project_id: u64,
470 worktree_id: u64,
471 removed_entries: &[u64],
472 updated_entries: &[proto::Entry],
473 ) -> Option<Vec<ConnectionId>> {
474 let project = self.write_project(project_id, connection_id)?;
475 let share = project.worktrees.get_mut(&worktree_id)?.share.as_mut()?;
476 for entry_id in removed_entries {
477 share.entries.remove(&entry_id);
478 }
479 for entry in updated_entries {
480 share.entries.insert(entry.id, entry.clone());
481 }
482 Some(project.connection_ids())
483 }
484
485 pub fn project_connection_ids(
486 &self,
487 project_id: u64,
488 acting_connection_id: ConnectionId,
489 ) -> Option<Vec<ConnectionId>> {
490 Some(
491 self.read_project(project_id, acting_connection_id)?
492 .connection_ids(),
493 )
494 }
495
496 pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
497 Some(self.channels.get(&channel_id)?.connection_ids())
498 }
499
500 pub fn read_project(&self, project_id: u64, connection_id: ConnectionId) -> Option<&Project> {
501 let project = self.projects.get(&project_id)?;
502 if project.host_connection_id == connection_id
503 || project.share.as_ref()?.guests.contains_key(&connection_id)
504 {
505 Some(project)
506 } else {
507 None
508 }
509 }
510
511 fn write_project(
512 &mut self,
513 project_id: u64,
514 connection_id: ConnectionId,
515 ) -> Option<&mut Project> {
516 let project = self.projects.get_mut(&project_id)?;
517 if project.host_connection_id == connection_id
518 || project.share.as_ref()?.guests.contains_key(&connection_id)
519 {
520 Some(project)
521 } else {
522 None
523 }
524 }
525
526 #[cfg(test)]
527 fn check_invariants(&self) {
528 for (connection_id, connection) in &self.connections {
529 for project_id in &connection.projects {
530 let project = &self.projects.get(&project_id).unwrap();
531 if project.host_connection_id != *connection_id {
532 assert!(project
533 .share
534 .as_ref()
535 .unwrap()
536 .guests
537 .contains_key(connection_id));
538 }
539 }
540 for channel_id in &connection.channels {
541 let channel = self.channels.get(channel_id).unwrap();
542 assert!(channel.connection_ids.contains(connection_id));
543 }
544 assert!(self
545 .connections_by_user_id
546 .get(&connection.user_id)
547 .unwrap()
548 .contains(connection_id));
549 }
550
551 for (user_id, connection_ids) in &self.connections_by_user_id {
552 for connection_id in connection_ids {
553 assert_eq!(
554 self.connections.get(connection_id).unwrap().user_id,
555 *user_id
556 );
557 }
558 }
559
560 for (project_id, project) in &self.projects {
561 let host_connection = self.connections.get(&project.host_connection_id).unwrap();
562 assert!(host_connection.projects.contains(project_id));
563
564 for authorized_user_ids in project.authorized_user_ids() {
565 let visible_project_ids = self
566 .visible_projects_by_user_id
567 .get(&authorized_user_ids)
568 .unwrap();
569 assert!(visible_project_ids.contains(project_id));
570 }
571
572 if let Some(share) = &project.share {
573 for guest_connection_id in share.guests.keys() {
574 let guest_connection = self.connections.get(guest_connection_id).unwrap();
575 assert!(guest_connection.projects.contains(project_id));
576 }
577 assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
578 assert_eq!(
579 share.active_replica_ids,
580 share
581 .guests
582 .values()
583 .map(|(replica_id, _)| *replica_id)
584 .collect::<HashSet<_>>(),
585 );
586 }
587 }
588
589 for (user_id, visible_project_ids) in &self.visible_projects_by_user_id {
590 for project_id in visible_project_ids {
591 let project = self.projects.get(project_id).unwrap();
592 assert!(project.authorized_user_ids().contains(user_id));
593 }
594 }
595
596 for (channel_id, channel) in &self.channels {
597 for connection_id in &channel.connection_ids {
598 let connection = self.connections.get(connection_id).unwrap();
599 assert!(connection.channels.contains(channel_id));
600 }
601 }
602 }
603}
604
605impl Project {
606 pub fn has_authorized_user_id(&self, user_id: UserId) -> bool {
607 self.worktrees
608 .values()
609 .any(|worktree| worktree.authorized_user_ids.contains(&user_id))
610 }
611
612 pub fn authorized_user_ids(&self) -> Vec<UserId> {
613 let mut ids = self
614 .worktrees
615 .values()
616 .flat_map(|worktree| worktree.authorized_user_ids.iter())
617 .copied()
618 .collect::<Vec<_>>();
619 ids.sort_unstable();
620 ids.dedup();
621 ids
622 }
623
624 pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
625 if let Some(share) = &self.share {
626 share.guests.keys().copied().collect()
627 } else {
628 Vec::new()
629 }
630 }
631
632 pub fn connection_ids(&self) -> Vec<ConnectionId> {
633 if let Some(share) = &self.share {
634 share
635 .guests
636 .keys()
637 .copied()
638 .chain(Some(self.host_connection_id))
639 .collect()
640 } else {
641 vec![self.host_connection_id]
642 }
643 }
644
645 pub fn share(&self) -> tide::Result<&ProjectShare> {
646 Ok(self
647 .share
648 .as_ref()
649 .ok_or_else(|| anyhow!("worktree is not shared"))?)
650 }
651
652 fn share_mut(&mut self) -> tide::Result<&mut ProjectShare> {
653 Ok(self
654 .share
655 .as_mut()
656 .ok_or_else(|| anyhow!("worktree is not shared"))?)
657 }
658}
659
660impl Channel {
661 fn connection_ids(&self) -> Vec<ConnectionId> {
662 self.connection_ids.iter().copied().collect()
663 }
664}