1use crate::db::{ChannelId, UserId};
2use crate::errors::TideResultExt;
3use anyhow::anyhow;
4use std::collections::{hash_map, HashMap, HashSet};
5use zrpc::{proto, ConnectionId};
6
7#[derive(Default)]
8pub struct Store {
9 connections: HashMap<ConnectionId, ConnectionState>,
10 connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
11 worktrees: HashMap<u64, Worktree>,
12 visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
13 channels: HashMap<ChannelId, Channel>,
14 next_worktree_id: u64,
15}
16
17struct ConnectionState {
18 user_id: UserId,
19 worktrees: HashSet<u64>,
20 channels: HashSet<ChannelId>,
21}
22
23pub struct Worktree {
24 pub host_connection_id: ConnectionId,
25 pub collaborator_user_ids: Vec<UserId>,
26 pub root_name: String,
27 pub share: Option<WorktreeShare>,
28}
29
30pub struct WorktreeShare {
31 pub guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
32 pub active_replica_ids: HashSet<ReplicaId>,
33 pub entries: HashMap<u64, proto::Entry>,
34}
35
36#[derive(Default)]
37pub struct Channel {
38 pub connection_ids: HashSet<ConnectionId>,
39}
40
41pub type ReplicaId = u16;
42
43#[derive(Default)]
44pub struct RemovedConnectionState {
45 pub hosted_worktrees: HashMap<u64, Worktree>,
46 pub guest_worktree_ids: HashMap<u64, Vec<ConnectionId>>,
47 pub collaborator_ids: HashSet<UserId>,
48}
49
50pub struct JoinedWorktree<'a> {
51 pub replica_id: ReplicaId,
52 pub worktree: &'a Worktree,
53}
54
55pub struct UnsharedWorktree {
56 pub connection_ids: Vec<ConnectionId>,
57 pub collaborator_ids: Vec<UserId>,
58}
59
60pub struct LeftWorktree {
61 pub connection_ids: Vec<ConnectionId>,
62 pub collaborator_ids: Vec<UserId>,
63}
64
65impl Store {
66 pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
67 self.connections.insert(
68 connection_id,
69 ConnectionState {
70 user_id,
71 worktrees: Default::default(),
72 channels: Default::default(),
73 },
74 );
75 self.connections_by_user_id
76 .entry(user_id)
77 .or_default()
78 .insert(connection_id);
79 }
80
81 pub fn remove_connection(
82 &mut self,
83 connection_id: ConnectionId,
84 ) -> tide::Result<RemovedConnectionState> {
85 let connection = if let Some(connection) = self.connections.get(&connection_id) {
86 connection
87 } else {
88 return Err(anyhow!("no such connection"))?;
89 };
90
91 for channel_id in &connection.channels {
92 if let Some(channel) = self.channels.get_mut(&channel_id) {
93 channel.connection_ids.remove(&connection_id);
94 }
95 }
96
97 let user_connections = self
98 .connections_by_user_id
99 .get_mut(&connection.user_id)
100 .unwrap();
101 user_connections.remove(&connection_id);
102 if user_connections.is_empty() {
103 self.connections_by_user_id.remove(&connection.user_id);
104 }
105
106 let mut result = RemovedConnectionState::default();
107 for worktree_id in connection.worktrees.clone() {
108 if let Ok(worktree) = self.remove_worktree(worktree_id, connection_id) {
109 result
110 .collaborator_ids
111 .extend(worktree.collaborator_user_ids.iter().copied());
112 result.hosted_worktrees.insert(worktree_id, worktree);
113 } else {
114 if let Some(worktree) = self.worktrees.get(&worktree_id) {
115 result
116 .guest_worktree_ids
117 .insert(worktree_id, worktree.connection_ids());
118 result
119 .collaborator_ids
120 .extend(worktree.collaborator_user_ids.iter().copied());
121 }
122 }
123 }
124
125 Ok(result)
126 }
127
128 #[cfg(test)]
129 pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
130 self.channels.get(&id)
131 }
132
133 pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
134 if let Some(connection) = self.connections.get_mut(&connection_id) {
135 connection.channels.insert(channel_id);
136 self.channels
137 .entry(channel_id)
138 .or_default()
139 .connection_ids
140 .insert(connection_id);
141 }
142 }
143
144 pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
145 if let Some(connection) = self.connections.get_mut(&connection_id) {
146 connection.channels.remove(&channel_id);
147 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
148 entry.get_mut().connection_ids.remove(&connection_id);
149 if entry.get_mut().connection_ids.is_empty() {
150 entry.remove();
151 }
152 }
153 }
154 }
155
156 pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
157 Ok(self
158 .connections
159 .get(&connection_id)
160 .ok_or_else(|| anyhow!("unknown connection"))?
161 .user_id)
162 }
163
164 pub fn connection_ids_for_user<'a>(
165 &'a self,
166 user_id: UserId,
167 ) -> impl 'a + Iterator<Item = ConnectionId> {
168 self.connections_by_user_id
169 .get(&user_id)
170 .into_iter()
171 .flatten()
172 .copied()
173 }
174
175 pub fn collaborators_for_user(&self, user_id: UserId) -> Vec<proto::Collaborator> {
176 let mut collaborators = HashMap::new();
177 for worktree_id in self
178 .visible_worktrees_by_user_id
179 .get(&user_id)
180 .unwrap_or(&HashSet::new())
181 {
182 let worktree = &self.worktrees[worktree_id];
183
184 let mut guests = HashSet::new();
185 if let Ok(share) = worktree.share() {
186 for guest_connection_id in share.guest_connection_ids.keys() {
187 if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
188 guests.insert(user_id.to_proto());
189 }
190 }
191 }
192
193 if let Ok(host_user_id) = self
194 .user_id_for_connection(worktree.host_connection_id)
195 .context("stale worktree host connection")
196 {
197 let host =
198 collaborators
199 .entry(host_user_id)
200 .or_insert_with(|| proto::Collaborator {
201 user_id: host_user_id.to_proto(),
202 worktrees: Vec::new(),
203 });
204 host.worktrees.push(proto::WorktreeMetadata {
205 root_name: worktree.root_name.clone(),
206 is_shared: worktree.share().is_ok(),
207 participants: guests.into_iter().collect(),
208 });
209 }
210 }
211
212 collaborators.into_values().collect()
213 }
214
215 pub fn add_worktree(&mut self, worktree: Worktree) -> u64 {
216 let worktree_id = self.next_worktree_id;
217 for collaborator_user_id in &worktree.collaborator_user_ids {
218 self.visible_worktrees_by_user_id
219 .entry(*collaborator_user_id)
220 .or_default()
221 .insert(worktree_id);
222 }
223 self.next_worktree_id += 1;
224 if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
225 connection.worktrees.insert(worktree_id);
226 }
227 self.worktrees.insert(worktree_id, worktree);
228
229 #[cfg(test)]
230 self.check_invariants();
231
232 worktree_id
233 }
234
235 pub fn remove_worktree(
236 &mut self,
237 worktree_id: u64,
238 acting_connection_id: ConnectionId,
239 ) -> tide::Result<Worktree> {
240 let worktree = if let hash_map::Entry::Occupied(e) = self.worktrees.entry(worktree_id) {
241 if e.get().host_connection_id != acting_connection_id {
242 Err(anyhow!("not your worktree"))?;
243 }
244 e.remove()
245 } else {
246 return Err(anyhow!("no such worktree"))?;
247 };
248
249 if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
250 connection.worktrees.remove(&worktree_id);
251 }
252
253 if let Some(share) = &worktree.share {
254 for connection_id in share.guest_connection_ids.keys() {
255 if let Some(connection) = self.connections.get_mut(connection_id) {
256 connection.worktrees.remove(&worktree_id);
257 }
258 }
259 }
260
261 for collaborator_user_id in &worktree.collaborator_user_ids {
262 if let Some(visible_worktrees) = self
263 .visible_worktrees_by_user_id
264 .get_mut(&collaborator_user_id)
265 {
266 visible_worktrees.remove(&worktree_id);
267 }
268 }
269
270 #[cfg(test)]
271 self.check_invariants();
272
273 Ok(worktree)
274 }
275
276 pub fn share_worktree(
277 &mut self,
278 worktree_id: u64,
279 connection_id: ConnectionId,
280 entries: HashMap<u64, proto::Entry>,
281 ) -> Option<Vec<UserId>> {
282 if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
283 if worktree.host_connection_id == connection_id {
284 worktree.share = Some(WorktreeShare {
285 guest_connection_ids: Default::default(),
286 active_replica_ids: Default::default(),
287 entries,
288 });
289 return Some(worktree.collaborator_user_ids.clone());
290 }
291 }
292 None
293 }
294
295 pub fn unshare_worktree(
296 &mut self,
297 worktree_id: u64,
298 acting_connection_id: ConnectionId,
299 ) -> tide::Result<UnsharedWorktree> {
300 let worktree = if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
301 worktree
302 } else {
303 return Err(anyhow!("no such worktree"))?;
304 };
305
306 if worktree.host_connection_id != acting_connection_id {
307 return Err(anyhow!("not your worktree"))?;
308 }
309
310 let connection_ids = worktree.connection_ids();
311
312 if let Some(_) = worktree.share.take() {
313 for connection_id in &connection_ids {
314 if let Some(connection) = self.connections.get_mut(connection_id) {
315 connection.worktrees.remove(&worktree_id);
316 }
317 }
318 Ok(UnsharedWorktree {
319 connection_ids,
320 collaborator_ids: worktree.collaborator_user_ids.clone(),
321 })
322 } else {
323 Err(anyhow!("worktree is not shared"))?
324 }
325 }
326
327 pub fn join_worktree(
328 &mut self,
329 connection_id: ConnectionId,
330 user_id: UserId,
331 worktree_id: u64,
332 ) -> tide::Result<JoinedWorktree> {
333 let connection = self
334 .connections
335 .get_mut(&connection_id)
336 .ok_or_else(|| anyhow!("no such connection"))?;
337 let worktree = self
338 .worktrees
339 .get_mut(&worktree_id)
340 .and_then(|worktree| {
341 if worktree.collaborator_user_ids.contains(&user_id) {
342 Some(worktree)
343 } else {
344 None
345 }
346 })
347 .ok_or_else(|| anyhow!("no such worktree"))?;
348
349 let share = worktree.share_mut()?;
350 connection.worktrees.insert(worktree_id);
351
352 let mut replica_id = 1;
353 while share.active_replica_ids.contains(&replica_id) {
354 replica_id += 1;
355 }
356 share.active_replica_ids.insert(replica_id);
357 share.guest_connection_ids.insert(connection_id, replica_id);
358 Ok(JoinedWorktree {
359 replica_id,
360 worktree,
361 })
362 }
363
364 pub fn leave_worktree(
365 &mut self,
366 connection_id: ConnectionId,
367 worktree_id: u64,
368 ) -> Option<LeftWorktree> {
369 let worktree = self.worktrees.get_mut(&worktree_id)?;
370 let share = worktree.share.as_mut()?;
371 let replica_id = share.guest_connection_ids.remove(&connection_id)?;
372 share.active_replica_ids.remove(&replica_id);
373 Some(LeftWorktree {
374 connection_ids: worktree.connection_ids(),
375 collaborator_ids: worktree.collaborator_user_ids.clone(),
376 })
377 }
378
379 pub fn update_worktree(
380 &mut self,
381 connection_id: ConnectionId,
382 worktree_id: u64,
383 removed_entries: &[u64],
384 updated_entries: &[proto::Entry],
385 ) -> tide::Result<Vec<ConnectionId>> {
386 let worktree = self.write_worktree(worktree_id, connection_id)?;
387 let share = worktree.share_mut()?;
388 for entry_id in removed_entries {
389 share.entries.remove(&entry_id);
390 }
391 for entry in updated_entries {
392 share.entries.insert(entry.id, entry.clone());
393 }
394 Ok(worktree.connection_ids())
395 }
396
397 pub fn worktree_host_connection_id(
398 &self,
399 connection_id: ConnectionId,
400 worktree_id: u64,
401 ) -> tide::Result<ConnectionId> {
402 Ok(self
403 .read_worktree(worktree_id, connection_id)?
404 .host_connection_id)
405 }
406
407 pub fn worktree_guest_connection_ids(
408 &self,
409 connection_id: ConnectionId,
410 worktree_id: u64,
411 ) -> tide::Result<Vec<ConnectionId>> {
412 Ok(self
413 .read_worktree(worktree_id, connection_id)?
414 .share()?
415 .guest_connection_ids
416 .keys()
417 .copied()
418 .collect())
419 }
420
421 pub fn worktree_connection_ids(
422 &self,
423 connection_id: ConnectionId,
424 worktree_id: u64,
425 ) -> tide::Result<Vec<ConnectionId>> {
426 Ok(self
427 .read_worktree(worktree_id, connection_id)?
428 .connection_ids())
429 }
430
431 pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
432 Some(self.channels.get(&channel_id)?.connection_ids())
433 }
434
435 fn read_worktree(
436 &self,
437 worktree_id: u64,
438 connection_id: ConnectionId,
439 ) -> tide::Result<&Worktree> {
440 let worktree = self
441 .worktrees
442 .get(&worktree_id)
443 .ok_or_else(|| anyhow!("worktree not found"))?;
444
445 if worktree.host_connection_id == connection_id
446 || worktree
447 .share()?
448 .guest_connection_ids
449 .contains_key(&connection_id)
450 {
451 Ok(worktree)
452 } else {
453 Err(anyhow!(
454 "{} is not a member of worktree {}",
455 connection_id,
456 worktree_id
457 ))?
458 }
459 }
460
461 fn write_worktree(
462 &mut self,
463 worktree_id: u64,
464 connection_id: ConnectionId,
465 ) -> tide::Result<&mut Worktree> {
466 let worktree = self
467 .worktrees
468 .get_mut(&worktree_id)
469 .ok_or_else(|| anyhow!("worktree not found"))?;
470
471 if worktree.host_connection_id == connection_id
472 || worktree.share.as_ref().map_or(false, |share| {
473 share.guest_connection_ids.contains_key(&connection_id)
474 })
475 {
476 Ok(worktree)
477 } else {
478 Err(anyhow!(
479 "{} is not a member of worktree {}",
480 connection_id,
481 worktree_id
482 ))?
483 }
484 }
485
486 #[cfg(test)]
487 fn check_invariants(&self) {
488 for (connection_id, connection) in &self.connections {
489 for worktree_id in &connection.worktrees {
490 let worktree = &self.worktrees.get(&worktree_id).unwrap();
491 if worktree.host_connection_id != *connection_id {
492 assert!(worktree
493 .share()
494 .unwrap()
495 .guest_connection_ids
496 .contains_key(connection_id));
497 }
498 }
499 for channel_id in &connection.channels {
500 let channel = self.channels.get(channel_id).unwrap();
501 assert!(channel.connection_ids.contains(connection_id));
502 }
503 assert!(self
504 .connections_by_user_id
505 .get(&connection.user_id)
506 .unwrap()
507 .contains(connection_id));
508 }
509
510 for (user_id, connection_ids) in &self.connections_by_user_id {
511 for connection_id in connection_ids {
512 assert_eq!(
513 self.connections.get(connection_id).unwrap().user_id,
514 *user_id
515 );
516 }
517 }
518
519 for (worktree_id, worktree) in &self.worktrees {
520 let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
521 assert!(host_connection.worktrees.contains(worktree_id));
522
523 for collaborator_id in &worktree.collaborator_user_ids {
524 let visible_worktree_ids = self
525 .visible_worktrees_by_user_id
526 .get(collaborator_id)
527 .unwrap();
528 assert!(visible_worktree_ids.contains(worktree_id));
529 }
530
531 if let Some(share) = &worktree.share {
532 for guest_connection_id in share.guest_connection_ids.keys() {
533 let guest_connection = self.connections.get(guest_connection_id).unwrap();
534 assert!(guest_connection.worktrees.contains(worktree_id));
535 }
536 assert_eq!(
537 share.active_replica_ids.len(),
538 share.guest_connection_ids.len(),
539 );
540 assert_eq!(
541 share.active_replica_ids,
542 share
543 .guest_connection_ids
544 .values()
545 .copied()
546 .collect::<HashSet<_>>(),
547 );
548 }
549 }
550
551 for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
552 for worktree_id in visible_worktree_ids {
553 let worktree = self.worktrees.get(worktree_id).unwrap();
554 assert!(worktree.collaborator_user_ids.contains(user_id));
555 }
556 }
557
558 for (channel_id, channel) in &self.channels {
559 for connection_id in &channel.connection_ids {
560 let connection = self.connections.get(connection_id).unwrap();
561 assert!(connection.channels.contains(channel_id));
562 }
563 }
564 }
565}
566
567impl Worktree {
568 pub fn connection_ids(&self) -> Vec<ConnectionId> {
569 if let Some(share) = &self.share {
570 share
571 .guest_connection_ids
572 .keys()
573 .copied()
574 .chain(Some(self.host_connection_id))
575 .collect()
576 } else {
577 vec![self.host_connection_id]
578 }
579 }
580
581 pub fn share(&self) -> tide::Result<&WorktreeShare> {
582 Ok(self
583 .share
584 .as_ref()
585 .ok_or_else(|| anyhow!("worktree is not shared"))?)
586 }
587
588 fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
589 Ok(self
590 .share
591 .as_mut()
592 .ok_or_else(|| anyhow!("worktree is not shared"))?)
593 }
594}
595
596impl Channel {
597 fn connection_ids(&self) -> Vec<ConnectionId> {
598 self.connection_ids.iter().copied().collect()
599 }
600}