1use crate::db::{ChannelId, UserId};
2use anyhow::anyhow;
3use std::collections::{hash_map, HashMap, HashSet};
4use zrpc::{proto, ConnectionId};
5
6#[derive(Default)]
7pub struct Store {
8 connections: HashMap<ConnectionId, ConnectionState>,
9 connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
10 worktrees: HashMap<u64, Worktree>,
11 visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
12 channels: HashMap<ChannelId, Channel>,
13 next_worktree_id: u64,
14}
15
16struct ConnectionState {
17 user_id: UserId,
18 worktrees: HashSet<u64>,
19 channels: HashSet<ChannelId>,
20}
21
22pub struct Worktree {
23 pub host_connection_id: ConnectionId,
24 pub collaborator_user_ids: Vec<UserId>,
25 pub root_name: String,
26 pub share: Option<WorktreeShare>,
27}
28
29pub struct WorktreeShare {
30 pub guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
31 pub active_replica_ids: HashSet<ReplicaId>,
32 pub entries: HashMap<u64, proto::Entry>,
33}
34
35#[derive(Default)]
36pub struct Channel {
37 pub connection_ids: HashSet<ConnectionId>,
38}
39
40pub type ReplicaId = u16;
41
42#[derive(Default)]
43pub struct RemovedConnectionState {
44 pub hosted_worktrees: HashMap<u64, Worktree>,
45 pub guest_worktree_ids: HashMap<u64, Vec<ConnectionId>>,
46 pub collaborator_ids: HashSet<UserId>,
47}
48
49pub struct JoinedWorktree<'a> {
50 pub replica_id: ReplicaId,
51 pub worktree: &'a Worktree,
52}
53
54pub struct UnsharedWorktree {
55 pub connection_ids: Vec<ConnectionId>,
56 pub collaborator_ids: Vec<UserId>,
57}
58
59pub struct LeftWorktree {
60 pub connection_ids: Vec<ConnectionId>,
61 pub collaborator_ids: Vec<UserId>,
62}
63
64impl Store {
65 pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
66 self.connections.insert(
67 connection_id,
68 ConnectionState {
69 user_id,
70 worktrees: Default::default(),
71 channels: Default::default(),
72 },
73 );
74 self.connections_by_user_id
75 .entry(user_id)
76 .or_default()
77 .insert(connection_id);
78 }
79
80 pub fn remove_connection(
81 &mut self,
82 connection_id: ConnectionId,
83 ) -> tide::Result<RemovedConnectionState> {
84 let connection = if let Some(connection) = self.connections.get(&connection_id) {
85 connection
86 } else {
87 return Err(anyhow!("no such connection"))?;
88 };
89
90 for channel_id in &connection.channels {
91 if let Some(channel) = self.channels.get_mut(&channel_id) {
92 channel.connection_ids.remove(&connection_id);
93 }
94 }
95
96 let user_connections = self
97 .connections_by_user_id
98 .get_mut(&connection.user_id)
99 .unwrap();
100 user_connections.remove(&connection_id);
101 if user_connections.is_empty() {
102 self.connections_by_user_id.remove(&connection.user_id);
103 }
104
105 let mut result = RemovedConnectionState::default();
106 for worktree_id in connection.worktrees.clone() {
107 if let Ok(worktree) = self.remove_worktree(worktree_id, connection_id) {
108 result
109 .collaborator_ids
110 .extend(worktree.collaborator_user_ids.iter().copied());
111 result.hosted_worktrees.insert(worktree_id, worktree);
112 } else {
113 if let Some(worktree) = self.worktrees.get(&worktree_id) {
114 result
115 .guest_worktree_ids
116 .insert(worktree_id, worktree.connection_ids());
117 result
118 .collaborator_ids
119 .extend(worktree.collaborator_user_ids.iter().copied());
120 }
121 }
122 }
123
124 Ok(result)
125 }
126
127 #[cfg(test)]
128 pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
129 self.channels.get(&id)
130 }
131
132 pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
133 if let Some(connection) = self.connections.get_mut(&connection_id) {
134 connection.channels.insert(channel_id);
135 self.channels
136 .entry(channel_id)
137 .or_default()
138 .connection_ids
139 .insert(connection_id);
140 }
141 }
142
143 pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
144 if let Some(connection) = self.connections.get_mut(&connection_id) {
145 connection.channels.remove(&channel_id);
146 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
147 entry.get_mut().connection_ids.remove(&connection_id);
148 if entry.get_mut().connection_ids.is_empty() {
149 entry.remove();
150 }
151 }
152 }
153 }
154
155 pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
156 Ok(self
157 .connections
158 .get(&connection_id)
159 .ok_or_else(|| anyhow!("unknown connection"))?
160 .user_id)
161 }
162
163 pub fn connection_ids_for_user<'a>(
164 &'a self,
165 user_id: UserId,
166 ) -> impl 'a + Iterator<Item = ConnectionId> {
167 self.connections_by_user_id
168 .get(&user_id)
169 .into_iter()
170 .flatten()
171 .copied()
172 }
173
174 pub fn collaborators_for_user(&self, user_id: UserId) -> Vec<proto::Collaborator> {
175 let mut collaborators = HashMap::new();
176 for worktree_id in self
177 .visible_worktrees_by_user_id
178 .get(&user_id)
179 .unwrap_or(&HashSet::new())
180 {
181 let worktree = &self.worktrees[worktree_id];
182
183 let mut guests = HashSet::new();
184 if let Ok(share) = worktree.share() {
185 for guest_connection_id in share.guest_connection_ids.keys() {
186 if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
187 guests.insert(user_id.to_proto());
188 }
189 }
190 }
191
192 if let Ok(host_user_id) = self.user_id_for_connection(worktree.host_connection_id) {
193 collaborators
194 .entry(host_user_id)
195 .or_insert_with(|| proto::Collaborator {
196 user_id: host_user_id.to_proto(),
197 worktrees: Vec::new(),
198 })
199 .worktrees
200 .push(proto::WorktreeMetadata {
201 id: *worktree_id,
202 root_name: worktree.root_name.clone(),
203 is_shared: worktree.share.is_some(),
204 guests: guests.into_iter().collect(),
205 });
206 }
207 }
208
209 collaborators.into_values().collect()
210 }
211
212 pub fn add_worktree(&mut self, worktree: Worktree) -> u64 {
213 let worktree_id = self.next_worktree_id;
214 for collaborator_user_id in &worktree.collaborator_user_ids {
215 self.visible_worktrees_by_user_id
216 .entry(*collaborator_user_id)
217 .or_default()
218 .insert(worktree_id);
219 }
220 self.next_worktree_id += 1;
221 if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
222 connection.worktrees.insert(worktree_id);
223 }
224 self.worktrees.insert(worktree_id, worktree);
225
226 #[cfg(test)]
227 self.check_invariants();
228
229 worktree_id
230 }
231
232 pub fn remove_worktree(
233 &mut self,
234 worktree_id: u64,
235 acting_connection_id: ConnectionId,
236 ) -> tide::Result<Worktree> {
237 let worktree = if let hash_map::Entry::Occupied(e) = self.worktrees.entry(worktree_id) {
238 if e.get().host_connection_id != acting_connection_id {
239 Err(anyhow!("not your worktree"))?;
240 }
241 e.remove()
242 } else {
243 return Err(anyhow!("no such worktree"))?;
244 };
245
246 if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
247 connection.worktrees.remove(&worktree_id);
248 }
249
250 if let Some(share) = &worktree.share {
251 for connection_id in share.guest_connection_ids.keys() {
252 if let Some(connection) = self.connections.get_mut(connection_id) {
253 connection.worktrees.remove(&worktree_id);
254 }
255 }
256 }
257
258 for collaborator_user_id in &worktree.collaborator_user_ids {
259 if let Some(visible_worktrees) = self
260 .visible_worktrees_by_user_id
261 .get_mut(&collaborator_user_id)
262 {
263 visible_worktrees.remove(&worktree_id);
264 }
265 }
266
267 #[cfg(test)]
268 self.check_invariants();
269
270 Ok(worktree)
271 }
272
273 pub fn share_worktree(
274 &mut self,
275 worktree_id: u64,
276 connection_id: ConnectionId,
277 entries: HashMap<u64, proto::Entry>,
278 ) -> Option<Vec<UserId>> {
279 if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
280 if worktree.host_connection_id == connection_id {
281 worktree.share = Some(WorktreeShare {
282 guest_connection_ids: Default::default(),
283 active_replica_ids: Default::default(),
284 entries,
285 });
286 return Some(worktree.collaborator_user_ids.clone());
287 }
288 }
289 None
290 }
291
292 pub fn unshare_worktree(
293 &mut self,
294 worktree_id: u64,
295 acting_connection_id: ConnectionId,
296 ) -> tide::Result<UnsharedWorktree> {
297 let worktree = if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
298 worktree
299 } else {
300 return Err(anyhow!("no such worktree"))?;
301 };
302
303 if worktree.host_connection_id != acting_connection_id {
304 return Err(anyhow!("not your worktree"))?;
305 }
306
307 let connection_ids = worktree.connection_ids();
308
309 if let Some(_) = worktree.share.take() {
310 for connection_id in &connection_ids {
311 if let Some(connection) = self.connections.get_mut(connection_id) {
312 connection.worktrees.remove(&worktree_id);
313 }
314 }
315 Ok(UnsharedWorktree {
316 connection_ids,
317 collaborator_ids: worktree.collaborator_user_ids.clone(),
318 })
319 } else {
320 Err(anyhow!("worktree is not shared"))?
321 }
322 }
323
324 pub fn join_worktree(
325 &mut self,
326 connection_id: ConnectionId,
327 user_id: UserId,
328 worktree_id: u64,
329 ) -> tide::Result<JoinedWorktree> {
330 let connection = self
331 .connections
332 .get_mut(&connection_id)
333 .ok_or_else(|| anyhow!("no such connection"))?;
334 let worktree = self
335 .worktrees
336 .get_mut(&worktree_id)
337 .and_then(|worktree| {
338 if worktree.collaborator_user_ids.contains(&user_id) {
339 Some(worktree)
340 } else {
341 None
342 }
343 })
344 .ok_or_else(|| anyhow!("no such worktree"))?;
345
346 let share = worktree.share_mut()?;
347 connection.worktrees.insert(worktree_id);
348
349 let mut replica_id = 1;
350 while share.active_replica_ids.contains(&replica_id) {
351 replica_id += 1;
352 }
353 share.active_replica_ids.insert(replica_id);
354 share.guest_connection_ids.insert(connection_id, replica_id);
355 Ok(JoinedWorktree {
356 replica_id,
357 worktree,
358 })
359 }
360
361 pub fn leave_worktree(
362 &mut self,
363 connection_id: ConnectionId,
364 worktree_id: u64,
365 ) -> Option<LeftWorktree> {
366 let worktree = self.worktrees.get_mut(&worktree_id)?;
367 let share = worktree.share.as_mut()?;
368 let replica_id = share.guest_connection_ids.remove(&connection_id)?;
369 share.active_replica_ids.remove(&replica_id);
370 Some(LeftWorktree {
371 connection_ids: worktree.connection_ids(),
372 collaborator_ids: worktree.collaborator_user_ids.clone(),
373 })
374 }
375
376 pub fn update_worktree(
377 &mut self,
378 connection_id: ConnectionId,
379 worktree_id: u64,
380 removed_entries: &[u64],
381 updated_entries: &[proto::Entry],
382 ) -> tide::Result<Vec<ConnectionId>> {
383 let worktree = self.write_worktree(worktree_id, connection_id)?;
384 let share = worktree.share_mut()?;
385 for entry_id in removed_entries {
386 share.entries.remove(&entry_id);
387 }
388 for entry in updated_entries {
389 share.entries.insert(entry.id, entry.clone());
390 }
391 Ok(worktree.connection_ids())
392 }
393
394 pub fn worktree_host_connection_id(
395 &self,
396 connection_id: ConnectionId,
397 worktree_id: u64,
398 ) -> tide::Result<ConnectionId> {
399 Ok(self
400 .read_worktree(worktree_id, connection_id)?
401 .host_connection_id)
402 }
403
404 pub fn worktree_guest_connection_ids(
405 &self,
406 connection_id: ConnectionId,
407 worktree_id: u64,
408 ) -> tide::Result<Vec<ConnectionId>> {
409 Ok(self
410 .read_worktree(worktree_id, connection_id)?
411 .share()?
412 .guest_connection_ids
413 .keys()
414 .copied()
415 .collect())
416 }
417
418 pub fn worktree_connection_ids(
419 &self,
420 connection_id: ConnectionId,
421 worktree_id: u64,
422 ) -> tide::Result<Vec<ConnectionId>> {
423 Ok(self
424 .read_worktree(worktree_id, connection_id)?
425 .connection_ids())
426 }
427
428 pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
429 Some(self.channels.get(&channel_id)?.connection_ids())
430 }
431
432 fn read_worktree(
433 &self,
434 worktree_id: u64,
435 connection_id: ConnectionId,
436 ) -> tide::Result<&Worktree> {
437 let worktree = self
438 .worktrees
439 .get(&worktree_id)
440 .ok_or_else(|| anyhow!("worktree not found"))?;
441
442 if worktree.host_connection_id == connection_id
443 || worktree
444 .share()?
445 .guest_connection_ids
446 .contains_key(&connection_id)
447 {
448 Ok(worktree)
449 } else {
450 Err(anyhow!(
451 "{} is not a member of worktree {}",
452 connection_id,
453 worktree_id
454 ))?
455 }
456 }
457
458 fn write_worktree(
459 &mut self,
460 worktree_id: u64,
461 connection_id: ConnectionId,
462 ) -> tide::Result<&mut Worktree> {
463 let worktree = self
464 .worktrees
465 .get_mut(&worktree_id)
466 .ok_or_else(|| anyhow!("worktree not found"))?;
467
468 if worktree.host_connection_id == connection_id
469 || worktree.share.as_ref().map_or(false, |share| {
470 share.guest_connection_ids.contains_key(&connection_id)
471 })
472 {
473 Ok(worktree)
474 } else {
475 Err(anyhow!(
476 "{} is not a member of worktree {}",
477 connection_id,
478 worktree_id
479 ))?
480 }
481 }
482
483 #[cfg(test)]
484 fn check_invariants(&self) {
485 for (connection_id, connection) in &self.connections {
486 for worktree_id in &connection.worktrees {
487 let worktree = &self.worktrees.get(&worktree_id).unwrap();
488 if worktree.host_connection_id != *connection_id {
489 assert!(worktree
490 .share()
491 .unwrap()
492 .guest_connection_ids
493 .contains_key(connection_id));
494 }
495 }
496 for channel_id in &connection.channels {
497 let channel = self.channels.get(channel_id).unwrap();
498 assert!(channel.connection_ids.contains(connection_id));
499 }
500 assert!(self
501 .connections_by_user_id
502 .get(&connection.user_id)
503 .unwrap()
504 .contains(connection_id));
505 }
506
507 for (user_id, connection_ids) in &self.connections_by_user_id {
508 for connection_id in connection_ids {
509 assert_eq!(
510 self.connections.get(connection_id).unwrap().user_id,
511 *user_id
512 );
513 }
514 }
515
516 for (worktree_id, worktree) in &self.worktrees {
517 let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
518 assert!(host_connection.worktrees.contains(worktree_id));
519
520 for collaborator_id in &worktree.collaborator_user_ids {
521 let visible_worktree_ids = self
522 .visible_worktrees_by_user_id
523 .get(collaborator_id)
524 .unwrap();
525 assert!(visible_worktree_ids.contains(worktree_id));
526 }
527
528 if let Some(share) = &worktree.share {
529 for guest_connection_id in share.guest_connection_ids.keys() {
530 let guest_connection = self.connections.get(guest_connection_id).unwrap();
531 assert!(guest_connection.worktrees.contains(worktree_id));
532 }
533 assert_eq!(
534 share.active_replica_ids.len(),
535 share.guest_connection_ids.len(),
536 );
537 assert_eq!(
538 share.active_replica_ids,
539 share
540 .guest_connection_ids
541 .values()
542 .copied()
543 .collect::<HashSet<_>>(),
544 );
545 }
546 }
547
548 for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
549 for worktree_id in visible_worktree_ids {
550 let worktree = self.worktrees.get(worktree_id).unwrap();
551 assert!(worktree.collaborator_user_ids.contains(user_id));
552 }
553 }
554
555 for (channel_id, channel) in &self.channels {
556 for connection_id in &channel.connection_ids {
557 let connection = self.connections.get(connection_id).unwrap();
558 assert!(connection.channels.contains(channel_id));
559 }
560 }
561 }
562}
563
564impl Worktree {
565 pub fn connection_ids(&self) -> Vec<ConnectionId> {
566 if let Some(share) = &self.share {
567 share
568 .guest_connection_ids
569 .keys()
570 .copied()
571 .chain(Some(self.host_connection_id))
572 .collect()
573 } else {
574 vec![self.host_connection_id]
575 }
576 }
577
578 pub fn share(&self) -> tide::Result<&WorktreeShare> {
579 Ok(self
580 .share
581 .as_ref()
582 .ok_or_else(|| anyhow!("worktree is not shared"))?)
583 }
584
585 fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
586 Ok(self
587 .share
588 .as_mut()
589 .ok_or_else(|| anyhow!("worktree is not shared"))?)
590 }
591}
592
593impl Channel {
594 fn connection_ids(&self) -> Vec<ConnectionId> {
595 self.connection_ids.iter().copied().collect()
596 }
597}