1use crate::db::{ChannelId, UserId};
2use crate::errors::TideResultExt;
3use anyhow::anyhow;
4use std::collections::{hash_map, HashMap, HashSet};
5use zrpc::{proto, ConnectionId};
6
7#[derive(Default)]
8pub struct Store {
9 connections: HashMap<ConnectionId, ConnectionState>,
10 connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
11 worktrees: HashMap<u64, Worktree>,
12 visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
13 channels: HashMap<ChannelId, Channel>,
14 next_worktree_id: u64,
15}
16
17struct ConnectionState {
18 user_id: UserId,
19 worktrees: HashSet<u64>,
20 channels: HashSet<ChannelId>,
21}
22
23pub struct Worktree {
24 pub host_connection_id: ConnectionId,
25 pub collaborator_user_ids: Vec<UserId>,
26 pub root_name: String,
27 pub share: Option<WorktreeShare>,
28}
29
30pub struct WorktreeShare {
31 pub guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
32 pub active_replica_ids: HashSet<ReplicaId>,
33 pub entries: HashMap<u64, proto::Entry>,
34}
35
36#[derive(Default)]
37pub struct Channel {
38 pub connection_ids: HashSet<ConnectionId>,
39}
40
41pub type ReplicaId = u16;
42
43#[derive(Default)]
44pub struct RemovedConnectionState {
45 pub hosted_worktrees: HashMap<u64, Worktree>,
46 pub guest_worktree_ids: HashMap<u64, Vec<ConnectionId>>,
47 pub collaborator_ids: HashSet<UserId>,
48}
49
50pub struct JoinedWorktree<'a> {
51 pub replica_id: ReplicaId,
52 pub worktree: &'a Worktree,
53}
54
55pub struct UnsharedWorktree {
56 pub connection_ids: Vec<ConnectionId>,
57 pub collaborator_ids: Vec<UserId>,
58}
59
60pub struct LeftWorktree {
61 pub connection_ids: Vec<ConnectionId>,
62 pub collaborator_ids: Vec<UserId>,
63}
64
65impl Store {
66 pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
67 self.connections.insert(
68 connection_id,
69 ConnectionState {
70 user_id,
71 worktrees: Default::default(),
72 channels: Default::default(),
73 },
74 );
75 self.connections_by_user_id
76 .entry(user_id)
77 .or_default()
78 .insert(connection_id);
79 }
80
81 pub fn remove_connection(
82 &mut self,
83 connection_id: ConnectionId,
84 ) -> tide::Result<RemovedConnectionState> {
85 let connection = if let Some(connection) = self.connections.get(&connection_id) {
86 connection
87 } else {
88 return Err(anyhow!("no such connection"))?;
89 };
90
91 for channel_id in &connection.channels {
92 if let Some(channel) = self.channels.get_mut(&channel_id) {
93 channel.connection_ids.remove(&connection_id);
94 }
95 }
96
97 let user_connections = self
98 .connections_by_user_id
99 .get_mut(&connection.user_id)
100 .unwrap();
101 user_connections.remove(&connection_id);
102 if user_connections.is_empty() {
103 self.connections_by_user_id.remove(&connection.user_id);
104 }
105
106 let mut result = RemovedConnectionState::default();
107 for worktree_id in connection.worktrees.clone() {
108 if let Ok(worktree) = self.remove_worktree(worktree_id, connection_id) {
109 result
110 .collaborator_ids
111 .extend(worktree.collaborator_user_ids.iter().copied());
112 result.hosted_worktrees.insert(worktree_id, worktree);
113 } else {
114 if let Some(worktree) = self.worktrees.get(&worktree_id) {
115 result
116 .guest_worktree_ids
117 .insert(worktree_id, worktree.connection_ids());
118 result
119 .collaborator_ids
120 .extend(worktree.collaborator_user_ids.iter().copied());
121 }
122 }
123 }
124
125 Ok(result)
126 }
127
128 #[cfg(test)]
129 pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
130 self.channels.get(&id)
131 }
132
133 pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
134 if let Some(connection) = self.connections.get_mut(&connection_id) {
135 connection.channels.insert(channel_id);
136 self.channels
137 .entry(channel_id)
138 .or_default()
139 .connection_ids
140 .insert(connection_id);
141 }
142 }
143
144 pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
145 if let Some(connection) = self.connections.get_mut(&connection_id) {
146 connection.channels.remove(&channel_id);
147 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
148 entry.get_mut().connection_ids.remove(&connection_id);
149 if entry.get_mut().connection_ids.is_empty() {
150 entry.remove();
151 }
152 }
153 }
154 }
155
156 pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
157 Ok(self
158 .connections
159 .get(&connection_id)
160 .ok_or_else(|| anyhow!("unknown connection"))?
161 .user_id)
162 }
163
164 pub fn connection_ids_for_user<'a>(
165 &'a self,
166 user_id: UserId,
167 ) -> impl 'a + Iterator<Item = ConnectionId> {
168 self.connections_by_user_id
169 .get(&user_id)
170 .into_iter()
171 .flatten()
172 .copied()
173 }
174
175 pub fn collaborators_for_user(&self, user_id: UserId) -> Vec<proto::Collaborator> {
176 let mut collaborators = HashMap::new();
177 for worktree_id in self
178 .visible_worktrees_by_user_id
179 .get(&user_id)
180 .unwrap_or(&HashSet::new())
181 {
182 let worktree = &self.worktrees[worktree_id];
183
184 let mut guests = HashSet::new();
185 if let Ok(share) = worktree.share() {
186 for guest_connection_id in share.guest_connection_ids.keys() {
187 if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
188 guests.insert(user_id.to_proto());
189 }
190 }
191 }
192
193 if let Ok(host_user_id) = self
194 .user_id_for_connection(worktree.host_connection_id)
195 .context("stale worktree host connection")
196 {
197 let host =
198 collaborators
199 .entry(host_user_id)
200 .or_insert_with(|| proto::Collaborator {
201 user_id: host_user_id.to_proto(),
202 worktrees: Vec::new(),
203 });
204 host.worktrees.push(proto::WorktreeMetadata {
205 id: *worktree_id,
206 root_name: worktree.root_name.clone(),
207 is_shared: worktree.share().is_ok(),
208 guests: guests.into_iter().collect(),
209 });
210 }
211 }
212
213 collaborators.into_values().collect()
214 }
215
216 pub fn add_worktree(&mut self, worktree: Worktree) -> u64 {
217 let worktree_id = self.next_worktree_id;
218 for collaborator_user_id in &worktree.collaborator_user_ids {
219 self.visible_worktrees_by_user_id
220 .entry(*collaborator_user_id)
221 .or_default()
222 .insert(worktree_id);
223 }
224 self.next_worktree_id += 1;
225 if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
226 connection.worktrees.insert(worktree_id);
227 }
228 self.worktrees.insert(worktree_id, worktree);
229
230 #[cfg(test)]
231 self.check_invariants();
232
233 worktree_id
234 }
235
236 pub fn remove_worktree(
237 &mut self,
238 worktree_id: u64,
239 acting_connection_id: ConnectionId,
240 ) -> tide::Result<Worktree> {
241 let worktree = if let hash_map::Entry::Occupied(e) = self.worktrees.entry(worktree_id) {
242 if e.get().host_connection_id != acting_connection_id {
243 Err(anyhow!("not your worktree"))?;
244 }
245 e.remove()
246 } else {
247 return Err(anyhow!("no such worktree"))?;
248 };
249
250 if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
251 connection.worktrees.remove(&worktree_id);
252 }
253
254 if let Some(share) = &worktree.share {
255 for connection_id in share.guest_connection_ids.keys() {
256 if let Some(connection) = self.connections.get_mut(connection_id) {
257 connection.worktrees.remove(&worktree_id);
258 }
259 }
260 }
261
262 for collaborator_user_id in &worktree.collaborator_user_ids {
263 if let Some(visible_worktrees) = self
264 .visible_worktrees_by_user_id
265 .get_mut(&collaborator_user_id)
266 {
267 visible_worktrees.remove(&worktree_id);
268 }
269 }
270
271 #[cfg(test)]
272 self.check_invariants();
273
274 Ok(worktree)
275 }
276
277 pub fn share_worktree(
278 &mut self,
279 worktree_id: u64,
280 connection_id: ConnectionId,
281 entries: HashMap<u64, proto::Entry>,
282 ) -> Option<Vec<UserId>> {
283 if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
284 if worktree.host_connection_id == connection_id {
285 worktree.share = Some(WorktreeShare {
286 guest_connection_ids: Default::default(),
287 active_replica_ids: Default::default(),
288 entries,
289 });
290 return Some(worktree.collaborator_user_ids.clone());
291 }
292 }
293 None
294 }
295
296 pub fn unshare_worktree(
297 &mut self,
298 worktree_id: u64,
299 acting_connection_id: ConnectionId,
300 ) -> tide::Result<UnsharedWorktree> {
301 let worktree = if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
302 worktree
303 } else {
304 return Err(anyhow!("no such worktree"))?;
305 };
306
307 if worktree.host_connection_id != acting_connection_id {
308 return Err(anyhow!("not your worktree"))?;
309 }
310
311 let connection_ids = worktree.connection_ids();
312
313 if let Some(_) = worktree.share.take() {
314 for connection_id in &connection_ids {
315 if let Some(connection) = self.connections.get_mut(connection_id) {
316 connection.worktrees.remove(&worktree_id);
317 }
318 }
319 Ok(UnsharedWorktree {
320 connection_ids,
321 collaborator_ids: worktree.collaborator_user_ids.clone(),
322 })
323 } else {
324 Err(anyhow!("worktree is not shared"))?
325 }
326 }
327
328 pub fn join_worktree(
329 &mut self,
330 connection_id: ConnectionId,
331 user_id: UserId,
332 worktree_id: u64,
333 ) -> tide::Result<JoinedWorktree> {
334 let connection = self
335 .connections
336 .get_mut(&connection_id)
337 .ok_or_else(|| anyhow!("no such connection"))?;
338 let worktree = self
339 .worktrees
340 .get_mut(&worktree_id)
341 .and_then(|worktree| {
342 if worktree.collaborator_user_ids.contains(&user_id) {
343 Some(worktree)
344 } else {
345 None
346 }
347 })
348 .ok_or_else(|| anyhow!("no such worktree"))?;
349
350 let share = worktree.share_mut()?;
351 connection.worktrees.insert(worktree_id);
352
353 let mut replica_id = 1;
354 while share.active_replica_ids.contains(&replica_id) {
355 replica_id += 1;
356 }
357 share.active_replica_ids.insert(replica_id);
358 share.guest_connection_ids.insert(connection_id, replica_id);
359 Ok(JoinedWorktree {
360 replica_id,
361 worktree,
362 })
363 }
364
365 pub fn leave_worktree(
366 &mut self,
367 connection_id: ConnectionId,
368 worktree_id: u64,
369 ) -> Option<LeftWorktree> {
370 let worktree = self.worktrees.get_mut(&worktree_id)?;
371 let share = worktree.share.as_mut()?;
372 let replica_id = share.guest_connection_ids.remove(&connection_id)?;
373 share.active_replica_ids.remove(&replica_id);
374 Some(LeftWorktree {
375 connection_ids: worktree.connection_ids(),
376 collaborator_ids: worktree.collaborator_user_ids.clone(),
377 })
378 }
379
380 pub fn update_worktree(
381 &mut self,
382 connection_id: ConnectionId,
383 worktree_id: u64,
384 removed_entries: &[u64],
385 updated_entries: &[proto::Entry],
386 ) -> tide::Result<Vec<ConnectionId>> {
387 let worktree = self.write_worktree(worktree_id, connection_id)?;
388 let share = worktree.share_mut()?;
389 for entry_id in removed_entries {
390 share.entries.remove(&entry_id);
391 }
392 for entry in updated_entries {
393 share.entries.insert(entry.id, entry.clone());
394 }
395 Ok(worktree.connection_ids())
396 }
397
398 pub fn worktree_host_connection_id(
399 &self,
400 connection_id: ConnectionId,
401 worktree_id: u64,
402 ) -> tide::Result<ConnectionId> {
403 Ok(self
404 .read_worktree(worktree_id, connection_id)?
405 .host_connection_id)
406 }
407
408 pub fn worktree_guest_connection_ids(
409 &self,
410 connection_id: ConnectionId,
411 worktree_id: u64,
412 ) -> tide::Result<Vec<ConnectionId>> {
413 Ok(self
414 .read_worktree(worktree_id, connection_id)?
415 .share()?
416 .guest_connection_ids
417 .keys()
418 .copied()
419 .collect())
420 }
421
422 pub fn worktree_connection_ids(
423 &self,
424 connection_id: ConnectionId,
425 worktree_id: u64,
426 ) -> tide::Result<Vec<ConnectionId>> {
427 Ok(self
428 .read_worktree(worktree_id, connection_id)?
429 .connection_ids())
430 }
431
432 pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
433 Some(self.channels.get(&channel_id)?.connection_ids())
434 }
435
436 fn read_worktree(
437 &self,
438 worktree_id: u64,
439 connection_id: ConnectionId,
440 ) -> tide::Result<&Worktree> {
441 let worktree = self
442 .worktrees
443 .get(&worktree_id)
444 .ok_or_else(|| anyhow!("worktree not found"))?;
445
446 if worktree.host_connection_id == connection_id
447 || worktree
448 .share()?
449 .guest_connection_ids
450 .contains_key(&connection_id)
451 {
452 Ok(worktree)
453 } else {
454 Err(anyhow!(
455 "{} is not a member of worktree {}",
456 connection_id,
457 worktree_id
458 ))?
459 }
460 }
461
462 fn write_worktree(
463 &mut self,
464 worktree_id: u64,
465 connection_id: ConnectionId,
466 ) -> tide::Result<&mut Worktree> {
467 let worktree = self
468 .worktrees
469 .get_mut(&worktree_id)
470 .ok_or_else(|| anyhow!("worktree not found"))?;
471
472 if worktree.host_connection_id == connection_id
473 || worktree.share.as_ref().map_or(false, |share| {
474 share.guest_connection_ids.contains_key(&connection_id)
475 })
476 {
477 Ok(worktree)
478 } else {
479 Err(anyhow!(
480 "{} is not a member of worktree {}",
481 connection_id,
482 worktree_id
483 ))?
484 }
485 }
486
487 #[cfg(test)]
488 fn check_invariants(&self) {
489 for (connection_id, connection) in &self.connections {
490 for worktree_id in &connection.worktrees {
491 let worktree = &self.worktrees.get(&worktree_id).unwrap();
492 if worktree.host_connection_id != *connection_id {
493 assert!(worktree
494 .share()
495 .unwrap()
496 .guest_connection_ids
497 .contains_key(connection_id));
498 }
499 }
500 for channel_id in &connection.channels {
501 let channel = self.channels.get(channel_id).unwrap();
502 assert!(channel.connection_ids.contains(connection_id));
503 }
504 assert!(self
505 .connections_by_user_id
506 .get(&connection.user_id)
507 .unwrap()
508 .contains(connection_id));
509 }
510
511 for (user_id, connection_ids) in &self.connections_by_user_id {
512 for connection_id in connection_ids {
513 assert_eq!(
514 self.connections.get(connection_id).unwrap().user_id,
515 *user_id
516 );
517 }
518 }
519
520 for (worktree_id, worktree) in &self.worktrees {
521 let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
522 assert!(host_connection.worktrees.contains(worktree_id));
523
524 for collaborator_id in &worktree.collaborator_user_ids {
525 let visible_worktree_ids = self
526 .visible_worktrees_by_user_id
527 .get(collaborator_id)
528 .unwrap();
529 assert!(visible_worktree_ids.contains(worktree_id));
530 }
531
532 if let Some(share) = &worktree.share {
533 for guest_connection_id in share.guest_connection_ids.keys() {
534 let guest_connection = self.connections.get(guest_connection_id).unwrap();
535 assert!(guest_connection.worktrees.contains(worktree_id));
536 }
537 assert_eq!(
538 share.active_replica_ids.len(),
539 share.guest_connection_ids.len(),
540 );
541 assert_eq!(
542 share.active_replica_ids,
543 share
544 .guest_connection_ids
545 .values()
546 .copied()
547 .collect::<HashSet<_>>(),
548 );
549 }
550 }
551
552 for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
553 for worktree_id in visible_worktree_ids {
554 let worktree = self.worktrees.get(worktree_id).unwrap();
555 assert!(worktree.collaborator_user_ids.contains(user_id));
556 }
557 }
558
559 for (channel_id, channel) in &self.channels {
560 for connection_id in &channel.connection_ids {
561 let connection = self.connections.get(connection_id).unwrap();
562 assert!(connection.channels.contains(channel_id));
563 }
564 }
565 }
566}
567
568impl Worktree {
569 pub fn connection_ids(&self) -> Vec<ConnectionId> {
570 if let Some(share) = &self.share {
571 share
572 .guest_connection_ids
573 .keys()
574 .copied()
575 .chain(Some(self.host_connection_id))
576 .collect()
577 } else {
578 vec![self.host_connection_id]
579 }
580 }
581
582 pub fn share(&self) -> tide::Result<&WorktreeShare> {
583 Ok(self
584 .share
585 .as_ref()
586 .ok_or_else(|| anyhow!("worktree is not shared"))?)
587 }
588
589 fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
590 Ok(self
591 .share
592 .as_mut()
593 .ok_or_else(|| anyhow!("worktree is not shared"))?)
594 }
595}
596
597impl Channel {
598 fn connection_ids(&self) -> Vec<ConnectionId> {
599 self.connection_ids.iter().copied().collect()
600 }
601}