1use crate::db::{ChannelId, UserId};
2use anyhow::anyhow;
3use std::collections::{hash_map, HashMap, HashSet};
4use rpc::{proto, ConnectionId};
5
6#[derive(Default)]
7pub struct Store {
8 connections: HashMap<ConnectionId, ConnectionState>,
9 connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
10 worktrees: HashMap<u64, Worktree>,
11 visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
12 channels: HashMap<ChannelId, Channel>,
13 next_worktree_id: u64,
14}
15
16struct ConnectionState {
17 user_id: UserId,
18 worktrees: HashSet<u64>,
19 channels: HashSet<ChannelId>,
20}
21
22pub struct Worktree {
23 pub host_connection_id: ConnectionId,
24 pub collaborator_user_ids: Vec<UserId>,
25 pub root_name: String,
26 pub share: Option<WorktreeShare>,
27}
28
29pub struct WorktreeShare {
30 pub guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
31 pub active_replica_ids: HashSet<ReplicaId>,
32 pub entries: HashMap<u64, proto::Entry>,
33}
34
35#[derive(Default)]
36pub struct Channel {
37 pub connection_ids: HashSet<ConnectionId>,
38}
39
40pub type ReplicaId = u16;
41
42#[derive(Default)]
43pub struct RemovedConnectionState {
44 pub hosted_worktrees: HashMap<u64, Worktree>,
45 pub guest_worktree_ids: HashMap<u64, Vec<ConnectionId>>,
46 pub collaborator_ids: HashSet<UserId>,
47}
48
49pub struct JoinedWorktree<'a> {
50 pub replica_id: ReplicaId,
51 pub worktree: &'a Worktree,
52}
53
54pub struct UnsharedWorktree {
55 pub connection_ids: Vec<ConnectionId>,
56 pub collaborator_ids: Vec<UserId>,
57}
58
59pub struct LeftWorktree {
60 pub connection_ids: Vec<ConnectionId>,
61 pub collaborator_ids: Vec<UserId>,
62}
63
64impl Store {
65 pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
66 self.connections.insert(
67 connection_id,
68 ConnectionState {
69 user_id,
70 worktrees: Default::default(),
71 channels: Default::default(),
72 },
73 );
74 self.connections_by_user_id
75 .entry(user_id)
76 .or_default()
77 .insert(connection_id);
78 }
79
80 pub fn remove_connection(
81 &mut self,
82 connection_id: ConnectionId,
83 ) -> tide::Result<RemovedConnectionState> {
84 let connection = if let Some(connection) = self.connections.remove(&connection_id) {
85 connection
86 } else {
87 return Err(anyhow!("no such connection"))?;
88 };
89
90 for channel_id in &connection.channels {
91 if let Some(channel) = self.channels.get_mut(&channel_id) {
92 channel.connection_ids.remove(&connection_id);
93 }
94 }
95
96 let user_connections = self
97 .connections_by_user_id
98 .get_mut(&connection.user_id)
99 .unwrap();
100 user_connections.remove(&connection_id);
101 if user_connections.is_empty() {
102 self.connections_by_user_id.remove(&connection.user_id);
103 }
104
105 let mut result = RemovedConnectionState::default();
106 for worktree_id in connection.worktrees.clone() {
107 if let Ok(worktree) = self.remove_worktree(worktree_id, connection_id) {
108 result
109 .collaborator_ids
110 .extend(worktree.collaborator_user_ids.iter().copied());
111 result.hosted_worktrees.insert(worktree_id, worktree);
112 } else if let Some(worktree) = self.leave_worktree(connection_id, worktree_id) {
113 result
114 .guest_worktree_ids
115 .insert(worktree_id, worktree.connection_ids);
116 result.collaborator_ids.extend(worktree.collaborator_ids);
117 }
118 }
119
120 #[cfg(test)]
121 self.check_invariants();
122
123 Ok(result)
124 }
125
126 #[cfg(test)]
127 pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
128 self.channels.get(&id)
129 }
130
131 pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
132 if let Some(connection) = self.connections.get_mut(&connection_id) {
133 connection.channels.insert(channel_id);
134 self.channels
135 .entry(channel_id)
136 .or_default()
137 .connection_ids
138 .insert(connection_id);
139 }
140 }
141
142 pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
143 if let Some(connection) = self.connections.get_mut(&connection_id) {
144 connection.channels.remove(&channel_id);
145 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
146 entry.get_mut().connection_ids.remove(&connection_id);
147 if entry.get_mut().connection_ids.is_empty() {
148 entry.remove();
149 }
150 }
151 }
152 }
153
154 pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
155 Ok(self
156 .connections
157 .get(&connection_id)
158 .ok_or_else(|| anyhow!("unknown connection"))?
159 .user_id)
160 }
161
162 pub fn connection_ids_for_user<'a>(
163 &'a self,
164 user_id: UserId,
165 ) -> impl 'a + Iterator<Item = ConnectionId> {
166 self.connections_by_user_id
167 .get(&user_id)
168 .into_iter()
169 .flatten()
170 .copied()
171 }
172
173 pub fn collaborators_for_user(&self, user_id: UserId) -> Vec<proto::Collaborator> {
174 let mut collaborators = HashMap::new();
175 for worktree_id in self
176 .visible_worktrees_by_user_id
177 .get(&user_id)
178 .unwrap_or(&HashSet::new())
179 {
180 let worktree = &self.worktrees[worktree_id];
181
182 let mut guests = HashSet::new();
183 if let Ok(share) = worktree.share() {
184 for guest_connection_id in share.guest_connection_ids.keys() {
185 if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
186 guests.insert(user_id.to_proto());
187 }
188 }
189 }
190
191 if let Ok(host_user_id) = self.user_id_for_connection(worktree.host_connection_id) {
192 collaborators
193 .entry(host_user_id)
194 .or_insert_with(|| proto::Collaborator {
195 user_id: host_user_id.to_proto(),
196 worktrees: Vec::new(),
197 })
198 .worktrees
199 .push(proto::WorktreeMetadata {
200 id: *worktree_id,
201 root_name: worktree.root_name.clone(),
202 is_shared: worktree.share.is_some(),
203 guests: guests.into_iter().collect(),
204 });
205 }
206 }
207
208 collaborators.into_values().collect()
209 }
210
211 pub fn add_worktree(&mut self, worktree: Worktree) -> u64 {
212 let worktree_id = self.next_worktree_id;
213 for collaborator_user_id in &worktree.collaborator_user_ids {
214 self.visible_worktrees_by_user_id
215 .entry(*collaborator_user_id)
216 .or_default()
217 .insert(worktree_id);
218 }
219 self.next_worktree_id += 1;
220 if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
221 connection.worktrees.insert(worktree_id);
222 }
223 self.worktrees.insert(worktree_id, worktree);
224
225 #[cfg(test)]
226 self.check_invariants();
227
228 worktree_id
229 }
230
231 pub fn remove_worktree(
232 &mut self,
233 worktree_id: u64,
234 acting_connection_id: ConnectionId,
235 ) -> tide::Result<Worktree> {
236 let worktree = if let hash_map::Entry::Occupied(e) = self.worktrees.entry(worktree_id) {
237 if e.get().host_connection_id != acting_connection_id {
238 Err(anyhow!("not your worktree"))?;
239 }
240 e.remove()
241 } else {
242 return Err(anyhow!("no such worktree"))?;
243 };
244
245 if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
246 connection.worktrees.remove(&worktree_id);
247 }
248
249 if let Some(share) = &worktree.share {
250 for connection_id in share.guest_connection_ids.keys() {
251 if let Some(connection) = self.connections.get_mut(connection_id) {
252 connection.worktrees.remove(&worktree_id);
253 }
254 }
255 }
256
257 for collaborator_user_id in &worktree.collaborator_user_ids {
258 if let Some(visible_worktrees) = self
259 .visible_worktrees_by_user_id
260 .get_mut(&collaborator_user_id)
261 {
262 visible_worktrees.remove(&worktree_id);
263 }
264 }
265
266 #[cfg(test)]
267 self.check_invariants();
268
269 Ok(worktree)
270 }
271
272 pub fn share_worktree(
273 &mut self,
274 worktree_id: u64,
275 connection_id: ConnectionId,
276 entries: HashMap<u64, proto::Entry>,
277 ) -> Option<Vec<UserId>> {
278 if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
279 if worktree.host_connection_id == connection_id {
280 worktree.share = Some(WorktreeShare {
281 guest_connection_ids: Default::default(),
282 active_replica_ids: Default::default(),
283 entries,
284 });
285 return Some(worktree.collaborator_user_ids.clone());
286 }
287 }
288 None
289 }
290
291 pub fn unshare_worktree(
292 &mut self,
293 worktree_id: u64,
294 acting_connection_id: ConnectionId,
295 ) -> tide::Result<UnsharedWorktree> {
296 let worktree = if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
297 worktree
298 } else {
299 return Err(anyhow!("no such worktree"))?;
300 };
301
302 if worktree.host_connection_id != acting_connection_id {
303 return Err(anyhow!("not your worktree"))?;
304 }
305
306 let connection_ids = worktree.connection_ids();
307 let collaborator_ids = worktree.collaborator_user_ids.clone();
308 if let Some(share) = worktree.share.take() {
309 for connection_id in share.guest_connection_ids.into_keys() {
310 if let Some(connection) = self.connections.get_mut(&connection_id) {
311 connection.worktrees.remove(&worktree_id);
312 }
313 }
314
315 #[cfg(test)]
316 self.check_invariants();
317
318 Ok(UnsharedWorktree {
319 connection_ids,
320 collaborator_ids,
321 })
322 } else {
323 Err(anyhow!("worktree is not shared"))?
324 }
325 }
326
327 pub fn join_worktree(
328 &mut self,
329 connection_id: ConnectionId,
330 user_id: UserId,
331 worktree_id: u64,
332 ) -> tide::Result<JoinedWorktree> {
333 let connection = self
334 .connections
335 .get_mut(&connection_id)
336 .ok_or_else(|| anyhow!("no such connection"))?;
337 let worktree = self
338 .worktrees
339 .get_mut(&worktree_id)
340 .and_then(|worktree| {
341 if worktree.collaborator_user_ids.contains(&user_id) {
342 Some(worktree)
343 } else {
344 None
345 }
346 })
347 .ok_or_else(|| anyhow!("no such worktree"))?;
348
349 let share = worktree.share_mut()?;
350 connection.worktrees.insert(worktree_id);
351
352 let mut replica_id = 1;
353 while share.active_replica_ids.contains(&replica_id) {
354 replica_id += 1;
355 }
356 share.active_replica_ids.insert(replica_id);
357 share.guest_connection_ids.insert(connection_id, replica_id);
358
359 #[cfg(test)]
360 self.check_invariants();
361
362 Ok(JoinedWorktree {
363 replica_id,
364 worktree: &self.worktrees[&worktree_id],
365 })
366 }
367
368 pub fn leave_worktree(
369 &mut self,
370 connection_id: ConnectionId,
371 worktree_id: u64,
372 ) -> Option<LeftWorktree> {
373 let worktree = self.worktrees.get_mut(&worktree_id)?;
374 let share = worktree.share.as_mut()?;
375 let replica_id = share.guest_connection_ids.remove(&connection_id)?;
376 share.active_replica_ids.remove(&replica_id);
377
378 if let Some(connection) = self.connections.get_mut(&connection_id) {
379 connection.worktrees.remove(&worktree_id);
380 }
381
382 let connection_ids = worktree.connection_ids();
383 let collaborator_ids = worktree.collaborator_user_ids.clone();
384
385 #[cfg(test)]
386 self.check_invariants();
387
388 Some(LeftWorktree {
389 connection_ids,
390 collaborator_ids,
391 })
392 }
393
394 pub fn update_worktree(
395 &mut self,
396 connection_id: ConnectionId,
397 worktree_id: u64,
398 removed_entries: &[u64],
399 updated_entries: &[proto::Entry],
400 ) -> tide::Result<Vec<ConnectionId>> {
401 let worktree = self.write_worktree(worktree_id, connection_id)?;
402 let share = worktree.share_mut()?;
403 for entry_id in removed_entries {
404 share.entries.remove(&entry_id);
405 }
406 for entry in updated_entries {
407 share.entries.insert(entry.id, entry.clone());
408 }
409 Ok(worktree.connection_ids())
410 }
411
412 pub fn worktree_host_connection_id(
413 &self,
414 connection_id: ConnectionId,
415 worktree_id: u64,
416 ) -> tide::Result<ConnectionId> {
417 Ok(self
418 .read_worktree(worktree_id, connection_id)?
419 .host_connection_id)
420 }
421
422 pub fn worktree_guest_connection_ids(
423 &self,
424 connection_id: ConnectionId,
425 worktree_id: u64,
426 ) -> tide::Result<Vec<ConnectionId>> {
427 Ok(self
428 .read_worktree(worktree_id, connection_id)?
429 .share()?
430 .guest_connection_ids
431 .keys()
432 .copied()
433 .collect())
434 }
435
436 pub fn worktree_connection_ids(
437 &self,
438 connection_id: ConnectionId,
439 worktree_id: u64,
440 ) -> tide::Result<Vec<ConnectionId>> {
441 Ok(self
442 .read_worktree(worktree_id, connection_id)?
443 .connection_ids())
444 }
445
446 pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
447 Some(self.channels.get(&channel_id)?.connection_ids())
448 }
449
450 fn read_worktree(
451 &self,
452 worktree_id: u64,
453 connection_id: ConnectionId,
454 ) -> tide::Result<&Worktree> {
455 let worktree = self
456 .worktrees
457 .get(&worktree_id)
458 .ok_or_else(|| anyhow!("worktree not found"))?;
459
460 if worktree.host_connection_id == connection_id
461 || worktree
462 .share()?
463 .guest_connection_ids
464 .contains_key(&connection_id)
465 {
466 Ok(worktree)
467 } else {
468 Err(anyhow!(
469 "{} is not a member of worktree {}",
470 connection_id,
471 worktree_id
472 ))?
473 }
474 }
475
476 fn write_worktree(
477 &mut self,
478 worktree_id: u64,
479 connection_id: ConnectionId,
480 ) -> tide::Result<&mut Worktree> {
481 let worktree = self
482 .worktrees
483 .get_mut(&worktree_id)
484 .ok_or_else(|| anyhow!("worktree not found"))?;
485
486 if worktree.host_connection_id == connection_id
487 || worktree.share.as_ref().map_or(false, |share| {
488 share.guest_connection_ids.contains_key(&connection_id)
489 })
490 {
491 Ok(worktree)
492 } else {
493 Err(anyhow!(
494 "{} is not a member of worktree {}",
495 connection_id,
496 worktree_id
497 ))?
498 }
499 }
500
501 #[cfg(test)]
502 fn check_invariants(&self) {
503 for (connection_id, connection) in &self.connections {
504 for worktree_id in &connection.worktrees {
505 let worktree = &self.worktrees.get(&worktree_id).unwrap();
506 if worktree.host_connection_id != *connection_id {
507 assert!(worktree
508 .share()
509 .unwrap()
510 .guest_connection_ids
511 .contains_key(connection_id));
512 }
513 }
514 for channel_id in &connection.channels {
515 let channel = self.channels.get(channel_id).unwrap();
516 assert!(channel.connection_ids.contains(connection_id));
517 }
518 assert!(self
519 .connections_by_user_id
520 .get(&connection.user_id)
521 .unwrap()
522 .contains(connection_id));
523 }
524
525 for (user_id, connection_ids) in &self.connections_by_user_id {
526 for connection_id in connection_ids {
527 assert_eq!(
528 self.connections.get(connection_id).unwrap().user_id,
529 *user_id
530 );
531 }
532 }
533
534 for (worktree_id, worktree) in &self.worktrees {
535 let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
536 assert!(host_connection.worktrees.contains(worktree_id));
537
538 for collaborator_id in &worktree.collaborator_user_ids {
539 let visible_worktree_ids = self
540 .visible_worktrees_by_user_id
541 .get(collaborator_id)
542 .unwrap();
543 assert!(visible_worktree_ids.contains(worktree_id));
544 }
545
546 if let Some(share) = &worktree.share {
547 for guest_connection_id in share.guest_connection_ids.keys() {
548 let guest_connection = self.connections.get(guest_connection_id).unwrap();
549 assert!(guest_connection.worktrees.contains(worktree_id));
550 }
551 assert_eq!(
552 share.active_replica_ids.len(),
553 share.guest_connection_ids.len(),
554 );
555 assert_eq!(
556 share.active_replica_ids,
557 share
558 .guest_connection_ids
559 .values()
560 .copied()
561 .collect::<HashSet<_>>(),
562 );
563 }
564 }
565
566 for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
567 for worktree_id in visible_worktree_ids {
568 let worktree = self.worktrees.get(worktree_id).unwrap();
569 assert!(worktree.collaborator_user_ids.contains(user_id));
570 }
571 }
572
573 for (channel_id, channel) in &self.channels {
574 for connection_id in &channel.connection_ids {
575 let connection = self.connections.get(connection_id).unwrap();
576 assert!(connection.channels.contains(channel_id));
577 }
578 }
579 }
580}
581
582impl Worktree {
583 pub fn connection_ids(&self) -> Vec<ConnectionId> {
584 if let Some(share) = &self.share {
585 share
586 .guest_connection_ids
587 .keys()
588 .copied()
589 .chain(Some(self.host_connection_id))
590 .collect()
591 } else {
592 vec![self.host_connection_id]
593 }
594 }
595
596 pub fn share(&self) -> tide::Result<&WorktreeShare> {
597 Ok(self
598 .share
599 .as_ref()
600 .ok_or_else(|| anyhow!("worktree is not shared"))?)
601 }
602
603 fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
604 Ok(self
605 .share
606 .as_mut()
607 .ok_or_else(|| anyhow!("worktree is not shared"))?)
608 }
609}
610
611impl Channel {
612 fn connection_ids(&self) -> Vec<ConnectionId> {
613 self.connection_ids.iter().copied().collect()
614 }
615}