1use crate::db::{ChannelId, MessageId, UserId};
2use crate::errors::TideResultExt;
3use anyhow::anyhow;
4use std::collections::{hash_map, HashMap, HashSet};
5use zrpc::{proto, ConnectionId};
6
7#[derive(Default)]
8pub struct Store {
9 connections: HashMap<ConnectionId, ConnectionState>,
10 connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
11 worktrees: HashMap<u64, Worktree>,
12 visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
13 channels: HashMap<ChannelId, Channel>,
14 next_worktree_id: u64,
15}
16
17struct ConnectionState {
18 user_id: UserId,
19 worktrees: HashSet<u64>,
20 channels: HashSet<ChannelId>,
21}
22
23pub struct Worktree {
24 pub host_connection_id: ConnectionId,
25 pub collaborator_user_ids: Vec<UserId>,
26 pub root_name: String,
27 pub share: Option<WorktreeShare>,
28}
29
30struct WorktreeShare {
31 pub guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
32 pub active_replica_ids: HashSet<ReplicaId>,
33 pub entries: HashMap<u64, proto::Entry>,
34}
35
36#[derive(Default)]
37struct Channel {
38 connection_ids: HashSet<ConnectionId>,
39}
40
41pub type ReplicaId = u16;
42
43#[derive(Default)]
44pub struct RemovedConnectionState {
45 pub hosted_worktrees: HashMap<u64, Worktree>,
46 pub guest_worktree_ids: HashMap<u64, Vec<ConnectionId>>,
47 pub collaborator_ids: HashSet<UserId>,
48}
49
50impl Store {
51 pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
52 self.connections.insert(
53 connection_id,
54 ConnectionState {
55 user_id,
56 worktrees: Default::default(),
57 channels: Default::default(),
58 },
59 );
60 self.connections_by_user_id
61 .entry(user_id)
62 .or_default()
63 .insert(connection_id);
64 }
65
66 pub fn remove_connection(
67 &mut self,
68 connection_id: ConnectionId,
69 ) -> tide::Result<RemovedConnectionState> {
70 let connection = if let Some(connection) = self.connections.get(&connection_id) {
71 connection
72 } else {
73 return Err(anyhow!("no such connection"))?;
74 };
75
76 for channel_id in connection.channels {
77 if let Some(channel) = self.channels.get_mut(&channel_id) {
78 channel.connection_ids.remove(&connection_id);
79 }
80 }
81
82 let user_connections = self
83 .connections_by_user_id
84 .get_mut(&connection.user_id)
85 .unwrap();
86 user_connections.remove(&connection_id);
87 if user_connections.is_empty() {
88 self.connections_by_user_id.remove(&connection.user_id);
89 }
90
91 let mut result = RemovedConnectionState::default();
92 for worktree_id in connection.worktrees {
93 if let Ok(worktree) = self.remove_worktree(worktree_id, connection_id) {
94 result.hosted_worktrees.insert(worktree_id, worktree);
95 result
96 .collaborator_ids
97 .extend(worktree.collaborator_user_ids.iter().copied());
98 } else {
99 if let Some(worktree) = self.worktrees.get(&worktree_id) {
100 result
101 .guest_worktree_ids
102 .insert(worktree_id, worktree.connection_ids());
103 result
104 .collaborator_ids
105 .extend(worktree.collaborator_user_ids.iter().copied());
106 }
107 }
108 }
109
110 Ok(result)
111 }
112
113 pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
114 if let Some(connection) = self.connections.get_mut(&connection_id) {
115 connection.channels.insert(channel_id);
116 self.channels
117 .entry(channel_id)
118 .or_default()
119 .connection_ids
120 .insert(connection_id);
121 }
122 }
123
124 pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
125 if let Some(connection) = self.connections.get_mut(&connection_id) {
126 connection.channels.remove(&channel_id);
127 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
128 entry.get_mut().connection_ids.remove(&connection_id);
129 if entry.get_mut().connection_ids.is_empty() {
130 entry.remove();
131 }
132 }
133 }
134 }
135
136 pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
137 Ok(self
138 .connections
139 .get(&connection_id)
140 .ok_or_else(|| anyhow!("unknown connection"))?
141 .user_id)
142 }
143
144 pub fn connection_ids_for_user<'a>(
145 &'a self,
146 user_id: UserId,
147 ) -> impl 'a + Iterator<Item = ConnectionId> {
148 self.connections_by_user_id
149 .get(&user_id)
150 .into_iter()
151 .flatten()
152 .copied()
153 }
154
155 pub fn collaborators_for_user(&self, user_id: UserId) -> Vec<proto::Collaborator> {
156 let mut collaborators = HashMap::new();
157 for worktree_id in self
158 .visible_worktrees_by_user_id
159 .get(&user_id)
160 .unwrap_or(&HashSet::new())
161 {
162 let worktree = &self.worktrees[worktree_id];
163
164 let mut guests = HashSet::new();
165 if let Ok(share) = worktree.share() {
166 for guest_connection_id in share.guest_connection_ids.keys() {
167 if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
168 guests.insert(user_id.to_proto());
169 }
170 }
171 }
172
173 if let Ok(host_user_id) = self
174 .user_id_for_connection(worktree.host_connection_id)
175 .context("stale worktree host connection")
176 {
177 let host =
178 collaborators
179 .entry(host_user_id)
180 .or_insert_with(|| proto::Collaborator {
181 user_id: host_user_id.to_proto(),
182 worktrees: Vec::new(),
183 });
184 host.worktrees.push(proto::WorktreeMetadata {
185 root_name: worktree.root_name.clone(),
186 is_shared: worktree.share().is_ok(),
187 participants: guests.into_iter().collect(),
188 });
189 }
190 }
191
192 collaborators.into_values().collect()
193 }
194
195 pub fn add_worktree(&mut self, worktree: Worktree) -> u64 {
196 let worktree_id = self.next_worktree_id;
197 for collaborator_user_id in &worktree.collaborator_user_ids {
198 self.visible_worktrees_by_user_id
199 .entry(*collaborator_user_id)
200 .or_default()
201 .insert(worktree_id);
202 }
203 self.next_worktree_id += 1;
204 if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
205 connection.worktrees.insert(worktree_id);
206 }
207 self.worktrees.insert(worktree_id, worktree);
208
209 #[cfg(test)]
210 self.check_invariants();
211
212 worktree_id
213 }
214
215 pub fn remove_worktree(
216 &mut self,
217 worktree_id: u64,
218 acting_connection_id: ConnectionId,
219 ) -> tide::Result<Worktree> {
220 let worktree = if let hash_map::Entry::Occupied(e) = self.worktrees.entry(worktree_id) {
221 if e.get().host_connection_id != acting_connection_id {
222 Err(anyhow!("not your worktree"))?;
223 }
224 e.remove()
225 } else {
226 return Err(anyhow!("no such worktree"))?;
227 };
228
229 if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
230 connection.worktrees.remove(&worktree_id);
231 }
232
233 if let Some(share) = worktree.share {
234 for connection_id in share.guest_connection_ids.keys() {
235 if let Some(connection) = self.connections.get_mut(connection_id) {
236 connection.worktrees.remove(&worktree_id);
237 }
238 }
239 }
240
241 for collaborator_user_id in worktree.collaborator_user_ids {
242 if let Some(visible_worktrees) = self
243 .visible_worktrees_by_user_id
244 .get_mut(&collaborator_user_id)
245 {
246 visible_worktrees.remove(&worktree_id);
247 }
248 }
249
250 #[cfg(test)]
251 self.check_invariants();
252
253 Ok(worktree)
254 }
255
256 pub fn share_worktree(
257 &mut self,
258 worktree_id: u64,
259 connection_id: ConnectionId,
260 entries: HashMap<u64, proto::Entry>,
261 ) -> Option<Vec<UserId>> {
262 if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
263 if worktree.host_connection_id == connection_id {
264 worktree.share = Some(WorktreeShare {
265 guest_connection_ids: Default::default(),
266 active_replica_ids: Default::default(),
267 entries,
268 });
269 return Some(worktree.collaborator_user_ids.clone());
270 }
271 }
272 None
273 }
274
275 pub fn unshare_worktree(
276 &mut self,
277 worktree_id: u64,
278 acting_connection_id: ConnectionId,
279 ) -> tide::Result<(Vec<ConnectionId>, Vec<UserId>)> {
280 let worktree = if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
281 worktree
282 } else {
283 return Err(anyhow!("no such worktree"))?;
284 };
285
286 if worktree.host_connection_id != acting_connection_id {
287 return Err(anyhow!("not your worktree"))?;
288 }
289
290 let connection_ids = worktree.connection_ids();
291
292 if let Some(share) = worktree.share.take() {
293 for connection_id in &connection_ids {
294 if let Some(connection) = self.connections.get_mut(connection_id) {
295 connection.worktrees.remove(&worktree_id);
296 }
297 }
298 Ok((connection_ids, worktree.collaborator_user_ids.clone()))
299 } else {
300 Err(anyhow!("worktree is not shared"))?
301 }
302 }
303
304 pub fn join_worktree(
305 &mut self,
306 connection_id: ConnectionId,
307 user_id: UserId,
308 worktree_id: u64,
309 ) -> tide::Result<(ReplicaId, &Worktree)> {
310 let connection = self
311 .connections
312 .get_mut(&connection_id)
313 .ok_or_else(|| anyhow!("no such connection"))?;
314 let worktree = self
315 .worktrees
316 .get_mut(&worktree_id)
317 .and_then(|worktree| {
318 if worktree.collaborator_user_ids.contains(&user_id) {
319 Some(worktree)
320 } else {
321 None
322 }
323 })
324 .ok_or_else(|| anyhow!("no such worktree"))?;
325
326 let share = worktree.share_mut()?;
327 connection.worktrees.insert(worktree_id);
328
329 let mut replica_id = 1;
330 while share.active_replica_ids.contains(&replica_id) {
331 replica_id += 1;
332 }
333 share.active_replica_ids.insert(replica_id);
334 share.guest_connection_ids.insert(connection_id, replica_id);
335 return Ok((replica_id, worktree));
336 }
337
338 pub fn leave_worktree(
339 &mut self,
340 connection_id: ConnectionId,
341 worktree_id: u64,
342 ) -> Option<(Vec<ConnectionId>, Vec<UserId>)> {
343 let worktree = self.worktrees.get_mut(&worktree_id)?;
344 let share = worktree.share.as_mut()?;
345 let replica_id = share.guest_connection_ids.remove(&connection_id)?;
346 share.active_replica_ids.remove(&replica_id);
347 Some((
348 worktree.connection_ids(),
349 worktree.collaborator_user_ids.clone(),
350 ))
351 }
352
353 pub fn update_worktree(
354 &mut self,
355 connection_id: ConnectionId,
356 worktree_id: u64,
357 removed_entries: &[u64],
358 updated_entries: &[proto::Entry],
359 ) -> tide::Result<Vec<ConnectionId>> {
360 let worktree = self.write_worktree(worktree_id, connection_id)?;
361 let share = worktree.share_mut()?;
362 for entry_id in removed_entries {
363 share.entries.remove(&entry_id);
364 }
365 for entry in updated_entries {
366 share.entries.insert(entry.id, entry.clone());
367 }
368 Ok(worktree.connection_ids())
369 }
370
371 pub fn worktree_host_connection_id(
372 &self,
373 connection_id: ConnectionId,
374 worktree_id: u64,
375 ) -> tide::Result<ConnectionId> {
376 Ok(self
377 .read_worktree(worktree_id, connection_id)?
378 .host_connection_id)
379 }
380
381 pub fn worktree_guest_connection_ids(
382 &self,
383 connection_id: ConnectionId,
384 worktree_id: u64,
385 ) -> tide::Result<Vec<ConnectionId>> {
386 Ok(self
387 .read_worktree(worktree_id, connection_id)?
388 .share()?
389 .guest_connection_ids
390 .keys()
391 .copied()
392 .collect())
393 }
394
395 pub fn worktree_connection_ids(
396 &self,
397 connection_id: ConnectionId,
398 worktree_id: u64,
399 ) -> tide::Result<Vec<ConnectionId>> {
400 Ok(self
401 .read_worktree(worktree_id, connection_id)?
402 .connection_ids())
403 }
404
405 pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
406 Some(self.channels.get(&channel_id)?.connection_ids())
407 }
408
409 fn read_worktree(
410 &self,
411 worktree_id: u64,
412 connection_id: ConnectionId,
413 ) -> tide::Result<&Worktree> {
414 let worktree = self
415 .worktrees
416 .get(&worktree_id)
417 .ok_or_else(|| anyhow!("worktree not found"))?;
418
419 if worktree.host_connection_id == connection_id
420 || worktree
421 .share()?
422 .guest_connection_ids
423 .contains_key(&connection_id)
424 {
425 Ok(worktree)
426 } else {
427 Err(anyhow!(
428 "{} is not a member of worktree {}",
429 connection_id,
430 worktree_id
431 ))?
432 }
433 }
434
435 fn write_worktree(
436 &mut self,
437 worktree_id: u64,
438 connection_id: ConnectionId,
439 ) -> tide::Result<&mut Worktree> {
440 let worktree = self
441 .worktrees
442 .get_mut(&worktree_id)
443 .ok_or_else(|| anyhow!("worktree not found"))?;
444
445 if worktree.host_connection_id == connection_id
446 || worktree.share.as_ref().map_or(false, |share| {
447 share.guest_connection_ids.contains_key(&connection_id)
448 })
449 {
450 Ok(worktree)
451 } else {
452 Err(anyhow!(
453 "{} is not a member of worktree {}",
454 connection_id,
455 worktree_id
456 ))?
457 }
458 }
459
460 #[cfg(test)]
461 fn check_invariants(&self) {
462 for (connection_id, connection) in &self.connections {
463 for worktree_id in &connection.worktrees {
464 let worktree = &self.worktrees.get(&worktree_id).unwrap();
465 if worktree.host_connection_id != *connection_id {
466 assert!(worktree
467 .share()
468 .unwrap()
469 .guest_connection_ids
470 .contains_key(connection_id));
471 }
472 }
473 for channel_id in &connection.channels {
474 let channel = self.channels.get(channel_id).unwrap();
475 assert!(channel.connection_ids.contains(connection_id));
476 }
477 assert!(self
478 .connections_by_user_id
479 .get(&connection.user_id)
480 .unwrap()
481 .contains(connection_id));
482 }
483
484 for (user_id, connection_ids) in &self.connections_by_user_id {
485 for connection_id in connection_ids {
486 assert_eq!(
487 self.connections.get(connection_id).unwrap().user_id,
488 *user_id
489 );
490 }
491 }
492
493 for (worktree_id, worktree) in &self.worktrees {
494 let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
495 assert!(host_connection.worktrees.contains(worktree_id));
496
497 for collaborator_id in &worktree.collaborator_user_ids {
498 let visible_worktree_ids = self
499 .visible_worktrees_by_user_id
500 .get(collaborator_id)
501 .unwrap();
502 assert!(visible_worktree_ids.contains(worktree_id));
503 }
504
505 if let Some(share) = &worktree.share {
506 for guest_connection_id in share.guest_connection_ids.keys() {
507 let guest_connection = self.connections.get(guest_connection_id).unwrap();
508 assert!(guest_connection.worktrees.contains(worktree_id));
509 }
510 assert_eq!(
511 share.active_replica_ids.len(),
512 share.guest_connection_ids.len(),
513 );
514 assert_eq!(
515 share.active_replica_ids,
516 share
517 .guest_connection_ids
518 .values()
519 .copied()
520 .collect::<HashSet<_>>(),
521 );
522 }
523 }
524
525 for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
526 for worktree_id in visible_worktree_ids {
527 let worktree = self.worktrees.get(worktree_id).unwrap();
528 assert!(worktree.collaborator_user_ids.contains(user_id));
529 }
530 }
531
532 for (channel_id, channel) in &self.channels {
533 for connection_id in &channel.connection_ids {
534 let connection = self.connections.get(connection_id).unwrap();
535 assert!(connection.channels.contains(channel_id));
536 }
537 }
538 }
539}
540
541impl Worktree {
542 pub fn connection_ids(&self) -> Vec<ConnectionId> {
543 if let Some(share) = &self.share {
544 share
545 .guest_connection_ids
546 .keys()
547 .copied()
548 .chain(Some(self.host_connection_id))
549 .collect()
550 } else {
551 vec![self.host_connection_id]
552 }
553 }
554
555 pub fn share(&self) -> tide::Result<&WorktreeShare> {
556 Ok(self
557 .share
558 .as_ref()
559 .ok_or_else(|| anyhow!("worktree is not shared"))?)
560 }
561
562 fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
563 Ok(self
564 .share
565 .as_mut()
566 .ok_or_else(|| anyhow!("worktree is not shared"))?)
567 }
568}
569
570impl Channel {
571 fn connection_ids(&self) -> Vec<ConnectionId> {
572 self.connection_ids.iter().copied().collect()
573 }
574}