1use crate::db::{ChannelId, UserId};
2use anyhow::anyhow;
3use rpc::{proto, ConnectionId};
4use std::collections::{hash_map, HashMap, HashSet};
5
6#[derive(Default)]
7pub struct Store {
8 connections: HashMap<ConnectionId, ConnectionState>,
9 connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
10 worktrees: HashMap<u64, Worktree>,
11 visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
12 channels: HashMap<ChannelId, Channel>,
13 next_worktree_id: u64,
14}
15
16struct ConnectionState {
17 user_id: UserId,
18 worktrees: HashSet<u64>,
19 channels: HashSet<ChannelId>,
20}
21
22pub struct Worktree {
23 pub host_connection_id: ConnectionId,
24 pub host_user_id: UserId,
25 pub contact_user_ids: Vec<UserId>,
26 pub root_name: String,
27 pub share: Option<WorktreeShare>,
28}
29
30pub struct WorktreeShare {
31 pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
32 pub active_replica_ids: HashSet<ReplicaId>,
33 pub entries: HashMap<u64, proto::Entry>,
34}
35
36#[derive(Default)]
37pub struct Channel {
38 pub connection_ids: HashSet<ConnectionId>,
39}
40
41pub type ReplicaId = u16;
42
43#[derive(Default)]
44pub struct RemovedConnectionState {
45 pub hosted_worktrees: HashMap<u64, Worktree>,
46 pub guest_worktree_ids: HashMap<u64, Vec<ConnectionId>>,
47 pub contact_ids: HashSet<UserId>,
48}
49
50pub struct JoinedWorktree<'a> {
51 pub replica_id: ReplicaId,
52 pub worktree: &'a Worktree,
53}
54
55pub struct UnsharedWorktree {
56 pub connection_ids: Vec<ConnectionId>,
57 pub contact_ids: Vec<UserId>,
58}
59
60pub struct LeftWorktree {
61 pub connection_ids: Vec<ConnectionId>,
62 pub contact_ids: Vec<UserId>,
63}
64
65impl Store {
66 pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
67 self.connections.insert(
68 connection_id,
69 ConnectionState {
70 user_id,
71 worktrees: Default::default(),
72 channels: Default::default(),
73 },
74 );
75 self.connections_by_user_id
76 .entry(user_id)
77 .or_default()
78 .insert(connection_id);
79 }
80
81 pub fn remove_connection(
82 &mut self,
83 connection_id: ConnectionId,
84 ) -> tide::Result<RemovedConnectionState> {
85 let connection = if let Some(connection) = self.connections.remove(&connection_id) {
86 connection
87 } else {
88 return Err(anyhow!("no such connection"))?;
89 };
90
91 for channel_id in &connection.channels {
92 if let Some(channel) = self.channels.get_mut(&channel_id) {
93 channel.connection_ids.remove(&connection_id);
94 }
95 }
96
97 let user_connections = self
98 .connections_by_user_id
99 .get_mut(&connection.user_id)
100 .unwrap();
101 user_connections.remove(&connection_id);
102 if user_connections.is_empty() {
103 self.connections_by_user_id.remove(&connection.user_id);
104 }
105
106 let mut result = RemovedConnectionState::default();
107 for worktree_id in connection.worktrees.clone() {
108 if let Ok(worktree) = self.remove_worktree(worktree_id, connection_id) {
109 result
110 .contact_ids
111 .extend(worktree.contact_user_ids.iter().copied());
112 result.hosted_worktrees.insert(worktree_id, worktree);
113 } else if let Some(worktree) = self.leave_worktree(connection_id, worktree_id) {
114 result
115 .guest_worktree_ids
116 .insert(worktree_id, worktree.connection_ids);
117 result.contact_ids.extend(worktree.contact_ids);
118 }
119 }
120
121 #[cfg(test)]
122 self.check_invariants();
123
124 Ok(result)
125 }
126
127 #[cfg(test)]
128 pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
129 self.channels.get(&id)
130 }
131
132 pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
133 if let Some(connection) = self.connections.get_mut(&connection_id) {
134 connection.channels.insert(channel_id);
135 self.channels
136 .entry(channel_id)
137 .or_default()
138 .connection_ids
139 .insert(connection_id);
140 }
141 }
142
143 pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
144 if let Some(connection) = self.connections.get_mut(&connection_id) {
145 connection.channels.remove(&channel_id);
146 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
147 entry.get_mut().connection_ids.remove(&connection_id);
148 if entry.get_mut().connection_ids.is_empty() {
149 entry.remove();
150 }
151 }
152 }
153 }
154
155 pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
156 Ok(self
157 .connections
158 .get(&connection_id)
159 .ok_or_else(|| anyhow!("unknown connection"))?
160 .user_id)
161 }
162
163 pub fn connection_ids_for_user<'a>(
164 &'a self,
165 user_id: UserId,
166 ) -> impl 'a + Iterator<Item = ConnectionId> {
167 self.connections_by_user_id
168 .get(&user_id)
169 .into_iter()
170 .flatten()
171 .copied()
172 }
173
174 pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
175 let mut contacts = HashMap::new();
176 for worktree_id in self
177 .visible_worktrees_by_user_id
178 .get(&user_id)
179 .unwrap_or(&HashSet::new())
180 {
181 let worktree = &self.worktrees[worktree_id];
182
183 let mut guests = HashSet::new();
184 if let Ok(share) = worktree.share() {
185 for guest_connection_id in share.guests.keys() {
186 if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
187 guests.insert(user_id.to_proto());
188 }
189 }
190 }
191
192 if let Ok(host_user_id) = self.user_id_for_connection(worktree.host_connection_id) {
193 contacts
194 .entry(host_user_id)
195 .or_insert_with(|| proto::Contact {
196 user_id: host_user_id.to_proto(),
197 worktrees: Vec::new(),
198 })
199 .worktrees
200 .push(proto::WorktreeMetadata {
201 id: *worktree_id,
202 root_name: worktree.root_name.clone(),
203 is_shared: worktree.share.is_some(),
204 guests: guests.into_iter().collect(),
205 });
206 }
207 }
208
209 contacts.into_values().collect()
210 }
211
212 pub fn add_worktree(&mut self, worktree: Worktree) -> u64 {
213 let worktree_id = self.next_worktree_id;
214 for contact_user_id in &worktree.contact_user_ids {
215 self.visible_worktrees_by_user_id
216 .entry(*contact_user_id)
217 .or_default()
218 .insert(worktree_id);
219 }
220 self.next_worktree_id += 1;
221 if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
222 connection.worktrees.insert(worktree_id);
223 }
224 self.worktrees.insert(worktree_id, worktree);
225
226 #[cfg(test)]
227 self.check_invariants();
228
229 worktree_id
230 }
231
232 pub fn remove_worktree(
233 &mut self,
234 worktree_id: u64,
235 acting_connection_id: ConnectionId,
236 ) -> tide::Result<Worktree> {
237 let worktree = if let hash_map::Entry::Occupied(e) = self.worktrees.entry(worktree_id) {
238 if e.get().host_connection_id != acting_connection_id {
239 Err(anyhow!("not your worktree"))?;
240 }
241 e.remove()
242 } else {
243 return Err(anyhow!("no such worktree"))?;
244 };
245
246 if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
247 connection.worktrees.remove(&worktree_id);
248 }
249
250 if let Some(share) = &worktree.share {
251 for connection_id in share.guests.keys() {
252 if let Some(connection) = self.connections.get_mut(connection_id) {
253 connection.worktrees.remove(&worktree_id);
254 }
255 }
256 }
257
258 for contact_user_id in &worktree.contact_user_ids {
259 if let Some(visible_worktrees) =
260 self.visible_worktrees_by_user_id.get_mut(&contact_user_id)
261 {
262 visible_worktrees.remove(&worktree_id);
263 }
264 }
265
266 #[cfg(test)]
267 self.check_invariants();
268
269 Ok(worktree)
270 }
271
272 pub fn share_worktree(
273 &mut self,
274 worktree_id: u64,
275 connection_id: ConnectionId,
276 entries: HashMap<u64, proto::Entry>,
277 ) -> Option<Vec<UserId>> {
278 if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
279 if worktree.host_connection_id == connection_id {
280 worktree.share = Some(WorktreeShare {
281 guests: Default::default(),
282 active_replica_ids: Default::default(),
283 entries,
284 });
285 return Some(worktree.contact_user_ids.clone());
286 }
287 }
288 None
289 }
290
291 pub fn unshare_worktree(
292 &mut self,
293 worktree_id: u64,
294 acting_connection_id: ConnectionId,
295 ) -> tide::Result<UnsharedWorktree> {
296 let worktree = if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
297 worktree
298 } else {
299 return Err(anyhow!("no such worktree"))?;
300 };
301
302 if worktree.host_connection_id != acting_connection_id {
303 return Err(anyhow!("not your worktree"))?;
304 }
305
306 let connection_ids = worktree.connection_ids();
307 let contact_ids = worktree.contact_user_ids.clone();
308 if let Some(share) = worktree.share.take() {
309 for connection_id in share.guests.into_keys() {
310 if let Some(connection) = self.connections.get_mut(&connection_id) {
311 connection.worktrees.remove(&worktree_id);
312 }
313 }
314
315 #[cfg(test)]
316 self.check_invariants();
317
318 Ok(UnsharedWorktree {
319 connection_ids,
320 contact_ids,
321 })
322 } else {
323 Err(anyhow!("worktree is not shared"))?
324 }
325 }
326
327 pub fn join_worktree(
328 &mut self,
329 connection_id: ConnectionId,
330 user_id: UserId,
331 worktree_id: u64,
332 ) -> tide::Result<JoinedWorktree> {
333 let connection = self
334 .connections
335 .get_mut(&connection_id)
336 .ok_or_else(|| anyhow!("no such connection"))?;
337 let worktree = self
338 .worktrees
339 .get_mut(&worktree_id)
340 .and_then(|worktree| {
341 if worktree.contact_user_ids.contains(&user_id) {
342 Some(worktree)
343 } else {
344 None
345 }
346 })
347 .ok_or_else(|| anyhow!("no such worktree"))?;
348
349 let share = worktree.share_mut()?;
350 connection.worktrees.insert(worktree_id);
351
352 let mut replica_id = 1;
353 while share.active_replica_ids.contains(&replica_id) {
354 replica_id += 1;
355 }
356 share.active_replica_ids.insert(replica_id);
357 share.guests.insert(connection_id, (replica_id, user_id));
358
359 #[cfg(test)]
360 self.check_invariants();
361
362 Ok(JoinedWorktree {
363 replica_id,
364 worktree: &self.worktrees[&worktree_id],
365 })
366 }
367
368 pub fn leave_worktree(
369 &mut self,
370 connection_id: ConnectionId,
371 worktree_id: u64,
372 ) -> Option<LeftWorktree> {
373 let worktree = self.worktrees.get_mut(&worktree_id)?;
374 let share = worktree.share.as_mut()?;
375 let (replica_id, _) = share.guests.remove(&connection_id)?;
376 share.active_replica_ids.remove(&replica_id);
377
378 if let Some(connection) = self.connections.get_mut(&connection_id) {
379 connection.worktrees.remove(&worktree_id);
380 }
381
382 let connection_ids = worktree.connection_ids();
383 let contact_ids = worktree.contact_user_ids.clone();
384
385 #[cfg(test)]
386 self.check_invariants();
387
388 Some(LeftWorktree {
389 connection_ids,
390 contact_ids,
391 })
392 }
393
394 pub fn update_worktree(
395 &mut self,
396 connection_id: ConnectionId,
397 worktree_id: u64,
398 removed_entries: &[u64],
399 updated_entries: &[proto::Entry],
400 ) -> tide::Result<Vec<ConnectionId>> {
401 let worktree = self.write_worktree(worktree_id, connection_id)?;
402 let share = worktree.share_mut()?;
403 for entry_id in removed_entries {
404 share.entries.remove(&entry_id);
405 }
406 for entry in updated_entries {
407 share.entries.insert(entry.id, entry.clone());
408 }
409 Ok(worktree.connection_ids())
410 }
411
412 pub fn worktree_host_connection_id(
413 &self,
414 connection_id: ConnectionId,
415 worktree_id: u64,
416 ) -> tide::Result<ConnectionId> {
417 Ok(self
418 .read_worktree(worktree_id, connection_id)?
419 .host_connection_id)
420 }
421
422 pub fn worktree_guest_connection_ids(
423 &self,
424 connection_id: ConnectionId,
425 worktree_id: u64,
426 ) -> tide::Result<Vec<ConnectionId>> {
427 Ok(self
428 .read_worktree(worktree_id, connection_id)?
429 .share()?
430 .guests
431 .keys()
432 .copied()
433 .collect())
434 }
435
436 pub fn worktree_connection_ids(
437 &self,
438 connection_id: ConnectionId,
439 worktree_id: u64,
440 ) -> tide::Result<Vec<ConnectionId>> {
441 Ok(self
442 .read_worktree(worktree_id, connection_id)?
443 .connection_ids())
444 }
445
446 pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
447 Some(self.channels.get(&channel_id)?.connection_ids())
448 }
449
450 fn read_worktree(
451 &self,
452 worktree_id: u64,
453 connection_id: ConnectionId,
454 ) -> tide::Result<&Worktree> {
455 let worktree = self
456 .worktrees
457 .get(&worktree_id)
458 .ok_or_else(|| anyhow!("worktree not found"))?;
459
460 if worktree.host_connection_id == connection_id
461 || worktree.share()?.guests.contains_key(&connection_id)
462 {
463 Ok(worktree)
464 } else {
465 Err(anyhow!(
466 "{} is not a member of worktree {}",
467 connection_id,
468 worktree_id
469 ))?
470 }
471 }
472
473 fn write_worktree(
474 &mut self,
475 worktree_id: u64,
476 connection_id: ConnectionId,
477 ) -> tide::Result<&mut Worktree> {
478 let worktree = self
479 .worktrees
480 .get_mut(&worktree_id)
481 .ok_or_else(|| anyhow!("worktree not found"))?;
482
483 if worktree.host_connection_id == connection_id
484 || worktree
485 .share
486 .as_ref()
487 .map_or(false, |share| share.guests.contains_key(&connection_id))
488 {
489 Ok(worktree)
490 } else {
491 Err(anyhow!(
492 "{} is not a member of worktree {}",
493 connection_id,
494 worktree_id
495 ))?
496 }
497 }
498
499 #[cfg(test)]
500 fn check_invariants(&self) {
501 for (connection_id, connection) in &self.connections {
502 for worktree_id in &connection.worktrees {
503 let worktree = &self.worktrees.get(&worktree_id).unwrap();
504 if worktree.host_connection_id != *connection_id {
505 assert!(worktree.share().unwrap().guests.contains_key(connection_id));
506 }
507 }
508 for channel_id in &connection.channels {
509 let channel = self.channels.get(channel_id).unwrap();
510 assert!(channel.connection_ids.contains(connection_id));
511 }
512 assert!(self
513 .connections_by_user_id
514 .get(&connection.user_id)
515 .unwrap()
516 .contains(connection_id));
517 }
518
519 for (user_id, connection_ids) in &self.connections_by_user_id {
520 for connection_id in connection_ids {
521 assert_eq!(
522 self.connections.get(connection_id).unwrap().user_id,
523 *user_id
524 );
525 }
526 }
527
528 for (worktree_id, worktree) in &self.worktrees {
529 let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
530 assert!(host_connection.worktrees.contains(worktree_id));
531
532 for contact_id in &worktree.contact_user_ids {
533 let visible_worktree_ids =
534 self.visible_worktrees_by_user_id.get(contact_id).unwrap();
535 assert!(visible_worktree_ids.contains(worktree_id));
536 }
537
538 if let Some(share) = &worktree.share {
539 for guest_connection_id in share.guests.keys() {
540 let guest_connection = self.connections.get(guest_connection_id).unwrap();
541 assert!(guest_connection.worktrees.contains(worktree_id));
542 }
543 assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
544 assert_eq!(
545 share.active_replica_ids,
546 share
547 .guests
548 .values()
549 .map(|(replica_id, _)| *replica_id)
550 .collect::<HashSet<_>>(),
551 );
552 }
553 }
554
555 for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
556 for worktree_id in visible_worktree_ids {
557 let worktree = self.worktrees.get(worktree_id).unwrap();
558 assert!(worktree.contact_user_ids.contains(user_id));
559 }
560 }
561
562 for (channel_id, channel) in &self.channels {
563 for connection_id in &channel.connection_ids {
564 let connection = self.connections.get(connection_id).unwrap();
565 assert!(connection.channels.contains(channel_id));
566 }
567 }
568 }
569}
570
571impl Worktree {
572 pub fn connection_ids(&self) -> Vec<ConnectionId> {
573 if let Some(share) = &self.share {
574 share
575 .guests
576 .keys()
577 .copied()
578 .chain(Some(self.host_connection_id))
579 .collect()
580 } else {
581 vec![self.host_connection_id]
582 }
583 }
584
585 pub fn share(&self) -> tide::Result<&WorktreeShare> {
586 Ok(self
587 .share
588 .as_ref()
589 .ok_or_else(|| anyhow!("worktree is not shared"))?)
590 }
591
592 fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
593 Ok(self
594 .share
595 .as_mut()
596 .ok_or_else(|| anyhow!("worktree is not shared"))?)
597 }
598}
599
600impl Channel {
601 fn connection_ids(&self) -> Vec<ConnectionId> {
602 self.connection_ids.iter().copied().collect()
603 }
604}