1use crate::db::{self, ChannelId, UserId};
2use anyhow::{anyhow, Result};
3use collections::{BTreeMap, HashMap, HashSet};
4use rpc::{proto, ConnectionId, Receipt};
5use std::{collections::hash_map, path::PathBuf};
6use tracing::instrument;
7
8#[derive(Default)]
9pub struct Store {
10 connections: HashMap<ConnectionId, ConnectionState>,
11 connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
12 projects: HashMap<u64, Project>,
13 channels: HashMap<ChannelId, Channel>,
14 next_project_id: u64,
15}
16
17struct ConnectionState {
18 user_id: UserId,
19 projects: HashSet<u64>,
20 requested_projects: HashSet<u64>,
21 channels: HashSet<ChannelId>,
22}
23
24pub struct Project {
25 pub host_connection_id: ConnectionId,
26 pub host_user_id: UserId,
27 pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
28 pub join_requests: HashMap<UserId, Vec<Receipt<proto::JoinProject>>>,
29 pub active_replica_ids: HashSet<ReplicaId>,
30 pub worktrees: HashMap<u64, Worktree>,
31 pub language_servers: Vec<proto::LanguageServer>,
32}
33
34#[derive(Default)]
35pub struct Worktree {
36 pub root_name: String,
37 pub visible: bool,
38 pub entries: HashMap<u64, proto::Entry>,
39 pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
40 pub scan_id: u64,
41}
42
43#[derive(Default)]
44pub struct Channel {
45 pub connection_ids: HashSet<ConnectionId>,
46}
47
48pub type ReplicaId = u16;
49
50#[derive(Default)]
51pub struct RemovedConnectionState {
52 pub user_id: UserId,
53 pub hosted_projects: HashMap<u64, Project>,
54 pub guest_project_ids: HashSet<u64>,
55 pub contact_ids: HashSet<UserId>,
56}
57
58pub struct LeftProject {
59 pub connection_ids: Vec<ConnectionId>,
60 pub host_user_id: UserId,
61 pub host_connection_id: ConnectionId,
62}
63
64#[derive(Copy, Clone)]
65pub struct Metrics {
66 pub connections: usize,
67 pub registered_projects: usize,
68 pub shared_projects: usize,
69}
70
71impl Store {
72 pub fn metrics(&self) -> Metrics {
73 let connections = self.connections.len();
74 let mut registered_projects = 0;
75 let mut shared_projects = 0;
76 for project in self.projects.values() {
77 registered_projects += 1;
78 if !project.guests.is_empty() {
79 shared_projects += 1;
80 }
81 }
82
83 Metrics {
84 connections,
85 registered_projects,
86 shared_projects,
87 }
88 }
89
90 #[instrument(skip(self))]
91 pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
92 self.connections.insert(
93 connection_id,
94 ConnectionState {
95 user_id,
96 projects: Default::default(),
97 requested_projects: Default::default(),
98 channels: Default::default(),
99 },
100 );
101 self.connections_by_user_id
102 .entry(user_id)
103 .or_default()
104 .insert(connection_id);
105 }
106
107 #[instrument(skip(self))]
108 pub fn remove_connection(
109 &mut self,
110 connection_id: ConnectionId,
111 ) -> Result<RemovedConnectionState> {
112 let connection = if let Some(connection) = self.connections.remove(&connection_id) {
113 connection
114 } else {
115 return Err(anyhow!("no such connection"))?;
116 };
117
118 for channel_id in &connection.channels {
119 if let Some(channel) = self.channels.get_mut(&channel_id) {
120 channel.connection_ids.remove(&connection_id);
121 }
122 }
123
124 let user_connections = self
125 .connections_by_user_id
126 .get_mut(&connection.user_id)
127 .unwrap();
128 user_connections.remove(&connection_id);
129 if user_connections.is_empty() {
130 self.connections_by_user_id.remove(&connection.user_id);
131 }
132
133 let mut result = RemovedConnectionState::default();
134 result.user_id = connection.user_id;
135 for project_id in connection.projects.clone() {
136 if let Ok(project) = self.unregister_project(project_id, connection_id) {
137 result.hosted_projects.insert(project_id, project);
138 } else if self.leave_project(connection_id, project_id).is_ok() {
139 result.guest_project_ids.insert(project_id);
140 }
141 }
142
143 Ok(result)
144 }
145
146 #[cfg(test)]
147 pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
148 self.channels.get(&id)
149 }
150
151 pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
152 if let Some(connection) = self.connections.get_mut(&connection_id) {
153 connection.channels.insert(channel_id);
154 self.channels
155 .entry(channel_id)
156 .or_default()
157 .connection_ids
158 .insert(connection_id);
159 }
160 }
161
162 pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
163 if let Some(connection) = self.connections.get_mut(&connection_id) {
164 connection.channels.remove(&channel_id);
165 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
166 entry.get_mut().connection_ids.remove(&connection_id);
167 if entry.get_mut().connection_ids.is_empty() {
168 entry.remove();
169 }
170 }
171 }
172 }
173
174 pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> Result<UserId> {
175 Ok(self
176 .connections
177 .get(&connection_id)
178 .ok_or_else(|| anyhow!("unknown connection"))?
179 .user_id)
180 }
181
182 pub fn connection_ids_for_user<'a>(
183 &'a self,
184 user_id: UserId,
185 ) -> impl 'a + Iterator<Item = ConnectionId> {
186 self.connections_by_user_id
187 .get(&user_id)
188 .into_iter()
189 .flatten()
190 .copied()
191 }
192
193 pub fn is_user_online(&self, user_id: UserId) -> bool {
194 !self
195 .connections_by_user_id
196 .get(&user_id)
197 .unwrap_or(&Default::default())
198 .is_empty()
199 }
200
201 pub fn build_initial_contacts_update(
202 &self,
203 contacts: Vec<db::Contact>,
204 ) -> proto::UpdateContacts {
205 let mut update = proto::UpdateContacts::default();
206
207 for contact in contacts {
208 match contact {
209 db::Contact::Accepted {
210 user_id,
211 should_notify,
212 } => {
213 update
214 .contacts
215 .push(self.contact_for_user(user_id, should_notify));
216 }
217 db::Contact::Outgoing { user_id } => {
218 update.outgoing_requests.push(user_id.to_proto())
219 }
220 db::Contact::Incoming {
221 user_id,
222 should_notify,
223 } => update
224 .incoming_requests
225 .push(proto::IncomingContactRequest {
226 requester_id: user_id.to_proto(),
227 should_notify,
228 }),
229 }
230 }
231
232 update
233 }
234
235 pub fn contact_for_user(&self, user_id: UserId, should_notify: bool) -> proto::Contact {
236 proto::Contact {
237 user_id: user_id.to_proto(),
238 projects: self.project_metadata_for_user(user_id),
239 online: self.is_user_online(user_id),
240 should_notify,
241 }
242 }
243
244 pub fn project_metadata_for_user(&self, user_id: UserId) -> Vec<proto::ProjectMetadata> {
245 let connection_ids = self.connections_by_user_id.get(&user_id);
246 let project_ids = connection_ids.iter().flat_map(|connection_ids| {
247 connection_ids
248 .iter()
249 .filter_map(|connection_id| self.connections.get(connection_id))
250 .flat_map(|connection| connection.projects.iter().copied())
251 });
252
253 let mut metadata = Vec::new();
254 for project_id in project_ids {
255 if let Some(project) = self.projects.get(&project_id) {
256 if project.host_user_id == user_id {
257 metadata.push(proto::ProjectMetadata {
258 id: project_id,
259 worktree_root_names: project
260 .worktrees
261 .values()
262 .map(|worktree| worktree.root_name.clone())
263 .collect(),
264 guests: project
265 .guests
266 .values()
267 .map(|(_, user_id)| user_id.to_proto())
268 .collect(),
269 });
270 }
271 }
272 }
273
274 metadata
275 }
276
277 pub fn register_project(
278 &mut self,
279 host_connection_id: ConnectionId,
280 host_user_id: UserId,
281 ) -> u64 {
282 let project_id = self.next_project_id;
283 self.projects.insert(
284 project_id,
285 Project {
286 host_connection_id,
287 host_user_id,
288 guests: Default::default(),
289 join_requests: Default::default(),
290 active_replica_ids: Default::default(),
291 worktrees: Default::default(),
292 language_servers: Default::default(),
293 },
294 );
295 if let Some(connection) = self.connections.get_mut(&host_connection_id) {
296 connection.projects.insert(project_id);
297 }
298 self.next_project_id += 1;
299 project_id
300 }
301
302 pub fn register_worktree(
303 &mut self,
304 project_id: u64,
305 worktree_id: u64,
306 connection_id: ConnectionId,
307 worktree: Worktree,
308 ) -> Result<()> {
309 let project = self
310 .projects
311 .get_mut(&project_id)
312 .ok_or_else(|| anyhow!("no such project"))?;
313 if project.host_connection_id == connection_id {
314 project.worktrees.insert(worktree_id, worktree);
315 Ok(())
316 } else {
317 Err(anyhow!("no such project"))?
318 }
319 }
320
321 pub fn unregister_project(
322 &mut self,
323 project_id: u64,
324 connection_id: ConnectionId,
325 ) -> Result<Project> {
326 match self.projects.entry(project_id) {
327 hash_map::Entry::Occupied(e) => {
328 if e.get().host_connection_id == connection_id {
329 let project = e.remove();
330
331 if let Some(host_connection) = self.connections.get_mut(&connection_id) {
332 host_connection.projects.remove(&project_id);
333 }
334
335 for guest_connection in project.guests.keys() {
336 if let Some(connection) = self.connections.get_mut(&guest_connection) {
337 connection.projects.remove(&project_id);
338 }
339 }
340
341 for requester_user_id in project.join_requests.keys() {
342 if let Some(requester_connection_ids) =
343 self.connections_by_user_id.get_mut(&requester_user_id)
344 {
345 for requester_connection_id in requester_connection_ids.iter() {
346 if let Some(requester_connection) =
347 self.connections.get_mut(requester_connection_id)
348 {
349 requester_connection.requested_projects.remove(&project_id);
350 }
351 }
352 }
353 }
354
355 Ok(project)
356 } else {
357 Err(anyhow!("no such project"))?
358 }
359 }
360 hash_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?,
361 }
362 }
363
364 pub fn unregister_worktree(
365 &mut self,
366 project_id: u64,
367 worktree_id: u64,
368 acting_connection_id: ConnectionId,
369 ) -> Result<(Worktree, Vec<ConnectionId>)> {
370 let project = self
371 .projects
372 .get_mut(&project_id)
373 .ok_or_else(|| anyhow!("no such project"))?;
374 if project.host_connection_id != acting_connection_id {
375 Err(anyhow!("not your worktree"))?;
376 }
377
378 let worktree = project
379 .worktrees
380 .remove(&worktree_id)
381 .ok_or_else(|| anyhow!("no such worktree"))?;
382 Ok((worktree, project.guest_connection_ids()))
383 }
384
385 pub fn update_diagnostic_summary(
386 &mut self,
387 project_id: u64,
388 worktree_id: u64,
389 connection_id: ConnectionId,
390 summary: proto::DiagnosticSummary,
391 ) -> Result<Vec<ConnectionId>> {
392 let project = self
393 .projects
394 .get_mut(&project_id)
395 .ok_or_else(|| anyhow!("no such project"))?;
396 if project.host_connection_id == connection_id {
397 let worktree = project
398 .worktrees
399 .get_mut(&worktree_id)
400 .ok_or_else(|| anyhow!("no such worktree"))?;
401 worktree
402 .diagnostic_summaries
403 .insert(summary.path.clone().into(), summary);
404 return Ok(project.connection_ids());
405 }
406
407 Err(anyhow!("no such worktree"))?
408 }
409
410 pub fn start_language_server(
411 &mut self,
412 project_id: u64,
413 connection_id: ConnectionId,
414 language_server: proto::LanguageServer,
415 ) -> Result<Vec<ConnectionId>> {
416 let project = self
417 .projects
418 .get_mut(&project_id)
419 .ok_or_else(|| anyhow!("no such project"))?;
420 if project.host_connection_id == connection_id {
421 project.language_servers.push(language_server);
422 return Ok(project.connection_ids());
423 }
424
425 Err(anyhow!("no such project"))?
426 }
427
428 pub fn request_join_project(
429 &mut self,
430 requester_id: UserId,
431 project_id: u64,
432 receipt: Receipt<proto::JoinProject>,
433 ) -> Result<()> {
434 let connection = self
435 .connections
436 .get_mut(&receipt.sender_id)
437 .ok_or_else(|| anyhow!("no such connection"))?;
438 let project = self
439 .projects
440 .get_mut(&project_id)
441 .ok_or_else(|| anyhow!("no such project"))?;
442 connection.requested_projects.insert(project_id);
443 project
444 .join_requests
445 .entry(requester_id)
446 .or_default()
447 .push(receipt);
448 Ok(())
449 }
450
451 pub fn deny_join_project_request(
452 &mut self,
453 responder_connection_id: ConnectionId,
454 requester_id: UserId,
455 project_id: u64,
456 ) -> Option<Vec<Receipt<proto::JoinProject>>> {
457 let project = self.projects.get_mut(&project_id)?;
458 if responder_connection_id != project.host_connection_id {
459 return None;
460 }
461
462 let receipts = project.join_requests.remove(&requester_id)?;
463 for receipt in &receipts {
464 let requester_connection = self.connections.get_mut(&receipt.sender_id)?;
465 requester_connection.requested_projects.remove(&project_id);
466 }
467 Some(receipts)
468 }
469
470 pub fn accept_join_project_request(
471 &mut self,
472 responder_connection_id: ConnectionId,
473 requester_id: UserId,
474 project_id: u64,
475 ) -> Option<(Vec<(Receipt<proto::JoinProject>, ReplicaId)>, &Project)> {
476 let project = self.projects.get_mut(&project_id)?;
477 if responder_connection_id != project.host_connection_id {
478 return None;
479 }
480
481 let receipts = project.join_requests.remove(&requester_id)?;
482 let mut receipts_with_replica_ids = Vec::new();
483 for receipt in receipts {
484 let requester_connection = self.connections.get_mut(&receipt.sender_id)?;
485 requester_connection.requested_projects.remove(&project_id);
486 requester_connection.projects.insert(project_id);
487 let mut replica_id = 1;
488 while project.active_replica_ids.contains(&replica_id) {
489 replica_id += 1;
490 }
491 project.active_replica_ids.insert(replica_id);
492 project
493 .guests
494 .insert(receipt.sender_id, (replica_id, requester_id));
495 receipts_with_replica_ids.push((receipt, replica_id));
496 }
497
498 Some((receipts_with_replica_ids, project))
499 }
500
501 pub fn leave_project(
502 &mut self,
503 connection_id: ConnectionId,
504 project_id: u64,
505 ) -> Result<LeftProject> {
506 let project = self
507 .projects
508 .get_mut(&project_id)
509 .ok_or_else(|| anyhow!("no such project"))?;
510 let (replica_id, _) = project
511 .guests
512 .remove(&connection_id)
513 .ok_or_else(|| anyhow!("cannot leave a project before joining it"))?;
514 project.active_replica_ids.remove(&replica_id);
515
516 if let Some(connection) = self.connections.get_mut(&connection_id) {
517 connection.projects.remove(&project_id);
518 }
519
520 Ok(LeftProject {
521 connection_ids: project.connection_ids(),
522 host_connection_id: project.host_connection_id,
523 host_user_id: project.host_user_id,
524 })
525 }
526
527 pub fn update_worktree(
528 &mut self,
529 connection_id: ConnectionId,
530 project_id: u64,
531 worktree_id: u64,
532 removed_entries: &[u64],
533 updated_entries: &[proto::Entry],
534 scan_id: u64,
535 ) -> Result<Vec<ConnectionId>> {
536 let project = self.write_project(project_id, connection_id)?;
537 let worktree = project
538 .worktrees
539 .get_mut(&worktree_id)
540 .ok_or_else(|| anyhow!("no such worktree"))?;
541 for entry_id in removed_entries {
542 worktree.entries.remove(&entry_id);
543 }
544 for entry in updated_entries {
545 worktree.entries.insert(entry.id, entry.clone());
546 }
547 worktree.scan_id = scan_id;
548 let connection_ids = project.connection_ids();
549 Ok(connection_ids)
550 }
551
552 pub fn project_connection_ids(
553 &self,
554 project_id: u64,
555 acting_connection_id: ConnectionId,
556 ) -> Result<Vec<ConnectionId>> {
557 Ok(self
558 .read_project(project_id, acting_connection_id)?
559 .connection_ids())
560 }
561
562 pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Result<Vec<ConnectionId>> {
563 Ok(self
564 .channels
565 .get(&channel_id)
566 .ok_or_else(|| anyhow!("no such channel"))?
567 .connection_ids())
568 }
569
570 pub fn project(&self, project_id: u64) -> Result<&Project> {
571 self.projects
572 .get(&project_id)
573 .ok_or_else(|| anyhow!("no such project"))
574 }
575
576 pub fn read_project(&self, project_id: u64, connection_id: ConnectionId) -> Result<&Project> {
577 let project = self
578 .projects
579 .get(&project_id)
580 .ok_or_else(|| anyhow!("no such project"))?;
581 if project.host_connection_id == connection_id
582 || project.guests.contains_key(&connection_id)
583 {
584 Ok(project)
585 } else {
586 Err(anyhow!("no such project"))?
587 }
588 }
589
590 fn write_project(
591 &mut self,
592 project_id: u64,
593 connection_id: ConnectionId,
594 ) -> Result<&mut Project> {
595 let project = self
596 .projects
597 .get_mut(&project_id)
598 .ok_or_else(|| anyhow!("no such project"))?;
599 if project.host_connection_id == connection_id
600 || project.guests.contains_key(&connection_id)
601 {
602 Ok(project)
603 } else {
604 Err(anyhow!("no such project"))?
605 }
606 }
607
608 #[cfg(test)]
609 pub fn check_invariants(&self) {
610 for (connection_id, connection) in &self.connections {
611 for project_id in &connection.projects {
612 let project = &self.projects.get(&project_id).unwrap();
613 if project.host_connection_id != *connection_id {
614 assert!(project.guests.contains_key(connection_id));
615 }
616
617 for (worktree_id, worktree) in project.worktrees.iter() {
618 let mut paths = HashMap::default();
619 for entry in worktree.entries.values() {
620 let prev_entry = paths.insert(&entry.path, entry);
621 assert_eq!(
622 prev_entry,
623 None,
624 "worktree {:?}, duplicate path for entries {:?} and {:?}",
625 worktree_id,
626 prev_entry.unwrap(),
627 entry
628 );
629 }
630 }
631 }
632 for channel_id in &connection.channels {
633 let channel = self.channels.get(channel_id).unwrap();
634 assert!(channel.connection_ids.contains(connection_id));
635 }
636 assert!(self
637 .connections_by_user_id
638 .get(&connection.user_id)
639 .unwrap()
640 .contains(connection_id));
641 }
642
643 for (user_id, connection_ids) in &self.connections_by_user_id {
644 for connection_id in connection_ids {
645 assert_eq!(
646 self.connections.get(connection_id).unwrap().user_id,
647 *user_id
648 );
649 }
650 }
651
652 for (project_id, project) in &self.projects {
653 let host_connection = self.connections.get(&project.host_connection_id).unwrap();
654 assert!(host_connection.projects.contains(project_id));
655
656 for guest_connection_id in project.guests.keys() {
657 let guest_connection = self.connections.get(guest_connection_id).unwrap();
658 assert!(guest_connection.projects.contains(project_id));
659 }
660 assert_eq!(project.active_replica_ids.len(), project.guests.len(),);
661 assert_eq!(
662 project.active_replica_ids,
663 project
664 .guests
665 .values()
666 .map(|(replica_id, _)| *replica_id)
667 .collect::<HashSet<_>>(),
668 );
669 }
670
671 for (channel_id, channel) in &self.channels {
672 for connection_id in &channel.connection_ids {
673 let connection = self.connections.get(connection_id).unwrap();
674 assert!(connection.channels.contains(channel_id));
675 }
676 }
677 }
678}
679
680impl Project {
681 pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
682 self.guests.keys().copied().collect()
683 }
684
685 pub fn connection_ids(&self) -> Vec<ConnectionId> {
686 self.guests
687 .keys()
688 .copied()
689 .chain(Some(self.host_connection_id))
690 .collect()
691 }
692}
693
694impl Channel {
695 fn connection_ids(&self) -> Vec<ConnectionId> {
696 self.connection_ids.iter().copied().collect()
697 }
698}