1use crate::db::{ChannelId, UserId};
2use anyhow::anyhow;
3use rpc::{proto, ConnectionId};
4use std::collections::{hash_map, HashMap, HashSet};
5
6#[derive(Default)]
7pub struct Store {
8 connections: HashMap<ConnectionId, ConnectionState>,
9 connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
10 worktrees: HashMap<u64, Worktree>,
11 visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
12 channels: HashMap<ChannelId, Channel>,
13 next_worktree_id: u64,
14}
15
16struct ConnectionState {
17 user_id: UserId,
18 worktrees: HashSet<u64>,
19 channels: HashSet<ChannelId>,
20}
21
22pub struct Worktree {
23 pub host_connection_id: ConnectionId,
24 pub host_user_id: UserId,
25 pub authorized_user_ids: Vec<UserId>,
26 pub root_name: String,
27 pub share: Option<WorktreeShare>,
28}
29
30pub struct WorktreeShare {
31 pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
32 pub active_replica_ids: HashSet<ReplicaId>,
33 pub entries: HashMap<u64, proto::Entry>,
34}
35
36#[derive(Default)]
37pub struct Channel {
38 pub connection_ids: HashSet<ConnectionId>,
39}
40
41pub type ReplicaId = u16;
42
43#[derive(Default)]
44pub struct RemovedConnectionState {
45 pub hosted_worktrees: HashMap<u64, Worktree>,
46 pub guest_worktree_ids: HashMap<u64, Vec<ConnectionId>>,
47 pub contact_ids: HashSet<UserId>,
48}
49
50pub struct JoinedWorktree<'a> {
51 pub replica_id: ReplicaId,
52 pub worktree: &'a Worktree,
53}
54
55pub struct UnsharedWorktree {
56 pub connection_ids: Vec<ConnectionId>,
57 pub authorized_user_ids: Vec<UserId>,
58}
59
60pub struct LeftWorktree {
61 pub connection_ids: Vec<ConnectionId>,
62 pub authorized_user_ids: Vec<UserId>,
63}
64
65impl Store {
66 pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
67 self.connections.insert(
68 connection_id,
69 ConnectionState {
70 user_id,
71 worktrees: Default::default(),
72 channels: Default::default(),
73 },
74 );
75 self.connections_by_user_id
76 .entry(user_id)
77 .or_default()
78 .insert(connection_id);
79 }
80
81 pub fn remove_connection(
82 &mut self,
83 connection_id: ConnectionId,
84 ) -> tide::Result<RemovedConnectionState> {
85 let connection = if let Some(connection) = self.connections.remove(&connection_id) {
86 connection
87 } else {
88 return Err(anyhow!("no such connection"))?;
89 };
90
91 for channel_id in &connection.channels {
92 if let Some(channel) = self.channels.get_mut(&channel_id) {
93 channel.connection_ids.remove(&connection_id);
94 }
95 }
96
97 let user_connections = self
98 .connections_by_user_id
99 .get_mut(&connection.user_id)
100 .unwrap();
101 user_connections.remove(&connection_id);
102 if user_connections.is_empty() {
103 self.connections_by_user_id.remove(&connection.user_id);
104 }
105
106 let mut result = RemovedConnectionState::default();
107 for worktree_id in connection.worktrees.clone() {
108 if let Ok(worktree) = self.remove_worktree(worktree_id, connection_id) {
109 result
110 .contact_ids
111 .extend(worktree.authorized_user_ids.iter().copied());
112 result.hosted_worktrees.insert(worktree_id, worktree);
113 } else if let Some(worktree) = self.leave_worktree(connection_id, worktree_id) {
114 result
115 .guest_worktree_ids
116 .insert(worktree_id, worktree.connection_ids);
117 result.contact_ids.extend(worktree.authorized_user_ids);
118 }
119 }
120
121 #[cfg(test)]
122 self.check_invariants();
123
124 Ok(result)
125 }
126
127 #[cfg(test)]
128 pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
129 self.channels.get(&id)
130 }
131
132 pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
133 if let Some(connection) = self.connections.get_mut(&connection_id) {
134 connection.channels.insert(channel_id);
135 self.channels
136 .entry(channel_id)
137 .or_default()
138 .connection_ids
139 .insert(connection_id);
140 }
141 }
142
143 pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
144 if let Some(connection) = self.connections.get_mut(&connection_id) {
145 connection.channels.remove(&channel_id);
146 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
147 entry.get_mut().connection_ids.remove(&connection_id);
148 if entry.get_mut().connection_ids.is_empty() {
149 entry.remove();
150 }
151 }
152 }
153 }
154
155 pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
156 Ok(self
157 .connections
158 .get(&connection_id)
159 .ok_or_else(|| anyhow!("unknown connection"))?
160 .user_id)
161 }
162
163 pub fn connection_ids_for_user<'a>(
164 &'a self,
165 user_id: UserId,
166 ) -> impl 'a + Iterator<Item = ConnectionId> {
167 self.connections_by_user_id
168 .get(&user_id)
169 .into_iter()
170 .flatten()
171 .copied()
172 }
173
174 pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
175 let mut contacts = HashMap::new();
176 for worktree_id in self
177 .visible_worktrees_by_user_id
178 .get(&user_id)
179 .unwrap_or(&HashSet::new())
180 {
181 let worktree = &self.worktrees[worktree_id];
182
183 let mut guests = HashSet::new();
184 if let Ok(share) = worktree.share() {
185 for guest_connection_id in share.guests.keys() {
186 if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
187 guests.insert(user_id.to_proto());
188 }
189 }
190 }
191
192 if let Ok(host_user_id) = self.user_id_for_connection(worktree.host_connection_id) {
193 contacts
194 .entry(host_user_id)
195 .or_insert_with(|| proto::Contact {
196 user_id: host_user_id.to_proto(),
197 worktrees: Vec::new(),
198 })
199 .worktrees
200 .push(proto::WorktreeMetadata {
201 id: *worktree_id,
202 root_name: worktree.root_name.clone(),
203 is_shared: worktree.share.is_some(),
204 guests: guests.into_iter().collect(),
205 });
206 }
207 }
208
209 contacts.into_values().collect()
210 }
211
212 pub fn add_worktree(&mut self, worktree: Worktree) -> u64 {
213 let worktree_id = self.next_worktree_id;
214 for authorized_user_id in &worktree.authorized_user_ids {
215 self.visible_worktrees_by_user_id
216 .entry(*authorized_user_id)
217 .or_default()
218 .insert(worktree_id);
219 }
220 self.next_worktree_id += 1;
221 if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
222 connection.worktrees.insert(worktree_id);
223 }
224 self.worktrees.insert(worktree_id, worktree);
225
226 #[cfg(test)]
227 self.check_invariants();
228
229 worktree_id
230 }
231
232 pub fn remove_worktree(
233 &mut self,
234 worktree_id: u64,
235 acting_connection_id: ConnectionId,
236 ) -> tide::Result<Worktree> {
237 let worktree = if let hash_map::Entry::Occupied(e) = self.worktrees.entry(worktree_id) {
238 if e.get().host_connection_id != acting_connection_id {
239 Err(anyhow!("not your worktree"))?;
240 }
241 e.remove()
242 } else {
243 return Err(anyhow!("no such worktree"))?;
244 };
245
246 if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
247 connection.worktrees.remove(&worktree_id);
248 }
249
250 if let Some(share) = &worktree.share {
251 for connection_id in share.guests.keys() {
252 if let Some(connection) = self.connections.get_mut(connection_id) {
253 connection.worktrees.remove(&worktree_id);
254 }
255 }
256 }
257
258 for authorized_user_id in &worktree.authorized_user_ids {
259 if let Some(visible_worktrees) = self
260 .visible_worktrees_by_user_id
261 .get_mut(&authorized_user_id)
262 {
263 visible_worktrees.remove(&worktree_id);
264 }
265 }
266
267 #[cfg(test)]
268 self.check_invariants();
269
270 Ok(worktree)
271 }
272
273 pub fn share_worktree(
274 &mut self,
275 worktree_id: u64,
276 connection_id: ConnectionId,
277 entries: HashMap<u64, proto::Entry>,
278 ) -> Option<Vec<UserId>> {
279 if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
280 if worktree.host_connection_id == connection_id {
281 worktree.share = Some(WorktreeShare {
282 guests: Default::default(),
283 active_replica_ids: Default::default(),
284 entries,
285 });
286 return Some(worktree.authorized_user_ids.clone());
287 }
288 }
289 None
290 }
291
292 pub fn unshare_worktree(
293 &mut self,
294 worktree_id: u64,
295 acting_connection_id: ConnectionId,
296 ) -> tide::Result<UnsharedWorktree> {
297 let worktree = if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
298 worktree
299 } else {
300 return Err(anyhow!("no such worktree"))?;
301 };
302
303 if worktree.host_connection_id != acting_connection_id {
304 return Err(anyhow!("not your worktree"))?;
305 }
306
307 let connection_ids = worktree.connection_ids();
308 let authorized_user_ids = worktree.authorized_user_ids.clone();
309 if let Some(share) = worktree.share.take() {
310 for connection_id in share.guests.into_keys() {
311 if let Some(connection) = self.connections.get_mut(&connection_id) {
312 connection.worktrees.remove(&worktree_id);
313 }
314 }
315
316 #[cfg(test)]
317 self.check_invariants();
318
319 Ok(UnsharedWorktree {
320 connection_ids,
321 authorized_user_ids,
322 })
323 } else {
324 Err(anyhow!("worktree is not shared"))?
325 }
326 }
327
328 pub fn join_worktree(
329 &mut self,
330 connection_id: ConnectionId,
331 user_id: UserId,
332 worktree_id: u64,
333 ) -> tide::Result<JoinedWorktree> {
334 let connection = self
335 .connections
336 .get_mut(&connection_id)
337 .ok_or_else(|| anyhow!("no such connection"))?;
338 let worktree = self
339 .worktrees
340 .get_mut(&worktree_id)
341 .and_then(|worktree| {
342 if worktree.authorized_user_ids.contains(&user_id) {
343 Some(worktree)
344 } else {
345 None
346 }
347 })
348 .ok_or_else(|| anyhow!("no such worktree"))?;
349
350 let share = worktree.share_mut()?;
351 connection.worktrees.insert(worktree_id);
352
353 let mut replica_id = 1;
354 while share.active_replica_ids.contains(&replica_id) {
355 replica_id += 1;
356 }
357 share.active_replica_ids.insert(replica_id);
358 share.guests.insert(connection_id, (replica_id, user_id));
359
360 #[cfg(test)]
361 self.check_invariants();
362
363 Ok(JoinedWorktree {
364 replica_id,
365 worktree: &self.worktrees[&worktree_id],
366 })
367 }
368
369 pub fn leave_worktree(
370 &mut self,
371 connection_id: ConnectionId,
372 worktree_id: u64,
373 ) -> Option<LeftWorktree> {
374 let worktree = self.worktrees.get_mut(&worktree_id)?;
375 let share = worktree.share.as_mut()?;
376 let (replica_id, _) = share.guests.remove(&connection_id)?;
377 share.active_replica_ids.remove(&replica_id);
378
379 if let Some(connection) = self.connections.get_mut(&connection_id) {
380 connection.worktrees.remove(&worktree_id);
381 }
382
383 let connection_ids = worktree.connection_ids();
384 let authorized_user_ids = worktree.authorized_user_ids.clone();
385
386 #[cfg(test)]
387 self.check_invariants();
388
389 Some(LeftWorktree {
390 connection_ids,
391 authorized_user_ids,
392 })
393 }
394
395 pub fn update_worktree(
396 &mut self,
397 connection_id: ConnectionId,
398 worktree_id: u64,
399 removed_entries: &[u64],
400 updated_entries: &[proto::Entry],
401 ) -> tide::Result<Vec<ConnectionId>> {
402 let worktree = self.write_worktree(worktree_id, connection_id)?;
403 let share = worktree.share_mut()?;
404 for entry_id in removed_entries {
405 share.entries.remove(&entry_id);
406 }
407 for entry in updated_entries {
408 share.entries.insert(entry.id, entry.clone());
409 }
410 Ok(worktree.connection_ids())
411 }
412
413 pub fn worktree_host_connection_id(
414 &self,
415 connection_id: ConnectionId,
416 worktree_id: u64,
417 ) -> tide::Result<ConnectionId> {
418 Ok(self
419 .read_worktree(worktree_id, connection_id)?
420 .host_connection_id)
421 }
422
423 pub fn worktree_guest_connection_ids(
424 &self,
425 connection_id: ConnectionId,
426 worktree_id: u64,
427 ) -> tide::Result<Vec<ConnectionId>> {
428 Ok(self
429 .read_worktree(worktree_id, connection_id)?
430 .share()?
431 .guests
432 .keys()
433 .copied()
434 .collect())
435 }
436
437 pub fn worktree_connection_ids(
438 &self,
439 connection_id: ConnectionId,
440 worktree_id: u64,
441 ) -> tide::Result<Vec<ConnectionId>> {
442 Ok(self
443 .read_worktree(worktree_id, connection_id)?
444 .connection_ids())
445 }
446
447 pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
448 Some(self.channels.get(&channel_id)?.connection_ids())
449 }
450
451 fn read_worktree(
452 &self,
453 worktree_id: u64,
454 connection_id: ConnectionId,
455 ) -> tide::Result<&Worktree> {
456 let worktree = self
457 .worktrees
458 .get(&worktree_id)
459 .ok_or_else(|| anyhow!("worktree not found"))?;
460
461 if worktree.host_connection_id == connection_id
462 || worktree.share()?.guests.contains_key(&connection_id)
463 {
464 Ok(worktree)
465 } else {
466 Err(anyhow!(
467 "{} is not a member of worktree {}",
468 connection_id,
469 worktree_id
470 ))?
471 }
472 }
473
474 fn write_worktree(
475 &mut self,
476 worktree_id: u64,
477 connection_id: ConnectionId,
478 ) -> tide::Result<&mut Worktree> {
479 let worktree = self
480 .worktrees
481 .get_mut(&worktree_id)
482 .ok_or_else(|| anyhow!("worktree not found"))?;
483
484 if worktree.host_connection_id == connection_id
485 || worktree
486 .share
487 .as_ref()
488 .map_or(false, |share| share.guests.contains_key(&connection_id))
489 {
490 Ok(worktree)
491 } else {
492 Err(anyhow!(
493 "{} is not a member of worktree {}",
494 connection_id,
495 worktree_id
496 ))?
497 }
498 }
499
500 #[cfg(test)]
501 fn check_invariants(&self) {
502 for (connection_id, connection) in &self.connections {
503 for worktree_id in &connection.worktrees {
504 let worktree = &self.worktrees.get(&worktree_id).unwrap();
505 if worktree.host_connection_id != *connection_id {
506 assert!(worktree.share().unwrap().guests.contains_key(connection_id));
507 }
508 }
509 for channel_id in &connection.channels {
510 let channel = self.channels.get(channel_id).unwrap();
511 assert!(channel.connection_ids.contains(connection_id));
512 }
513 assert!(self
514 .connections_by_user_id
515 .get(&connection.user_id)
516 .unwrap()
517 .contains(connection_id));
518 }
519
520 for (user_id, connection_ids) in &self.connections_by_user_id {
521 for connection_id in connection_ids {
522 assert_eq!(
523 self.connections.get(connection_id).unwrap().user_id,
524 *user_id
525 );
526 }
527 }
528
529 for (worktree_id, worktree) in &self.worktrees {
530 let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
531 assert!(host_connection.worktrees.contains(worktree_id));
532
533 for authorized_user_ids in &worktree.authorized_user_ids {
534 let visible_worktree_ids = self
535 .visible_worktrees_by_user_id
536 .get(authorized_user_ids)
537 .unwrap();
538 assert!(visible_worktree_ids.contains(worktree_id));
539 }
540
541 if let Some(share) = &worktree.share {
542 for guest_connection_id in share.guests.keys() {
543 let guest_connection = self.connections.get(guest_connection_id).unwrap();
544 assert!(guest_connection.worktrees.contains(worktree_id));
545 }
546 assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
547 assert_eq!(
548 share.active_replica_ids,
549 share
550 .guests
551 .values()
552 .map(|(replica_id, _)| *replica_id)
553 .collect::<HashSet<_>>(),
554 );
555 }
556 }
557
558 for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
559 for worktree_id in visible_worktree_ids {
560 let worktree = self.worktrees.get(worktree_id).unwrap();
561 assert!(worktree.authorized_user_ids.contains(user_id));
562 }
563 }
564
565 for (channel_id, channel) in &self.channels {
566 for connection_id in &channel.connection_ids {
567 let connection = self.connections.get(connection_id).unwrap();
568 assert!(connection.channels.contains(channel_id));
569 }
570 }
571 }
572}
573
574impl Worktree {
575 pub fn connection_ids(&self) -> Vec<ConnectionId> {
576 if let Some(share) = &self.share {
577 share
578 .guests
579 .keys()
580 .copied()
581 .chain(Some(self.host_connection_id))
582 .collect()
583 } else {
584 vec![self.host_connection_id]
585 }
586 }
587
588 pub fn share(&self) -> tide::Result<&WorktreeShare> {
589 Ok(self
590 .share
591 .as_ref()
592 .ok_or_else(|| anyhow!("worktree is not shared"))?)
593 }
594
595 fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
596 Ok(self
597 .share
598 .as_mut()
599 .ok_or_else(|| anyhow!("worktree is not shared"))?)
600 }
601}
602
603impl Channel {
604 fn connection_ids(&self) -> Vec<ConnectionId> {
605 self.connection_ids.iter().copied().collect()
606 }
607}