1use crate::db::{ChannelId, UserId};
2use anyhow::anyhow;
3use collections::{HashMap, HashSet};
4use rpc::{proto, ConnectionId};
5use std::collections::hash_map;
6
7#[derive(Default)]
8pub struct Store {
9 connections: HashMap<ConnectionId, ConnectionState>,
10 connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
11 worktrees: HashMap<u64, Worktree>,
12 visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
13 channels: HashMap<ChannelId, Channel>,
14 next_worktree_id: u64,
15}
16
17struct ConnectionState {
18 user_id: UserId,
19 worktrees: HashSet<u64>,
20 channels: HashSet<ChannelId>,
21}
22
23pub struct Worktree {
24 pub host_connection_id: ConnectionId,
25 pub host_user_id: UserId,
26 pub authorized_user_ids: Vec<UserId>,
27 pub root_name: String,
28 pub share: Option<WorktreeShare>,
29}
30
31pub struct WorktreeShare {
32 pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
33 pub active_replica_ids: HashSet<ReplicaId>,
34 pub entries: HashMap<u64, proto::Entry>,
35}
36
37#[derive(Default)]
38pub struct Channel {
39 pub connection_ids: HashSet<ConnectionId>,
40}
41
42pub type ReplicaId = u16;
43
44#[derive(Default)]
45pub struct RemovedConnectionState {
46 pub hosted_worktrees: HashMap<u64, Worktree>,
47 pub guest_worktree_ids: HashMap<u64, Vec<ConnectionId>>,
48 pub contact_ids: HashSet<UserId>,
49}
50
51pub struct JoinedWorktree<'a> {
52 pub replica_id: ReplicaId,
53 pub worktree: &'a Worktree,
54}
55
56pub struct UnsharedWorktree {
57 pub connection_ids: Vec<ConnectionId>,
58 pub authorized_user_ids: Vec<UserId>,
59}
60
61pub struct LeftWorktree {
62 pub connection_ids: Vec<ConnectionId>,
63 pub authorized_user_ids: Vec<UserId>,
64}
65
66impl Store {
67 pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
68 self.connections.insert(
69 connection_id,
70 ConnectionState {
71 user_id,
72 worktrees: Default::default(),
73 channels: Default::default(),
74 },
75 );
76 self.connections_by_user_id
77 .entry(user_id)
78 .or_default()
79 .insert(connection_id);
80 }
81
82 pub fn remove_connection(
83 &mut self,
84 connection_id: ConnectionId,
85 ) -> tide::Result<RemovedConnectionState> {
86 let connection = if let Some(connection) = self.connections.remove(&connection_id) {
87 connection
88 } else {
89 return Err(anyhow!("no such connection"))?;
90 };
91
92 for channel_id in &connection.channels {
93 if let Some(channel) = self.channels.get_mut(&channel_id) {
94 channel.connection_ids.remove(&connection_id);
95 }
96 }
97
98 let user_connections = self
99 .connections_by_user_id
100 .get_mut(&connection.user_id)
101 .unwrap();
102 user_connections.remove(&connection_id);
103 if user_connections.is_empty() {
104 self.connections_by_user_id.remove(&connection.user_id);
105 }
106
107 let mut result = RemovedConnectionState::default();
108 for worktree_id in connection.worktrees.clone() {
109 if let Ok(worktree) = self.remove_worktree(worktree_id, connection_id) {
110 result
111 .contact_ids
112 .extend(worktree.authorized_user_ids.iter().copied());
113 result.hosted_worktrees.insert(worktree_id, worktree);
114 } else if let Some(worktree) = self.leave_worktree(connection_id, worktree_id) {
115 result
116 .guest_worktree_ids
117 .insert(worktree_id, worktree.connection_ids);
118 result.contact_ids.extend(worktree.authorized_user_ids);
119 }
120 }
121
122 #[cfg(test)]
123 self.check_invariants();
124
125 Ok(result)
126 }
127
128 #[cfg(test)]
129 pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
130 self.channels.get(&id)
131 }
132
133 pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
134 if let Some(connection) = self.connections.get_mut(&connection_id) {
135 connection.channels.insert(channel_id);
136 self.channels
137 .entry(channel_id)
138 .or_default()
139 .connection_ids
140 .insert(connection_id);
141 }
142 }
143
144 pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
145 if let Some(connection) = self.connections.get_mut(&connection_id) {
146 connection.channels.remove(&channel_id);
147 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
148 entry.get_mut().connection_ids.remove(&connection_id);
149 if entry.get_mut().connection_ids.is_empty() {
150 entry.remove();
151 }
152 }
153 }
154 }
155
156 pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
157 Ok(self
158 .connections
159 .get(&connection_id)
160 .ok_or_else(|| anyhow!("unknown connection"))?
161 .user_id)
162 }
163
164 pub fn connection_ids_for_user<'a>(
165 &'a self,
166 user_id: UserId,
167 ) -> impl 'a + Iterator<Item = ConnectionId> {
168 self.connections_by_user_id
169 .get(&user_id)
170 .into_iter()
171 .flatten()
172 .copied()
173 }
174
175 pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
176 let mut contacts = HashMap::default();
177 for worktree_id in self
178 .visible_worktrees_by_user_id
179 .get(&user_id)
180 .unwrap_or(&HashSet::default())
181 {
182 let worktree = &self.worktrees[worktree_id];
183
184 let mut guests = HashSet::default();
185 if let Ok(share) = worktree.share() {
186 for guest_connection_id in share.guests.keys() {
187 if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
188 guests.insert(user_id.to_proto());
189 }
190 }
191 }
192
193 if let Ok(host_user_id) = self.user_id_for_connection(worktree.host_connection_id) {
194 contacts
195 .entry(host_user_id)
196 .or_insert_with(|| proto::Contact {
197 user_id: host_user_id.to_proto(),
198 worktrees: Vec::new(),
199 })
200 .worktrees
201 .push(proto::WorktreeMetadata {
202 id: *worktree_id,
203 root_name: worktree.root_name.clone(),
204 is_shared: worktree.share.is_some(),
205 guests: guests.into_iter().collect(),
206 });
207 }
208 }
209
210 contacts.into_values().collect()
211 }
212
213 pub fn add_worktree(&mut self, worktree: Worktree) -> u64 {
214 let worktree_id = self.next_worktree_id;
215 for authorized_user_id in &worktree.authorized_user_ids {
216 self.visible_worktrees_by_user_id
217 .entry(*authorized_user_id)
218 .or_default()
219 .insert(worktree_id);
220 }
221 self.next_worktree_id += 1;
222 if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
223 connection.worktrees.insert(worktree_id);
224 }
225 self.worktrees.insert(worktree_id, worktree);
226
227 #[cfg(test)]
228 self.check_invariants();
229
230 worktree_id
231 }
232
233 pub fn remove_worktree(
234 &mut self,
235 worktree_id: u64,
236 acting_connection_id: ConnectionId,
237 ) -> tide::Result<Worktree> {
238 let worktree = if let hash_map::Entry::Occupied(e) = self.worktrees.entry(worktree_id) {
239 if e.get().host_connection_id != acting_connection_id {
240 Err(anyhow!("not your worktree"))?;
241 }
242 e.remove()
243 } else {
244 return Err(anyhow!("no such worktree"))?;
245 };
246
247 if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
248 connection.worktrees.remove(&worktree_id);
249 }
250
251 if let Some(share) = &worktree.share {
252 for connection_id in share.guests.keys() {
253 if let Some(connection) = self.connections.get_mut(connection_id) {
254 connection.worktrees.remove(&worktree_id);
255 }
256 }
257 }
258
259 for authorized_user_id in &worktree.authorized_user_ids {
260 if let Some(visible_worktrees) = self
261 .visible_worktrees_by_user_id
262 .get_mut(&authorized_user_id)
263 {
264 visible_worktrees.remove(&worktree_id);
265 }
266 }
267
268 #[cfg(test)]
269 self.check_invariants();
270
271 Ok(worktree)
272 }
273
274 pub fn share_worktree(
275 &mut self,
276 worktree_id: u64,
277 connection_id: ConnectionId,
278 entries: HashMap<u64, proto::Entry>,
279 ) -> Option<Vec<UserId>> {
280 if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
281 if worktree.host_connection_id == connection_id {
282 worktree.share = Some(WorktreeShare {
283 guests: Default::default(),
284 active_replica_ids: Default::default(),
285 entries,
286 });
287 return Some(worktree.authorized_user_ids.clone());
288 }
289 }
290 None
291 }
292
293 pub fn unshare_worktree(
294 &mut self,
295 worktree_id: u64,
296 acting_connection_id: ConnectionId,
297 ) -> tide::Result<UnsharedWorktree> {
298 let worktree = if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
299 worktree
300 } else {
301 return Err(anyhow!("no such worktree"))?;
302 };
303
304 if worktree.host_connection_id != acting_connection_id {
305 return Err(anyhow!("not your worktree"))?;
306 }
307
308 let connection_ids = worktree.connection_ids();
309 let authorized_user_ids = worktree.authorized_user_ids.clone();
310 if let Some(share) = worktree.share.take() {
311 for connection_id in share.guests.into_keys() {
312 if let Some(connection) = self.connections.get_mut(&connection_id) {
313 connection.worktrees.remove(&worktree_id);
314 }
315 }
316
317 #[cfg(test)]
318 self.check_invariants();
319
320 Ok(UnsharedWorktree {
321 connection_ids,
322 authorized_user_ids,
323 })
324 } else {
325 Err(anyhow!("worktree is not shared"))?
326 }
327 }
328
329 pub fn join_worktree(
330 &mut self,
331 connection_id: ConnectionId,
332 user_id: UserId,
333 worktree_id: u64,
334 ) -> tide::Result<JoinedWorktree> {
335 let connection = self
336 .connections
337 .get_mut(&connection_id)
338 .ok_or_else(|| anyhow!("no such connection"))?;
339 let worktree = self
340 .worktrees
341 .get_mut(&worktree_id)
342 .and_then(|worktree| {
343 if worktree.authorized_user_ids.contains(&user_id) {
344 Some(worktree)
345 } else {
346 None
347 }
348 })
349 .ok_or_else(|| anyhow!("no such worktree"))?;
350
351 let share = worktree.share_mut()?;
352 connection.worktrees.insert(worktree_id);
353
354 let mut replica_id = 1;
355 while share.active_replica_ids.contains(&replica_id) {
356 replica_id += 1;
357 }
358 share.active_replica_ids.insert(replica_id);
359 share.guests.insert(connection_id, (replica_id, user_id));
360
361 #[cfg(test)]
362 self.check_invariants();
363
364 Ok(JoinedWorktree {
365 replica_id,
366 worktree: &self.worktrees[&worktree_id],
367 })
368 }
369
370 pub fn leave_worktree(
371 &mut self,
372 connection_id: ConnectionId,
373 worktree_id: u64,
374 ) -> Option<LeftWorktree> {
375 let worktree = self.worktrees.get_mut(&worktree_id)?;
376 let share = worktree.share.as_mut()?;
377 let (replica_id, _) = share.guests.remove(&connection_id)?;
378 share.active_replica_ids.remove(&replica_id);
379
380 if let Some(connection) = self.connections.get_mut(&connection_id) {
381 connection.worktrees.remove(&worktree_id);
382 }
383
384 let connection_ids = worktree.connection_ids();
385 let authorized_user_ids = worktree.authorized_user_ids.clone();
386
387 #[cfg(test)]
388 self.check_invariants();
389
390 Some(LeftWorktree {
391 connection_ids,
392 authorized_user_ids,
393 })
394 }
395
396 pub fn update_worktree(
397 &mut self,
398 connection_id: ConnectionId,
399 worktree_id: u64,
400 removed_entries: &[u64],
401 updated_entries: &[proto::Entry],
402 ) -> tide::Result<Vec<ConnectionId>> {
403 let worktree = self.write_worktree(worktree_id, connection_id)?;
404 let share = worktree.share_mut()?;
405 for entry_id in removed_entries {
406 share.entries.remove(&entry_id);
407 }
408 for entry in updated_entries {
409 share.entries.insert(entry.id, entry.clone());
410 }
411 Ok(worktree.connection_ids())
412 }
413
414 pub fn worktree_host_connection_id(
415 &self,
416 connection_id: ConnectionId,
417 worktree_id: u64,
418 ) -> tide::Result<ConnectionId> {
419 Ok(self
420 .read_worktree(worktree_id, connection_id)?
421 .host_connection_id)
422 }
423
424 pub fn worktree_guest_connection_ids(
425 &self,
426 connection_id: ConnectionId,
427 worktree_id: u64,
428 ) -> tide::Result<Vec<ConnectionId>> {
429 Ok(self
430 .read_worktree(worktree_id, connection_id)?
431 .share()?
432 .guests
433 .keys()
434 .copied()
435 .collect())
436 }
437
438 pub fn worktree_connection_ids(
439 &self,
440 connection_id: ConnectionId,
441 worktree_id: u64,
442 ) -> tide::Result<Vec<ConnectionId>> {
443 Ok(self
444 .read_worktree(worktree_id, connection_id)?
445 .connection_ids())
446 }
447
448 pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
449 Some(self.channels.get(&channel_id)?.connection_ids())
450 }
451
452 fn read_worktree(
453 &self,
454 worktree_id: u64,
455 connection_id: ConnectionId,
456 ) -> tide::Result<&Worktree> {
457 let worktree = self
458 .worktrees
459 .get(&worktree_id)
460 .ok_or_else(|| anyhow!("worktree not found"))?;
461
462 if worktree.host_connection_id == connection_id
463 || worktree.share()?.guests.contains_key(&connection_id)
464 {
465 Ok(worktree)
466 } else {
467 Err(anyhow!(
468 "{} is not a member of worktree {}",
469 connection_id,
470 worktree_id
471 ))?
472 }
473 }
474
475 fn write_worktree(
476 &mut self,
477 worktree_id: u64,
478 connection_id: ConnectionId,
479 ) -> tide::Result<&mut Worktree> {
480 let worktree = self
481 .worktrees
482 .get_mut(&worktree_id)
483 .ok_or_else(|| anyhow!("worktree not found"))?;
484
485 if worktree.host_connection_id == connection_id
486 || worktree
487 .share
488 .as_ref()
489 .map_or(false, |share| share.guests.contains_key(&connection_id))
490 {
491 Ok(worktree)
492 } else {
493 Err(anyhow!(
494 "{} is not a member of worktree {}",
495 connection_id,
496 worktree_id
497 ))?
498 }
499 }
500
501 #[cfg(test)]
502 fn check_invariants(&self) {
503 for (connection_id, connection) in &self.connections {
504 for worktree_id in &connection.worktrees {
505 let worktree = &self.worktrees.get(&worktree_id).unwrap();
506 if worktree.host_connection_id != *connection_id {
507 assert!(worktree.share().unwrap().guests.contains_key(connection_id));
508 }
509 }
510 for channel_id in &connection.channels {
511 let channel = self.channels.get(channel_id).unwrap();
512 assert!(channel.connection_ids.contains(connection_id));
513 }
514 assert!(self
515 .connections_by_user_id
516 .get(&connection.user_id)
517 .unwrap()
518 .contains(connection_id));
519 }
520
521 for (user_id, connection_ids) in &self.connections_by_user_id {
522 for connection_id in connection_ids {
523 assert_eq!(
524 self.connections.get(connection_id).unwrap().user_id,
525 *user_id
526 );
527 }
528 }
529
530 for (worktree_id, worktree) in &self.worktrees {
531 let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
532 assert!(host_connection.worktrees.contains(worktree_id));
533
534 for authorized_user_ids in &worktree.authorized_user_ids {
535 let visible_worktree_ids = self
536 .visible_worktrees_by_user_id
537 .get(authorized_user_ids)
538 .unwrap();
539 assert!(visible_worktree_ids.contains(worktree_id));
540 }
541
542 if let Some(share) = &worktree.share {
543 for guest_connection_id in share.guests.keys() {
544 let guest_connection = self.connections.get(guest_connection_id).unwrap();
545 assert!(guest_connection.worktrees.contains(worktree_id));
546 }
547 assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
548 assert_eq!(
549 share.active_replica_ids,
550 share
551 .guests
552 .values()
553 .map(|(replica_id, _)| *replica_id)
554 .collect::<HashSet<_>>(),
555 );
556 }
557 }
558
559 for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
560 for worktree_id in visible_worktree_ids {
561 let worktree = self.worktrees.get(worktree_id).unwrap();
562 assert!(worktree.authorized_user_ids.contains(user_id));
563 }
564 }
565
566 for (channel_id, channel) in &self.channels {
567 for connection_id in &channel.connection_ids {
568 let connection = self.connections.get(connection_id).unwrap();
569 assert!(connection.channels.contains(channel_id));
570 }
571 }
572 }
573}
574
575impl Worktree {
576 pub fn connection_ids(&self) -> Vec<ConnectionId> {
577 if let Some(share) = &self.share {
578 share
579 .guests
580 .keys()
581 .copied()
582 .chain(Some(self.host_connection_id))
583 .collect()
584 } else {
585 vec![self.host_connection_id]
586 }
587 }
588
589 pub fn share(&self) -> tide::Result<&WorktreeShare> {
590 Ok(self
591 .share
592 .as_ref()
593 .ok_or_else(|| anyhow!("worktree is not shared"))?)
594 }
595
596 fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
597 Ok(self
598 .share
599 .as_mut()
600 .ok_or_else(|| anyhow!("worktree is not shared"))?)
601 }
602}
603
604impl Channel {
605 fn connection_ids(&self) -> Vec<ConnectionId> {
606 self.connection_ids.iter().copied().collect()
607 }
608}