1use crate::db::{ChannelId, UserId};
2use anyhow::anyhow;
3use collections::{HashMap, HashSet};
4use rpc::{proto, ConnectionId};
5use std::collections::hash_map;
6
7#[derive(Default)]
8pub struct Store {
9 connections: HashMap<ConnectionId, ConnectionState>,
10 connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
11 projects: HashMap<u64, Project>,
12 visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
13 channels: HashMap<ChannelId, Channel>,
14 next_worktree_id: u64,
15}
16
17struct ConnectionState {
18 user_id: UserId,
19 projects: HashSet<u64>,
20 channels: HashSet<ChannelId>,
21}
22
23pub struct Project {
24 pub host_connection_id: ConnectionId,
25 pub host_user_id: UserId,
26 pub share: Option<ProjectShare>,
27 worktrees: HashMap<u64, Worktree>,
28}
29
30pub struct Worktree {
31 pub authorized_user_ids: Vec<UserId>,
32 pub root_name: String,
33}
34
35#[derive(Default)]
36pub struct ProjectShare {
37 pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
38 pub active_replica_ids: HashSet<ReplicaId>,
39 pub worktrees: HashMap<u64, WorktreeShare>,
40}
41
42pub struct WorktreeShare {
43 pub entries: HashMap<u64, proto::Entry>,
44}
45
46#[derive(Default)]
47pub struct Channel {
48 pub connection_ids: HashSet<ConnectionId>,
49}
50
51pub type ReplicaId = u16;
52
53#[derive(Default)]
54pub struct RemovedConnectionState {
55 pub hosted_projects: HashMap<u64, Project>,
56 pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
57 pub contact_ids: HashSet<UserId>,
58}
59
60pub struct JoinedWorktree<'a> {
61 pub replica_id: ReplicaId,
62 pub worktree: &'a Worktree,
63}
64
65pub struct UnsharedWorktree {
66 pub connection_ids: Vec<ConnectionId>,
67 pub authorized_user_ids: Vec<UserId>,
68}
69
70pub struct LeftWorktree {
71 pub connection_ids: Vec<ConnectionId>,
72 pub authorized_user_ids: Vec<UserId>,
73}
74
75impl Store {
76 pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
77 self.connections.insert(
78 connection_id,
79 ConnectionState {
80 user_id,
81 projects: Default::default(),
82 channels: Default::default(),
83 },
84 );
85 self.connections_by_user_id
86 .entry(user_id)
87 .or_default()
88 .insert(connection_id);
89 }
90
91 pub fn remove_connection(
92 &mut self,
93 connection_id: ConnectionId,
94 ) -> tide::Result<RemovedConnectionState> {
95 let connection = if let Some(connection) = self.connections.remove(&connection_id) {
96 connection
97 } else {
98 return Err(anyhow!("no such connection"))?;
99 };
100
101 for channel_id in &connection.channels {
102 if let Some(channel) = self.channels.get_mut(&channel_id) {
103 channel.connection_ids.remove(&connection_id);
104 }
105 }
106
107 let user_connections = self
108 .connections_by_user_id
109 .get_mut(&connection.user_id)
110 .unwrap();
111 user_connections.remove(&connection_id);
112 if user_connections.is_empty() {
113 self.connections_by_user_id.remove(&connection.user_id);
114 }
115
116 let mut result = RemovedConnectionState::default();
117 for worktree_id in connection.worktrees.clone() {
118 if let Ok(worktree) = self.unregister_worktree(worktree_id, connection_id) {
119 result
120 .contact_ids
121 .extend(worktree.authorized_user_ids.iter().copied());
122 result.hosted_worktrees.insert(worktree_id, worktree);
123 } else if let Some(worktree) = self.leave_worktree(connection_id, worktree_id) {
124 result
125 .guest_worktree_ids
126 .insert(worktree_id, worktree.connection_ids);
127 result.contact_ids.extend(worktree.authorized_user_ids);
128 }
129 }
130
131 #[cfg(test)]
132 self.check_invariants();
133
134 Ok(result)
135 }
136
137 #[cfg(test)]
138 pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
139 self.channels.get(&id)
140 }
141
142 pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
143 if let Some(connection) = self.connections.get_mut(&connection_id) {
144 connection.channels.insert(channel_id);
145 self.channels
146 .entry(channel_id)
147 .or_default()
148 .connection_ids
149 .insert(connection_id);
150 }
151 }
152
153 pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
154 if let Some(connection) = self.connections.get_mut(&connection_id) {
155 connection.channels.remove(&channel_id);
156 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
157 entry.get_mut().connection_ids.remove(&connection_id);
158 if entry.get_mut().connection_ids.is_empty() {
159 entry.remove();
160 }
161 }
162 }
163 }
164
165 pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
166 Ok(self
167 .connections
168 .get(&connection_id)
169 .ok_or_else(|| anyhow!("unknown connection"))?
170 .user_id)
171 }
172
173 pub fn connection_ids_for_user<'a>(
174 &'a self,
175 user_id: UserId,
176 ) -> impl 'a + Iterator<Item = ConnectionId> {
177 self.connections_by_user_id
178 .get(&user_id)
179 .into_iter()
180 .flatten()
181 .copied()
182 }
183
184 pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
185 let mut contacts = HashMap::default();
186 for project_id in self
187 .visible_projects_by_user_id
188 .get(&user_id)
189 .unwrap_or(&HashSet::default())
190 {
191 let project = &self.projects[project_id];
192
193 let mut guests = HashSet::default();
194 if let Ok(share) = worktree.share() {
195 for guest_connection_id in share.guests.keys() {
196 if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
197 guests.insert(user_id.to_proto());
198 }
199 }
200 }
201
202 if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
203 contacts
204 .entry(host_user_id)
205 .or_insert_with(|| proto::Contact {
206 user_id: host_user_id.to_proto(),
207 projects: Vec::new(),
208 })
209 .projects
210 .push(proto::ProjectMetadata {
211 id: *project_id,
212 worktree_root_names: project
213 .worktrees
214 .iter()
215 .map(|worktree| worktree.root_name.clone())
216 .collect(),
217 is_shared: project.share.is_some(),
218 guests: guests.into_iter().collect(),
219 });
220 }
221 }
222
223 contacts.into_values().collect()
224 }
225
226 pub fn register_project(
227 &mut self,
228 host_connection_id: ConnectionId,
229 host_user_id: UserId,
230 ) -> u64 {
231 let project_id = self.next_project_id;
232 self.projects.insert(
233 project_id,
234 Project {
235 host_connection_id,
236 host_user_id,
237 share: None,
238 worktrees: Default::default(),
239 },
240 );
241 self.next_project_id += 1;
242 project_id
243 }
244
245 pub fn register_worktree(
246 &mut self,
247 project_id: u64,
248 worktree_id: u64,
249 worktree: Worktree,
250 ) -> bool {
251 if let Some(project) = self.projects.get_mut(&project_id) {
252 for authorized_user_id in &worktree.authorized_user_ids {
253 self.visible_projects_by_user_id
254 .entry(*authorized_user_id)
255 .or_default()
256 .insert(project_id);
257 }
258 if let Some(connection) = self.connections.get_mut(&project.host_connection_id) {
259 connection.projects.insert(project_id);
260 }
261 project.worktrees.insert(worktree_id, worktree);
262
263 #[cfg(test)]
264 self.check_invariants();
265 true
266 } else {
267 false
268 }
269 }
270
271 pub fn unregister_project(&mut self, project_id: u64) {
272 todo!()
273 }
274
275 pub fn unregister_worktree(
276 &mut self,
277 project_id: u64,
278 worktree_id: u64,
279 acting_connection_id: ConnectionId,
280 ) -> tide::Result<Worktree> {
281 let project = self
282 .projects
283 .get_mut(&project_id)
284 .ok_or_else(|| anyhow!("no such project"))?;
285 if project.host_connection_id != acting_connection_id {
286 Err(anyhow!("not your worktree"))?;
287 }
288
289 let worktree = project
290 .worktrees
291 .remove(&worktree_id)
292 .ok_or_else(|| anyhow!("no such worktree"))?;
293
294 if let Some(connection) = self.connections.get_mut(&project.host_connection_id) {
295 connection.worktrees.remove(&worktree_id);
296 }
297
298 if let Some(share) = &worktree.share {
299 for connection_id in share.guests.keys() {
300 if let Some(connection) = self.connections.get_mut(connection_id) {
301 connection.worktrees.remove(&worktree_id);
302 }
303 }
304 }
305
306 for authorized_user_id in &worktree.authorized_user_ids {
307 if let Some(visible_worktrees) = self
308 .visible_worktrees_by_user_id
309 .get_mut(&authorized_user_id)
310 {
311 visible_worktrees.remove(&worktree_id);
312 }
313 }
314
315 #[cfg(test)]
316 self.check_invariants();
317
318 Ok(worktree)
319 }
320
321 pub fn share_project(&mut self, project_id: u64, connection_id: ConnectionId) -> bool {
322 if let Some(project) = self.projects.get_mut(&project_id) {
323 if project.host_connection_id == connection_id {
324 project.share = Some(ProjectShare::default());
325 return true;
326 }
327 }
328 false
329 }
330
331 pub fn share_worktree(
332 &mut self,
333 project_id: u64,
334 worktree_id: u64,
335 connection_id: ConnectionId,
336 entries: HashMap<u64, proto::Entry>,
337 ) -> Option<Vec<UserId>> {
338 if let Some(project) = self.projects.get_mut(&project_id) {
339 if project.host_connection_id == connection_id {
340 if let Some(share) = project.share.as_mut() {
341 share
342 .worktrees
343 .insert(worktree_id, WorktreeShare { entries });
344 return Some(project.authorized_user_ids());
345 }
346 }
347 }
348 None
349 }
350
351 pub fn unshare_worktree(
352 &mut self,
353 worktree_id: u64,
354 acting_connection_id: ConnectionId,
355 ) -> tide::Result<UnsharedWorktree> {
356 let worktree = if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
357 worktree
358 } else {
359 return Err(anyhow!("no such worktree"))?;
360 };
361
362 if worktree.host_connection_id != acting_connection_id {
363 return Err(anyhow!("not your worktree"))?;
364 }
365
366 let connection_ids = worktree.connection_ids();
367 let authorized_user_ids = worktree.authorized_user_ids.clone();
368 if let Some(share) = worktree.share.take() {
369 for connection_id in share.guests.into_keys() {
370 if let Some(connection) = self.connections.get_mut(&connection_id) {
371 connection.worktrees.remove(&worktree_id);
372 }
373 }
374
375 #[cfg(test)]
376 self.check_invariants();
377
378 Ok(UnsharedWorktree {
379 connection_ids,
380 authorized_user_ids,
381 })
382 } else {
383 Err(anyhow!("worktree is not shared"))?
384 }
385 }
386
387 pub fn join_worktree(
388 &mut self,
389 connection_id: ConnectionId,
390 user_id: UserId,
391 worktree_id: u64,
392 ) -> tide::Result<JoinedWorktree> {
393 let connection = self
394 .connections
395 .get_mut(&connection_id)
396 .ok_or_else(|| anyhow!("no such connection"))?;
397 let worktree = self
398 .worktrees
399 .get_mut(&worktree_id)
400 .and_then(|worktree| {
401 if worktree.authorized_user_ids.contains(&user_id) {
402 Some(worktree)
403 } else {
404 None
405 }
406 })
407 .ok_or_else(|| anyhow!("no such worktree"))?;
408
409 let share = worktree.share_mut()?;
410 connection.worktrees.insert(worktree_id);
411
412 let mut replica_id = 1;
413 while share.active_replica_ids.contains(&replica_id) {
414 replica_id += 1;
415 }
416 share.active_replica_ids.insert(replica_id);
417 share.guests.insert(connection_id, (replica_id, user_id));
418
419 #[cfg(test)]
420 self.check_invariants();
421
422 Ok(JoinedWorktree {
423 replica_id,
424 worktree: &self.worktrees[&worktree_id],
425 })
426 }
427
428 pub fn leave_worktree(
429 &mut self,
430 connection_id: ConnectionId,
431 worktree_id: u64,
432 ) -> Option<LeftWorktree> {
433 let worktree = self.worktrees.get_mut(&worktree_id)?;
434 let share = worktree.share.as_mut()?;
435 let (replica_id, _) = share.guests.remove(&connection_id)?;
436 share.active_replica_ids.remove(&replica_id);
437
438 if let Some(connection) = self.connections.get_mut(&connection_id) {
439 connection.worktrees.remove(&worktree_id);
440 }
441
442 let connection_ids = worktree.connection_ids();
443 let authorized_user_ids = worktree.authorized_user_ids.clone();
444
445 #[cfg(test)]
446 self.check_invariants();
447
448 Some(LeftWorktree {
449 connection_ids,
450 authorized_user_ids,
451 })
452 }
453
454 pub fn update_worktree(
455 &mut self,
456 connection_id: ConnectionId,
457 worktree_id: u64,
458 removed_entries: &[u64],
459 updated_entries: &[proto::Entry],
460 ) -> tide::Result<Vec<ConnectionId>> {
461 let worktree = self.write_worktree(worktree_id, connection_id)?;
462 let share = worktree.share_mut()?;
463 for entry_id in removed_entries {
464 share.entries.remove(&entry_id);
465 }
466 for entry in updated_entries {
467 share.entries.insert(entry.id, entry.clone());
468 }
469 Ok(worktree.connection_ids())
470 }
471
472 pub fn worktree_host_connection_id(
473 &self,
474 connection_id: ConnectionId,
475 worktree_id: u64,
476 ) -> tide::Result<ConnectionId> {
477 Ok(self
478 .read_worktree(worktree_id, connection_id)?
479 .host_connection_id)
480 }
481
482 pub fn worktree_guest_connection_ids(
483 &self,
484 connection_id: ConnectionId,
485 worktree_id: u64,
486 ) -> tide::Result<Vec<ConnectionId>> {
487 Ok(self
488 .read_worktree(worktree_id, connection_id)?
489 .share()?
490 .guests
491 .keys()
492 .copied()
493 .collect())
494 }
495
496 pub fn worktree_connection_ids(
497 &self,
498 connection_id: ConnectionId,
499 worktree_id: u64,
500 ) -> tide::Result<Vec<ConnectionId>> {
501 Ok(self
502 .read_worktree(worktree_id, connection_id)?
503 .connection_ids())
504 }
505
506 pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
507 Some(self.channels.get(&channel_id)?.connection_ids())
508 }
509
510 fn read_worktree(
511 &self,
512 worktree_id: u64,
513 connection_id: ConnectionId,
514 ) -> tide::Result<&Worktree> {
515 let worktree = self
516 .worktrees
517 .get(&worktree_id)
518 .ok_or_else(|| anyhow!("worktree not found"))?;
519
520 if worktree.host_connection_id == connection_id
521 || worktree.share()?.guests.contains_key(&connection_id)
522 {
523 Ok(worktree)
524 } else {
525 Err(anyhow!(
526 "{} is not a member of worktree {}",
527 connection_id,
528 worktree_id
529 ))?
530 }
531 }
532
533 fn write_worktree(
534 &mut self,
535 worktree_id: u64,
536 connection_id: ConnectionId,
537 ) -> tide::Result<&mut Worktree> {
538 let worktree = self
539 .worktrees
540 .get_mut(&worktree_id)
541 .ok_or_else(|| anyhow!("worktree not found"))?;
542
543 if worktree.host_connection_id == connection_id
544 || worktree
545 .share
546 .as_ref()
547 .map_or(false, |share| share.guests.contains_key(&connection_id))
548 {
549 Ok(worktree)
550 } else {
551 Err(anyhow!(
552 "{} is not a member of worktree {}",
553 connection_id,
554 worktree_id
555 ))?
556 }
557 }
558
559 #[cfg(test)]
560 fn check_invariants(&self) {
561 for (connection_id, connection) in &self.connections {
562 for worktree_id in &connection.worktrees {
563 let worktree = &self.worktrees.get(&worktree_id).unwrap();
564 if worktree.host_connection_id != *connection_id {
565 assert!(worktree.share().unwrap().guests.contains_key(connection_id));
566 }
567 }
568 for channel_id in &connection.channels {
569 let channel = self.channels.get(channel_id).unwrap();
570 assert!(channel.connection_ids.contains(connection_id));
571 }
572 assert!(self
573 .connections_by_user_id
574 .get(&connection.user_id)
575 .unwrap()
576 .contains(connection_id));
577 }
578
579 for (user_id, connection_ids) in &self.connections_by_user_id {
580 for connection_id in connection_ids {
581 assert_eq!(
582 self.connections.get(connection_id).unwrap().user_id,
583 *user_id
584 );
585 }
586 }
587
588 for (worktree_id, worktree) in &self.worktrees {
589 let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
590 assert!(host_connection.worktrees.contains(worktree_id));
591
592 for authorized_user_ids in &worktree.authorized_user_ids {
593 let visible_worktree_ids = self
594 .visible_worktrees_by_user_id
595 .get(authorized_user_ids)
596 .unwrap();
597 assert!(visible_worktree_ids.contains(worktree_id));
598 }
599
600 if let Some(share) = &worktree.share {
601 for guest_connection_id in share.guests.keys() {
602 let guest_connection = self.connections.get(guest_connection_id).unwrap();
603 assert!(guest_connection.worktrees.contains(worktree_id));
604 }
605 assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
606 assert_eq!(
607 share.active_replica_ids,
608 share
609 .guests
610 .values()
611 .map(|(replica_id, _)| *replica_id)
612 .collect::<HashSet<_>>(),
613 );
614 }
615 }
616
617 for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
618 for worktree_id in visible_worktree_ids {
619 let worktree = self.worktrees.get(worktree_id).unwrap();
620 assert!(worktree.authorized_user_ids.contains(user_id));
621 }
622 }
623
624 for (channel_id, channel) in &self.channels {
625 for connection_id in &channel.connection_ids {
626 let connection = self.connections.get(connection_id).unwrap();
627 assert!(connection.channels.contains(channel_id));
628 }
629 }
630 }
631}
632
633impl Worktree {
634 pub fn connection_ids(&self) -> Vec<ConnectionId> {
635 if let Some(share) = &self.share {
636 share
637 .guests
638 .keys()
639 .copied()
640 .chain(Some(self.host_connection_id))
641 .collect()
642 } else {
643 vec![self.host_connection_id]
644 }
645 }
646
647 pub fn share(&self) -> tide::Result<&ProjectShare> {
648 Ok(self
649 .share
650 .as_ref()
651 .ok_or_else(|| anyhow!("worktree is not shared"))?)
652 }
653
654 fn share_mut(&mut self) -> tide::Result<&mut ProjectShare> {
655 Ok(self
656 .share
657 .as_mut()
658 .ok_or_else(|| anyhow!("worktree is not shared"))?)
659 }
660}
661
662impl Channel {
663 fn connection_ids(&self) -> Vec<ConnectionId> {
664 self.connection_ids.iter().copied().collect()
665 }
666}