1use crate::db::{ChannelId, UserId};
2use anyhow::anyhow;
3use std::collections::{hash_map, HashMap, HashSet};
4use zrpc::{proto, ConnectionId};
5
6#[derive(Default)]
7pub struct Store {
8 connections: HashMap<ConnectionId, ConnectionState>,
9 connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
10 worktrees: HashMap<u64, Worktree>,
11 visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
12 channels: HashMap<ChannelId, Channel>,
13 next_worktree_id: u64,
14}
15
16struct ConnectionState {
17 user_id: UserId,
18 worktrees: HashSet<u64>,
19 channels: HashSet<ChannelId>,
20}
21
22pub struct Worktree {
23 pub host_connection_id: ConnectionId,
24 pub collaborator_user_ids: Vec<UserId>,
25 pub root_name: String,
26 pub share: Option<WorktreeShare>,
27}
28
29pub struct WorktreeShare {
30 pub guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
31 pub active_replica_ids: HashSet<ReplicaId>,
32 pub entries: HashMap<u64, proto::Entry>,
33}
34
35#[derive(Default)]
36pub struct Channel {
37 pub connection_ids: HashSet<ConnectionId>,
38}
39
40pub type ReplicaId = u16;
41
42#[derive(Default)]
43pub struct RemovedConnectionState {
44 pub hosted_worktrees: HashMap<u64, Worktree>,
45 pub guest_worktree_ids: HashMap<u64, Vec<ConnectionId>>,
46 pub collaborator_ids: HashSet<UserId>,
47}
48
49pub struct JoinedWorktree<'a> {
50 pub replica_id: ReplicaId,
51 pub worktree: &'a Worktree,
52}
53
54pub struct UnsharedWorktree {
55 pub connection_ids: Vec<ConnectionId>,
56 pub collaborator_ids: Vec<UserId>,
57}
58
59pub struct LeftWorktree {
60 pub connection_ids: Vec<ConnectionId>,
61 pub collaborator_ids: Vec<UserId>,
62}
63
64impl Store {
65 pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
66 self.connections.insert(
67 connection_id,
68 ConnectionState {
69 user_id,
70 worktrees: Default::default(),
71 channels: Default::default(),
72 },
73 );
74 self.connections_by_user_id
75 .entry(user_id)
76 .or_default()
77 .insert(connection_id);
78 }
79
80 pub fn remove_connection(
81 &mut self,
82 connection_id: ConnectionId,
83 ) -> tide::Result<RemovedConnectionState> {
84 let connection = if let Some(connection) = self.connections.get(&connection_id) {
85 connection
86 } else {
87 return Err(anyhow!("no such connection"))?;
88 };
89
90 for channel_id in &connection.channels {
91 if let Some(channel) = self.channels.get_mut(&channel_id) {
92 channel.connection_ids.remove(&connection_id);
93 }
94 }
95
96 let user_connections = self
97 .connections_by_user_id
98 .get_mut(&connection.user_id)
99 .unwrap();
100 user_connections.remove(&connection_id);
101 if user_connections.is_empty() {
102 self.connections_by_user_id.remove(&connection.user_id);
103 }
104
105 let mut result = RemovedConnectionState::default();
106 for worktree_id in connection.worktrees.clone() {
107 if let Ok(worktree) = self.remove_worktree(worktree_id, connection_id) {
108 result
109 .collaborator_ids
110 .extend(worktree.collaborator_user_ids.iter().copied());
111 result.hosted_worktrees.insert(worktree_id, worktree);
112 } else {
113 if let Some(worktree) = self.worktrees.get(&worktree_id) {
114 result
115 .guest_worktree_ids
116 .insert(worktree_id, worktree.connection_ids());
117 result
118 .collaborator_ids
119 .extend(worktree.collaborator_user_ids.iter().copied());
120 }
121 }
122 }
123
124 Ok(result)
125 }
126
127 #[cfg(test)]
128 pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
129 self.channels.get(&id)
130 }
131
132 pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
133 if let Some(connection) = self.connections.get_mut(&connection_id) {
134 connection.channels.insert(channel_id);
135 self.channels
136 .entry(channel_id)
137 .or_default()
138 .connection_ids
139 .insert(connection_id);
140 }
141 }
142
143 pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
144 if let Some(connection) = self.connections.get_mut(&connection_id) {
145 connection.channels.remove(&channel_id);
146 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
147 entry.get_mut().connection_ids.remove(&connection_id);
148 if entry.get_mut().connection_ids.is_empty() {
149 entry.remove();
150 }
151 }
152 }
153 }
154
155 pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
156 Ok(self
157 .connections
158 .get(&connection_id)
159 .ok_or_else(|| anyhow!("unknown connection"))?
160 .user_id)
161 }
162
163 pub fn connection_ids_for_user<'a>(
164 &'a self,
165 user_id: UserId,
166 ) -> impl 'a + Iterator<Item = ConnectionId> {
167 self.connections_by_user_id
168 .get(&user_id)
169 .into_iter()
170 .flatten()
171 .copied()
172 }
173
174 pub fn collaborators_for_user(&self, user_id: UserId) -> Vec<proto::Collaborator> {
175 let mut collaborators = HashMap::new();
176 for worktree_id in self
177 .visible_worktrees_by_user_id
178 .get(&user_id)
179 .unwrap_or(&HashSet::new())
180 {
181 let worktree = &self.worktrees[worktree_id];
182
183 let mut guests = HashSet::new();
184 if let Ok(share) = worktree.share() {
185 for guest_connection_id in share.guest_connection_ids.keys() {
186 if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
187 guests.insert(user_id.to_proto());
188 }
189 }
190 }
191
192 if let Ok(host_user_id) = self.user_id_for_connection(worktree.host_connection_id) {
193 collaborators
194 .entry(host_user_id)
195 .or_insert_with(|| proto::Collaborator {
196 user_id: host_user_id.to_proto(),
197 worktrees: Vec::new(),
198 })
199 .worktrees
200 .push(proto::WorktreeMetadata {
201 id: *worktree_id,
202 root_name: worktree.root_name.clone(),
203 is_shared: worktree.share.is_some(),
204 guests: guests.into_iter().collect(),
205 });
206 }
207 }
208
209 collaborators.into_values().collect()
210 }
211
212 pub fn add_worktree(&mut self, worktree: Worktree) -> u64 {
213 let worktree_id = self.next_worktree_id;
214 for collaborator_user_id in &worktree.collaborator_user_ids {
215 self.visible_worktrees_by_user_id
216 .entry(*collaborator_user_id)
217 .or_default()
218 .insert(worktree_id);
219 }
220 self.next_worktree_id += 1;
221 if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
222 connection.worktrees.insert(worktree_id);
223 }
224 self.worktrees.insert(worktree_id, worktree);
225
226 #[cfg(test)]
227 self.check_invariants();
228
229 worktree_id
230 }
231
232 pub fn remove_worktree(
233 &mut self,
234 worktree_id: u64,
235 acting_connection_id: ConnectionId,
236 ) -> tide::Result<Worktree> {
237 let worktree = if let hash_map::Entry::Occupied(e) = self.worktrees.entry(worktree_id) {
238 if e.get().host_connection_id != acting_connection_id {
239 Err(anyhow!("not your worktree"))?;
240 }
241 e.remove()
242 } else {
243 return Err(anyhow!("no such worktree"))?;
244 };
245
246 if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
247 connection.worktrees.remove(&worktree_id);
248 }
249
250 if let Some(share) = &worktree.share {
251 for connection_id in share.guest_connection_ids.keys() {
252 if let Some(connection) = self.connections.get_mut(connection_id) {
253 connection.worktrees.remove(&worktree_id);
254 }
255 }
256 }
257
258 for collaborator_user_id in &worktree.collaborator_user_ids {
259 if let Some(visible_worktrees) = self
260 .visible_worktrees_by_user_id
261 .get_mut(&collaborator_user_id)
262 {
263 visible_worktrees.remove(&worktree_id);
264 }
265 }
266
267 #[cfg(test)]
268 self.check_invariants();
269
270 Ok(worktree)
271 }
272
273 pub fn share_worktree(
274 &mut self,
275 worktree_id: u64,
276 connection_id: ConnectionId,
277 entries: HashMap<u64, proto::Entry>,
278 ) -> Option<Vec<UserId>> {
279 if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
280 if worktree.host_connection_id == connection_id {
281 worktree.share = Some(WorktreeShare {
282 guest_connection_ids: Default::default(),
283 active_replica_ids: Default::default(),
284 entries,
285 });
286 return Some(worktree.collaborator_user_ids.clone());
287 }
288 }
289 None
290 }
291
292 pub fn unshare_worktree(
293 &mut self,
294 worktree_id: u64,
295 acting_connection_id: ConnectionId,
296 ) -> tide::Result<UnsharedWorktree> {
297 let worktree = if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
298 worktree
299 } else {
300 return Err(anyhow!("no such worktree"))?;
301 };
302
303 if worktree.host_connection_id != acting_connection_id {
304 return Err(anyhow!("not your worktree"))?;
305 }
306
307 let connection_ids = worktree.connection_ids();
308 let collaborator_ids = worktree.collaborator_user_ids.clone();
309 if let Some(share) = worktree.share.take() {
310 for connection_id in share.guest_connection_ids.into_keys() {
311 if let Some(connection) = self.connections.get_mut(&connection_id) {
312 connection.worktrees.remove(&worktree_id);
313 }
314 }
315
316 #[cfg(test)]
317 self.check_invariants();
318
319 Ok(UnsharedWorktree {
320 connection_ids,
321 collaborator_ids,
322 })
323 } else {
324 Err(anyhow!("worktree is not shared"))?
325 }
326 }
327
328 pub fn join_worktree(
329 &mut self,
330 connection_id: ConnectionId,
331 user_id: UserId,
332 worktree_id: u64,
333 ) -> tide::Result<JoinedWorktree> {
334 let connection = self
335 .connections
336 .get_mut(&connection_id)
337 .ok_or_else(|| anyhow!("no such connection"))?;
338 let worktree = self
339 .worktrees
340 .get_mut(&worktree_id)
341 .and_then(|worktree| {
342 if worktree.collaborator_user_ids.contains(&user_id) {
343 Some(worktree)
344 } else {
345 None
346 }
347 })
348 .ok_or_else(|| anyhow!("no such worktree"))?;
349
350 let share = worktree.share_mut()?;
351 connection.worktrees.insert(worktree_id);
352
353 let mut replica_id = 1;
354 while share.active_replica_ids.contains(&replica_id) {
355 replica_id += 1;
356 }
357 share.active_replica_ids.insert(replica_id);
358 share.guest_connection_ids.insert(connection_id, replica_id);
359
360 #[cfg(test)]
361 self.check_invariants();
362
363 Ok(JoinedWorktree {
364 replica_id,
365 worktree: &self.worktrees[&worktree_id],
366 })
367 }
368
369 pub fn leave_worktree(
370 &mut self,
371 connection_id: ConnectionId,
372 worktree_id: u64,
373 ) -> Option<LeftWorktree> {
374 let worktree = self.worktrees.get_mut(&worktree_id)?;
375 let share = worktree.share.as_mut()?;
376 let replica_id = share.guest_connection_ids.remove(&connection_id)?;
377 share.active_replica_ids.remove(&replica_id);
378
379 let connection = self.connections.get_mut(&connection_id)?;
380 connection.worktrees.remove(&worktree_id);
381
382 let connection_ids = worktree.connection_ids();
383 let collaborator_ids = worktree.collaborator_user_ids.clone();
384
385 #[cfg(test)]
386 self.check_invariants();
387
388 Some(LeftWorktree {
389 connection_ids,
390 collaborator_ids,
391 })
392 }
393
394 pub fn update_worktree(
395 &mut self,
396 connection_id: ConnectionId,
397 worktree_id: u64,
398 removed_entries: &[u64],
399 updated_entries: &[proto::Entry],
400 ) -> tide::Result<Vec<ConnectionId>> {
401 let worktree = self.write_worktree(worktree_id, connection_id)?;
402 let share = worktree.share_mut()?;
403 for entry_id in removed_entries {
404 share.entries.remove(&entry_id);
405 }
406 for entry in updated_entries {
407 share.entries.insert(entry.id, entry.clone());
408 }
409 Ok(worktree.connection_ids())
410 }
411
412 pub fn worktree_host_connection_id(
413 &self,
414 connection_id: ConnectionId,
415 worktree_id: u64,
416 ) -> tide::Result<ConnectionId> {
417 Ok(self
418 .read_worktree(worktree_id, connection_id)?
419 .host_connection_id)
420 }
421
422 pub fn worktree_guest_connection_ids(
423 &self,
424 connection_id: ConnectionId,
425 worktree_id: u64,
426 ) -> tide::Result<Vec<ConnectionId>> {
427 Ok(self
428 .read_worktree(worktree_id, connection_id)?
429 .share()?
430 .guest_connection_ids
431 .keys()
432 .copied()
433 .collect())
434 }
435
436 pub fn worktree_connection_ids(
437 &self,
438 connection_id: ConnectionId,
439 worktree_id: u64,
440 ) -> tide::Result<Vec<ConnectionId>> {
441 Ok(self
442 .read_worktree(worktree_id, connection_id)?
443 .connection_ids())
444 }
445
446 pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
447 Some(self.channels.get(&channel_id)?.connection_ids())
448 }
449
450 fn read_worktree(
451 &self,
452 worktree_id: u64,
453 connection_id: ConnectionId,
454 ) -> tide::Result<&Worktree> {
455 let worktree = self
456 .worktrees
457 .get(&worktree_id)
458 .ok_or_else(|| anyhow!("worktree not found"))?;
459
460 if worktree.host_connection_id == connection_id
461 || worktree
462 .share()?
463 .guest_connection_ids
464 .contains_key(&connection_id)
465 {
466 Ok(worktree)
467 } else {
468 Err(anyhow!(
469 "{} is not a member of worktree {}",
470 connection_id,
471 worktree_id
472 ))?
473 }
474 }
475
476 fn write_worktree(
477 &mut self,
478 worktree_id: u64,
479 connection_id: ConnectionId,
480 ) -> tide::Result<&mut Worktree> {
481 let worktree = self
482 .worktrees
483 .get_mut(&worktree_id)
484 .ok_or_else(|| anyhow!("worktree not found"))?;
485
486 if worktree.host_connection_id == connection_id
487 || worktree.share.as_ref().map_or(false, |share| {
488 share.guest_connection_ids.contains_key(&connection_id)
489 })
490 {
491 Ok(worktree)
492 } else {
493 Err(anyhow!(
494 "{} is not a member of worktree {}",
495 connection_id,
496 worktree_id
497 ))?
498 }
499 }
500
501 #[cfg(test)]
502 fn check_invariants(&self) {
503 for (connection_id, connection) in &self.connections {
504 for worktree_id in &connection.worktrees {
505 let worktree = &self.worktrees.get(&worktree_id).unwrap();
506 if worktree.host_connection_id != *connection_id {
507 assert!(worktree
508 .share()
509 .unwrap()
510 .guest_connection_ids
511 .contains_key(connection_id));
512 }
513 }
514 for channel_id in &connection.channels {
515 let channel = self.channels.get(channel_id).unwrap();
516 assert!(channel.connection_ids.contains(connection_id));
517 }
518 assert!(self
519 .connections_by_user_id
520 .get(&connection.user_id)
521 .unwrap()
522 .contains(connection_id));
523 }
524
525 for (user_id, connection_ids) in &self.connections_by_user_id {
526 for connection_id in connection_ids {
527 assert_eq!(
528 self.connections.get(connection_id).unwrap().user_id,
529 *user_id
530 );
531 }
532 }
533
534 for (worktree_id, worktree) in &self.worktrees {
535 let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
536 assert!(host_connection.worktrees.contains(worktree_id));
537
538 for collaborator_id in &worktree.collaborator_user_ids {
539 let visible_worktree_ids = self
540 .visible_worktrees_by_user_id
541 .get(collaborator_id)
542 .unwrap();
543 assert!(visible_worktree_ids.contains(worktree_id));
544 }
545
546 if let Some(share) = &worktree.share {
547 for guest_connection_id in share.guest_connection_ids.keys() {
548 let guest_connection = self.connections.get(guest_connection_id).unwrap();
549 assert!(guest_connection.worktrees.contains(worktree_id));
550 }
551 assert_eq!(
552 share.active_replica_ids.len(),
553 share.guest_connection_ids.len(),
554 );
555 assert_eq!(
556 share.active_replica_ids,
557 share
558 .guest_connection_ids
559 .values()
560 .copied()
561 .collect::<HashSet<_>>(),
562 );
563 }
564 }
565
566 for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
567 for worktree_id in visible_worktree_ids {
568 let worktree = self.worktrees.get(worktree_id).unwrap();
569 assert!(worktree.collaborator_user_ids.contains(user_id));
570 }
571 }
572
573 for (channel_id, channel) in &self.channels {
574 for connection_id in &channel.connection_ids {
575 let connection = self.connections.get(connection_id).unwrap();
576 assert!(connection.channels.contains(channel_id));
577 }
578 }
579 }
580}
581
582impl Worktree {
583 pub fn connection_ids(&self) -> Vec<ConnectionId> {
584 if let Some(share) = &self.share {
585 share
586 .guest_connection_ids
587 .keys()
588 .copied()
589 .chain(Some(self.host_connection_id))
590 .collect()
591 } else {
592 vec![self.host_connection_id]
593 }
594 }
595
596 pub fn share(&self) -> tide::Result<&WorktreeShare> {
597 Ok(self
598 .share
599 .as_ref()
600 .ok_or_else(|| anyhow!("worktree is not shared"))?)
601 }
602
603 fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
604 Ok(self
605 .share
606 .as_mut()
607 .ok_or_else(|| anyhow!("worktree is not shared"))?)
608 }
609}
610
611impl Channel {
612 fn connection_ids(&self) -> Vec<ConnectionId> {
613 self.connection_ids.iter().copied().collect()
614 }
615}