1use crate::db::{ChannelId, UserId};
2use crate::errors::TideResultExt;
3use anyhow::anyhow;
4use std::collections::{hash_map, HashMap, HashSet};
5use zrpc::{proto, ConnectionId};
6
7#[derive(Default)]
8pub struct Store {
9 connections: HashMap<ConnectionId, ConnectionState>,
10 connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
11 worktrees: HashMap<u64, Worktree>,
12 visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
13 channels: HashMap<ChannelId, Channel>,
14 next_worktree_id: u64,
15}
16
17struct ConnectionState {
18 user_id: UserId,
19 worktrees: HashSet<u64>,
20 channels: HashSet<ChannelId>,
21}
22
23pub struct Worktree {
24 pub host_connection_id: ConnectionId,
25 pub collaborator_user_ids: Vec<UserId>,
26 pub root_name: String,
27 pub share: Option<WorktreeShare>,
28}
29
30pub struct WorktreeShare {
31 pub guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
32 pub active_replica_ids: HashSet<ReplicaId>,
33 pub entries: HashMap<u64, proto::Entry>,
34}
35
36#[derive(Default)]
37pub struct Channel {
38 pub connection_ids: HashSet<ConnectionId>,
39}
40
41pub type ReplicaId = u16;
42
43#[derive(Default)]
44pub struct RemovedConnectionState {
45 pub hosted_worktrees: HashMap<u64, Worktree>,
46 pub guest_worktree_ids: HashMap<u64, Vec<ConnectionId>>,
47 pub collaborator_ids: HashSet<UserId>,
48}
49
50pub struct JoinedWorktree<'a> {
51 pub replica_id: ReplicaId,
52 pub worktree: &'a Worktree,
53}
54
55pub struct WorktreeMetadata {
56 pub connection_ids: Vec<ConnectionId>,
57 pub collaborator_ids: Vec<UserId>,
58}
59
60impl Store {
61 pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
62 self.connections.insert(
63 connection_id,
64 ConnectionState {
65 user_id,
66 worktrees: Default::default(),
67 channels: Default::default(),
68 },
69 );
70 self.connections_by_user_id
71 .entry(user_id)
72 .or_default()
73 .insert(connection_id);
74 }
75
76 pub fn remove_connection(
77 &mut self,
78 connection_id: ConnectionId,
79 ) -> tide::Result<RemovedConnectionState> {
80 let connection = if let Some(connection) = self.connections.get(&connection_id) {
81 connection
82 } else {
83 return Err(anyhow!("no such connection"))?;
84 };
85
86 for channel_id in &connection.channels {
87 if let Some(channel) = self.channels.get_mut(&channel_id) {
88 channel.connection_ids.remove(&connection_id);
89 }
90 }
91
92 let user_connections = self
93 .connections_by_user_id
94 .get_mut(&connection.user_id)
95 .unwrap();
96 user_connections.remove(&connection_id);
97 if user_connections.is_empty() {
98 self.connections_by_user_id.remove(&connection.user_id);
99 }
100
101 let mut result = RemovedConnectionState::default();
102 for worktree_id in connection.worktrees.clone() {
103 if let Ok(worktree) = self.remove_worktree(worktree_id, connection_id) {
104 result
105 .collaborator_ids
106 .extend(worktree.collaborator_user_ids.iter().copied());
107 result.hosted_worktrees.insert(worktree_id, worktree);
108 } else {
109 if let Some(worktree) = self.worktrees.get(&worktree_id) {
110 result
111 .guest_worktree_ids
112 .insert(worktree_id, worktree.connection_ids());
113 result
114 .collaborator_ids
115 .extend(worktree.collaborator_user_ids.iter().copied());
116 }
117 }
118 }
119
120 Ok(result)
121 }
122
123 #[cfg(test)]
124 pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
125 self.channels.get(&id)
126 }
127
128 pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
129 if let Some(connection) = self.connections.get_mut(&connection_id) {
130 connection.channels.insert(channel_id);
131 self.channels
132 .entry(channel_id)
133 .or_default()
134 .connection_ids
135 .insert(connection_id);
136 }
137 }
138
139 pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
140 if let Some(connection) = self.connections.get_mut(&connection_id) {
141 connection.channels.remove(&channel_id);
142 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
143 entry.get_mut().connection_ids.remove(&connection_id);
144 if entry.get_mut().connection_ids.is_empty() {
145 entry.remove();
146 }
147 }
148 }
149 }
150
151 pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
152 Ok(self
153 .connections
154 .get(&connection_id)
155 .ok_or_else(|| anyhow!("unknown connection"))?
156 .user_id)
157 }
158
159 pub fn connection_ids_for_user<'a>(
160 &'a self,
161 user_id: UserId,
162 ) -> impl 'a + Iterator<Item = ConnectionId> {
163 self.connections_by_user_id
164 .get(&user_id)
165 .into_iter()
166 .flatten()
167 .copied()
168 }
169
170 pub fn collaborators_for_user(&self, user_id: UserId) -> Vec<proto::Collaborator> {
171 let mut collaborators = HashMap::new();
172 for worktree_id in self
173 .visible_worktrees_by_user_id
174 .get(&user_id)
175 .unwrap_or(&HashSet::new())
176 {
177 let worktree = &self.worktrees[worktree_id];
178
179 let mut guests = HashSet::new();
180 if let Ok(share) = worktree.share() {
181 for guest_connection_id in share.guest_connection_ids.keys() {
182 if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
183 guests.insert(user_id.to_proto());
184 }
185 }
186 }
187
188 if let Ok(host_user_id) = self
189 .user_id_for_connection(worktree.host_connection_id)
190 .context("stale worktree host connection")
191 {
192 let host =
193 collaborators
194 .entry(host_user_id)
195 .or_insert_with(|| proto::Collaborator {
196 user_id: host_user_id.to_proto(),
197 worktrees: Vec::new(),
198 });
199 host.worktrees.push(proto::WorktreeMetadata {
200 root_name: worktree.root_name.clone(),
201 is_shared: worktree.share().is_ok(),
202 participants: guests.into_iter().collect(),
203 });
204 }
205 }
206
207 collaborators.into_values().collect()
208 }
209
210 pub fn add_worktree(&mut self, worktree: Worktree) -> u64 {
211 let worktree_id = self.next_worktree_id;
212 for collaborator_user_id in &worktree.collaborator_user_ids {
213 self.visible_worktrees_by_user_id
214 .entry(*collaborator_user_id)
215 .or_default()
216 .insert(worktree_id);
217 }
218 self.next_worktree_id += 1;
219 if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
220 connection.worktrees.insert(worktree_id);
221 }
222 self.worktrees.insert(worktree_id, worktree);
223
224 #[cfg(test)]
225 self.check_invariants();
226
227 worktree_id
228 }
229
230 pub fn remove_worktree(
231 &mut self,
232 worktree_id: u64,
233 acting_connection_id: ConnectionId,
234 ) -> tide::Result<Worktree> {
235 let worktree = if let hash_map::Entry::Occupied(e) = self.worktrees.entry(worktree_id) {
236 if e.get().host_connection_id != acting_connection_id {
237 Err(anyhow!("not your worktree"))?;
238 }
239 e.remove()
240 } else {
241 return Err(anyhow!("no such worktree"))?;
242 };
243
244 if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
245 connection.worktrees.remove(&worktree_id);
246 }
247
248 if let Some(share) = &worktree.share {
249 for connection_id in share.guest_connection_ids.keys() {
250 if let Some(connection) = self.connections.get_mut(connection_id) {
251 connection.worktrees.remove(&worktree_id);
252 }
253 }
254 }
255
256 for collaborator_user_id in &worktree.collaborator_user_ids {
257 if let Some(visible_worktrees) = self
258 .visible_worktrees_by_user_id
259 .get_mut(&collaborator_user_id)
260 {
261 visible_worktrees.remove(&worktree_id);
262 }
263 }
264
265 #[cfg(test)]
266 self.check_invariants();
267
268 Ok(worktree)
269 }
270
271 pub fn share_worktree(
272 &mut self,
273 worktree_id: u64,
274 connection_id: ConnectionId,
275 entries: HashMap<u64, proto::Entry>,
276 ) -> Option<Vec<UserId>> {
277 if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
278 if worktree.host_connection_id == connection_id {
279 worktree.share = Some(WorktreeShare {
280 guest_connection_ids: Default::default(),
281 active_replica_ids: Default::default(),
282 entries,
283 });
284 return Some(worktree.collaborator_user_ids.clone());
285 }
286 }
287 None
288 }
289
290 pub fn unshare_worktree(
291 &mut self,
292 worktree_id: u64,
293 acting_connection_id: ConnectionId,
294 ) -> tide::Result<WorktreeMetadata> {
295 let worktree = if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
296 worktree
297 } else {
298 return Err(anyhow!("no such worktree"))?;
299 };
300
301 if worktree.host_connection_id != acting_connection_id {
302 return Err(anyhow!("not your worktree"))?;
303 }
304
305 let connection_ids = worktree.connection_ids();
306
307 if let Some(_) = worktree.share.take() {
308 for connection_id in &connection_ids {
309 if let Some(connection) = self.connections.get_mut(connection_id) {
310 connection.worktrees.remove(&worktree_id);
311 }
312 }
313 Ok(WorktreeMetadata {
314 connection_ids,
315 collaborator_ids: worktree.collaborator_user_ids.clone(),
316 })
317 } else {
318 Err(anyhow!("worktree is not shared"))?
319 }
320 }
321
322 pub fn join_worktree(
323 &mut self,
324 connection_id: ConnectionId,
325 user_id: UserId,
326 worktree_id: u64,
327 ) -> tide::Result<JoinedWorktree> {
328 let connection = self
329 .connections
330 .get_mut(&connection_id)
331 .ok_or_else(|| anyhow!("no such connection"))?;
332 let worktree = self
333 .worktrees
334 .get_mut(&worktree_id)
335 .and_then(|worktree| {
336 if worktree.collaborator_user_ids.contains(&user_id) {
337 Some(worktree)
338 } else {
339 None
340 }
341 })
342 .ok_or_else(|| anyhow!("no such worktree"))?;
343
344 let share = worktree.share_mut()?;
345 connection.worktrees.insert(worktree_id);
346
347 let mut replica_id = 1;
348 while share.active_replica_ids.contains(&replica_id) {
349 replica_id += 1;
350 }
351 share.active_replica_ids.insert(replica_id);
352 share.guest_connection_ids.insert(connection_id, replica_id);
353 Ok(JoinedWorktree {
354 replica_id,
355 worktree,
356 })
357 }
358
359 pub fn leave_worktree(
360 &mut self,
361 connection_id: ConnectionId,
362 worktree_id: u64,
363 ) -> Option<WorktreeMetadata> {
364 let worktree = self.worktrees.get_mut(&worktree_id)?;
365 let share = worktree.share.as_mut()?;
366 let replica_id = share.guest_connection_ids.remove(&connection_id)?;
367 share.active_replica_ids.remove(&replica_id);
368 Some(WorktreeMetadata {
369 connection_ids: worktree.connection_ids(),
370 collaborator_ids: worktree.collaborator_user_ids.clone(),
371 })
372 }
373
374 pub fn update_worktree(
375 &mut self,
376 connection_id: ConnectionId,
377 worktree_id: u64,
378 removed_entries: &[u64],
379 updated_entries: &[proto::Entry],
380 ) -> tide::Result<Vec<ConnectionId>> {
381 let worktree = self.write_worktree(worktree_id, connection_id)?;
382 let share = worktree.share_mut()?;
383 for entry_id in removed_entries {
384 share.entries.remove(&entry_id);
385 }
386 for entry in updated_entries {
387 share.entries.insert(entry.id, entry.clone());
388 }
389 Ok(worktree.connection_ids())
390 }
391
392 pub fn worktree_host_connection_id(
393 &self,
394 connection_id: ConnectionId,
395 worktree_id: u64,
396 ) -> tide::Result<ConnectionId> {
397 Ok(self
398 .read_worktree(worktree_id, connection_id)?
399 .host_connection_id)
400 }
401
402 pub fn worktree_guest_connection_ids(
403 &self,
404 connection_id: ConnectionId,
405 worktree_id: u64,
406 ) -> tide::Result<Vec<ConnectionId>> {
407 Ok(self
408 .read_worktree(worktree_id, connection_id)?
409 .share()?
410 .guest_connection_ids
411 .keys()
412 .copied()
413 .collect())
414 }
415
416 pub fn worktree_connection_ids(
417 &self,
418 connection_id: ConnectionId,
419 worktree_id: u64,
420 ) -> tide::Result<Vec<ConnectionId>> {
421 Ok(self
422 .read_worktree(worktree_id, connection_id)?
423 .connection_ids())
424 }
425
426 pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
427 Some(self.channels.get(&channel_id)?.connection_ids())
428 }
429
430 fn read_worktree(
431 &self,
432 worktree_id: u64,
433 connection_id: ConnectionId,
434 ) -> tide::Result<&Worktree> {
435 let worktree = self
436 .worktrees
437 .get(&worktree_id)
438 .ok_or_else(|| anyhow!("worktree not found"))?;
439
440 if worktree.host_connection_id == connection_id
441 || worktree
442 .share()?
443 .guest_connection_ids
444 .contains_key(&connection_id)
445 {
446 Ok(worktree)
447 } else {
448 Err(anyhow!(
449 "{} is not a member of worktree {}",
450 connection_id,
451 worktree_id
452 ))?
453 }
454 }
455
456 fn write_worktree(
457 &mut self,
458 worktree_id: u64,
459 connection_id: ConnectionId,
460 ) -> tide::Result<&mut Worktree> {
461 let worktree = self
462 .worktrees
463 .get_mut(&worktree_id)
464 .ok_or_else(|| anyhow!("worktree not found"))?;
465
466 if worktree.host_connection_id == connection_id
467 || worktree.share.as_ref().map_or(false, |share| {
468 share.guest_connection_ids.contains_key(&connection_id)
469 })
470 {
471 Ok(worktree)
472 } else {
473 Err(anyhow!(
474 "{} is not a member of worktree {}",
475 connection_id,
476 worktree_id
477 ))?
478 }
479 }
480
481 #[cfg(test)]
482 fn check_invariants(&self) {
483 for (connection_id, connection) in &self.connections {
484 for worktree_id in &connection.worktrees {
485 let worktree = &self.worktrees.get(&worktree_id).unwrap();
486 if worktree.host_connection_id != *connection_id {
487 assert!(worktree
488 .share()
489 .unwrap()
490 .guest_connection_ids
491 .contains_key(connection_id));
492 }
493 }
494 for channel_id in &connection.channels {
495 let channel = self.channels.get(channel_id).unwrap();
496 assert!(channel.connection_ids.contains(connection_id));
497 }
498 assert!(self
499 .connections_by_user_id
500 .get(&connection.user_id)
501 .unwrap()
502 .contains(connection_id));
503 }
504
505 for (user_id, connection_ids) in &self.connections_by_user_id {
506 for connection_id in connection_ids {
507 assert_eq!(
508 self.connections.get(connection_id).unwrap().user_id,
509 *user_id
510 );
511 }
512 }
513
514 for (worktree_id, worktree) in &self.worktrees {
515 let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
516 assert!(host_connection.worktrees.contains(worktree_id));
517
518 for collaborator_id in &worktree.collaborator_user_ids {
519 let visible_worktree_ids = self
520 .visible_worktrees_by_user_id
521 .get(collaborator_id)
522 .unwrap();
523 assert!(visible_worktree_ids.contains(worktree_id));
524 }
525
526 if let Some(share) = &worktree.share {
527 for guest_connection_id in share.guest_connection_ids.keys() {
528 let guest_connection = self.connections.get(guest_connection_id).unwrap();
529 assert!(guest_connection.worktrees.contains(worktree_id));
530 }
531 assert_eq!(
532 share.active_replica_ids.len(),
533 share.guest_connection_ids.len(),
534 );
535 assert_eq!(
536 share.active_replica_ids,
537 share
538 .guest_connection_ids
539 .values()
540 .copied()
541 .collect::<HashSet<_>>(),
542 );
543 }
544 }
545
546 for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
547 for worktree_id in visible_worktree_ids {
548 let worktree = self.worktrees.get(worktree_id).unwrap();
549 assert!(worktree.collaborator_user_ids.contains(user_id));
550 }
551 }
552
553 for (channel_id, channel) in &self.channels {
554 for connection_id in &channel.connection_ids {
555 let connection = self.connections.get(connection_id).unwrap();
556 assert!(connection.channels.contains(channel_id));
557 }
558 }
559 }
560}
561
562impl Worktree {
563 pub fn connection_ids(&self) -> Vec<ConnectionId> {
564 if let Some(share) = &self.share {
565 share
566 .guest_connection_ids
567 .keys()
568 .copied()
569 .chain(Some(self.host_connection_id))
570 .collect()
571 } else {
572 vec![self.host_connection_id]
573 }
574 }
575
576 pub fn share(&self) -> tide::Result<&WorktreeShare> {
577 Ok(self
578 .share
579 .as_ref()
580 .ok_or_else(|| anyhow!("worktree is not shared"))?)
581 }
582
583 fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
584 Ok(self
585 .share
586 .as_mut()
587 .ok_or_else(|| anyhow!("worktree is not shared"))?)
588 }
589}
590
591impl Channel {
592 fn connection_ids(&self) -> Vec<ConnectionId> {
593 self.connection_ids.iter().copied().collect()
594 }
595}