1use crate::db::{self, ChannelId, UserId};
2use anyhow::{anyhow, Result};
3use collections::{hash_map::Entry, BTreeMap, HashMap, HashSet};
4use rpc::{proto, ConnectionId, Receipt};
5use serde::Serialize;
6use std::{collections::hash_map, mem, path::PathBuf};
7use tracing::instrument;
8
9#[derive(Default, Serialize)]
10pub struct Store {
11 connections: HashMap<ConnectionId, ConnectionState>,
12 connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
13 projects: HashMap<u64, Project>,
14 #[serde(skip)]
15 channels: HashMap<ChannelId, Channel>,
16 next_project_id: u64,
17}
18
19#[derive(Serialize)]
20struct ConnectionState {
21 user_id: UserId,
22 projects: HashSet<u64>,
23 requested_projects: HashSet<u64>,
24 channels: HashSet<ChannelId>,
25}
26
27#[derive(Serialize)]
28pub struct Project {
29 pub host_connection_id: ConnectionId,
30 pub host_user_id: UserId,
31 pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
32 #[serde(skip)]
33 pub join_requests: HashMap<UserId, Vec<Receipt<proto::JoinProject>>>,
34 pub active_replica_ids: HashSet<ReplicaId>,
35 pub worktrees: HashMap<u64, Worktree>,
36 pub language_servers: Vec<proto::LanguageServer>,
37}
38
39#[derive(Default, Serialize)]
40pub struct Worktree {
41 pub root_name: String,
42 pub visible: bool,
43 #[serde(skip)]
44 pub entries: HashMap<u64, proto::Entry>,
45 #[serde(skip)]
46 pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
47 pub scan_id: u64,
48}
49
50#[derive(Default)]
51pub struct Channel {
52 pub connection_ids: HashSet<ConnectionId>,
53}
54
55pub type ReplicaId = u16;
56
57#[derive(Default)]
58pub struct RemovedConnectionState {
59 pub user_id: UserId,
60 pub hosted_projects: HashMap<u64, Project>,
61 pub guest_project_ids: HashSet<u64>,
62 pub contact_ids: HashSet<UserId>,
63}
64
65pub struct LeftProject {
66 pub host_user_id: UserId,
67 pub host_connection_id: ConnectionId,
68 pub connection_ids: Vec<ConnectionId>,
69 pub remove_collaborator: bool,
70 pub cancel_request: Option<UserId>,
71 pub unshare: bool,
72}
73
74#[derive(Copy, Clone)]
75pub struct Metrics {
76 pub connections: usize,
77 pub registered_projects: usize,
78 pub shared_projects: usize,
79}
80
81impl Store {
82 pub fn metrics(&self) -> Metrics {
83 let connections = self.connections.len();
84 let mut registered_projects = 0;
85 let mut shared_projects = 0;
86 for project in self.projects.values() {
87 registered_projects += 1;
88 if !project.guests.is_empty() {
89 shared_projects += 1;
90 }
91 }
92
93 Metrics {
94 connections,
95 registered_projects,
96 shared_projects,
97 }
98 }
99
100 #[instrument(skip(self))]
101 pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
102 self.connections.insert(
103 connection_id,
104 ConnectionState {
105 user_id,
106 projects: Default::default(),
107 requested_projects: Default::default(),
108 channels: Default::default(),
109 },
110 );
111 self.connections_by_user_id
112 .entry(user_id)
113 .or_default()
114 .insert(connection_id);
115 }
116
117 #[instrument(skip(self))]
118 pub fn remove_connection(
119 &mut self,
120 connection_id: ConnectionId,
121 ) -> Result<RemovedConnectionState> {
122 let connection = self
123 .connections
124 .get_mut(&connection_id)
125 .ok_or_else(|| anyhow!("no such connection"))?;
126
127 let user_id = connection.user_id;
128 let connection_projects = mem::take(&mut connection.projects);
129 let connection_channels = mem::take(&mut connection.channels);
130
131 let mut result = RemovedConnectionState::default();
132 result.user_id = user_id;
133
134 // Leave all channels.
135 for channel_id in connection_channels {
136 self.leave_channel(connection_id, channel_id);
137 }
138
139 // Unregister and leave all projects.
140 for project_id in connection_projects {
141 if let Ok(project) = self.unregister_project(project_id, connection_id) {
142 result.hosted_projects.insert(project_id, project);
143 } else if self.leave_project(connection_id, project_id).is_ok() {
144 result.guest_project_ids.insert(project_id);
145 }
146 }
147
148 let user_connections = self.connections_by_user_id.get_mut(&user_id).unwrap();
149 user_connections.remove(&connection_id);
150 if user_connections.is_empty() {
151 self.connections_by_user_id.remove(&user_id);
152 }
153
154 self.connections.remove(&connection_id).unwrap();
155
156 Ok(result)
157 }
158
159 #[cfg(test)]
160 pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
161 self.channels.get(&id)
162 }
163
164 pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
165 if let Some(connection) = self.connections.get_mut(&connection_id) {
166 connection.channels.insert(channel_id);
167 self.channels
168 .entry(channel_id)
169 .or_default()
170 .connection_ids
171 .insert(connection_id);
172 }
173 }
174
175 pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
176 if let Some(connection) = self.connections.get_mut(&connection_id) {
177 connection.channels.remove(&channel_id);
178 if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
179 entry.get_mut().connection_ids.remove(&connection_id);
180 if entry.get_mut().connection_ids.is_empty() {
181 entry.remove();
182 }
183 }
184 }
185 }
186
187 pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> Result<UserId> {
188 Ok(self
189 .connections
190 .get(&connection_id)
191 .ok_or_else(|| anyhow!("unknown connection"))?
192 .user_id)
193 }
194
195 pub fn connection_ids_for_user<'a>(
196 &'a self,
197 user_id: UserId,
198 ) -> impl 'a + Iterator<Item = ConnectionId> {
199 self.connections_by_user_id
200 .get(&user_id)
201 .into_iter()
202 .flatten()
203 .copied()
204 }
205
206 pub fn is_user_online(&self, user_id: UserId) -> bool {
207 !self
208 .connections_by_user_id
209 .get(&user_id)
210 .unwrap_or(&Default::default())
211 .is_empty()
212 }
213
214 pub fn build_initial_contacts_update(
215 &self,
216 contacts: Vec<db::Contact>,
217 ) -> proto::UpdateContacts {
218 let mut update = proto::UpdateContacts::default();
219
220 for contact in contacts {
221 match contact {
222 db::Contact::Accepted {
223 user_id,
224 should_notify,
225 } => {
226 update
227 .contacts
228 .push(self.contact_for_user(user_id, should_notify));
229 }
230 db::Contact::Outgoing { user_id } => {
231 update.outgoing_requests.push(user_id.to_proto())
232 }
233 db::Contact::Incoming {
234 user_id,
235 should_notify,
236 } => update
237 .incoming_requests
238 .push(proto::IncomingContactRequest {
239 requester_id: user_id.to_proto(),
240 should_notify,
241 }),
242 }
243 }
244
245 update
246 }
247
248 pub fn contact_for_user(&self, user_id: UserId, should_notify: bool) -> proto::Contact {
249 proto::Contact {
250 user_id: user_id.to_proto(),
251 projects: self.project_metadata_for_user(user_id),
252 online: self.is_user_online(user_id),
253 should_notify,
254 }
255 }
256
257 pub fn project_metadata_for_user(&self, user_id: UserId) -> Vec<proto::ProjectMetadata> {
258 let connection_ids = self.connections_by_user_id.get(&user_id);
259 let project_ids = connection_ids.iter().flat_map(|connection_ids| {
260 connection_ids
261 .iter()
262 .filter_map(|connection_id| self.connections.get(connection_id))
263 .flat_map(|connection| connection.projects.iter().copied())
264 });
265
266 let mut metadata = Vec::new();
267 for project_id in project_ids {
268 if let Some(project) = self.projects.get(&project_id) {
269 if project.host_user_id == user_id {
270 metadata.push(proto::ProjectMetadata {
271 id: project_id,
272 worktree_root_names: project
273 .worktrees
274 .values()
275 .map(|worktree| worktree.root_name.clone())
276 .collect(),
277 guests: project
278 .guests
279 .values()
280 .map(|(_, user_id)| user_id.to_proto())
281 .collect(),
282 });
283 }
284 }
285 }
286
287 metadata
288 }
289
290 pub fn register_project(
291 &mut self,
292 host_connection_id: ConnectionId,
293 host_user_id: UserId,
294 ) -> u64 {
295 let project_id = self.next_project_id;
296 self.projects.insert(
297 project_id,
298 Project {
299 host_connection_id,
300 host_user_id,
301 guests: Default::default(),
302 join_requests: Default::default(),
303 active_replica_ids: Default::default(),
304 worktrees: Default::default(),
305 language_servers: Default::default(),
306 },
307 );
308 if let Some(connection) = self.connections.get_mut(&host_connection_id) {
309 connection.projects.insert(project_id);
310 }
311 self.next_project_id += 1;
312 project_id
313 }
314
315 pub fn register_worktree(
316 &mut self,
317 project_id: u64,
318 worktree_id: u64,
319 connection_id: ConnectionId,
320 worktree: Worktree,
321 ) -> Result<()> {
322 let project = self
323 .projects
324 .get_mut(&project_id)
325 .ok_or_else(|| anyhow!("no such project"))?;
326 if project.host_connection_id == connection_id {
327 project.worktrees.insert(worktree_id, worktree);
328 Ok(())
329 } else {
330 Err(anyhow!("no such project"))?
331 }
332 }
333
334 pub fn unregister_project(
335 &mut self,
336 project_id: u64,
337 connection_id: ConnectionId,
338 ) -> Result<Project> {
339 match self.projects.entry(project_id) {
340 hash_map::Entry::Occupied(e) => {
341 if e.get().host_connection_id == connection_id {
342 let project = e.remove();
343
344 if let Some(host_connection) = self.connections.get_mut(&connection_id) {
345 host_connection.projects.remove(&project_id);
346 }
347
348 for guest_connection in project.guests.keys() {
349 if let Some(connection) = self.connections.get_mut(&guest_connection) {
350 connection.projects.remove(&project_id);
351 }
352 }
353
354 for requester_user_id in project.join_requests.keys() {
355 if let Some(requester_connection_ids) =
356 self.connections_by_user_id.get_mut(&requester_user_id)
357 {
358 for requester_connection_id in requester_connection_ids.iter() {
359 if let Some(requester_connection) =
360 self.connections.get_mut(requester_connection_id)
361 {
362 requester_connection.requested_projects.remove(&project_id);
363 }
364 }
365 }
366 }
367
368 Ok(project)
369 } else {
370 Err(anyhow!("no such project"))?
371 }
372 }
373 hash_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?,
374 }
375 }
376
377 pub fn unregister_worktree(
378 &mut self,
379 project_id: u64,
380 worktree_id: u64,
381 acting_connection_id: ConnectionId,
382 ) -> Result<(Worktree, Vec<ConnectionId>)> {
383 let project = self
384 .projects
385 .get_mut(&project_id)
386 .ok_or_else(|| anyhow!("no such project"))?;
387 if project.host_connection_id != acting_connection_id {
388 Err(anyhow!("not your worktree"))?;
389 }
390
391 let worktree = project
392 .worktrees
393 .remove(&worktree_id)
394 .ok_or_else(|| anyhow!("no such worktree"))?;
395 Ok((worktree, project.guest_connection_ids()))
396 }
397
398 pub fn update_diagnostic_summary(
399 &mut self,
400 project_id: u64,
401 worktree_id: u64,
402 connection_id: ConnectionId,
403 summary: proto::DiagnosticSummary,
404 ) -> Result<Vec<ConnectionId>> {
405 let project = self
406 .projects
407 .get_mut(&project_id)
408 .ok_or_else(|| anyhow!("no such project"))?;
409 if project.host_connection_id == connection_id {
410 let worktree = project
411 .worktrees
412 .get_mut(&worktree_id)
413 .ok_or_else(|| anyhow!("no such worktree"))?;
414 worktree
415 .diagnostic_summaries
416 .insert(summary.path.clone().into(), summary);
417 return Ok(project.connection_ids());
418 }
419
420 Err(anyhow!("no such worktree"))?
421 }
422
423 pub fn start_language_server(
424 &mut self,
425 project_id: u64,
426 connection_id: ConnectionId,
427 language_server: proto::LanguageServer,
428 ) -> Result<Vec<ConnectionId>> {
429 let project = self
430 .projects
431 .get_mut(&project_id)
432 .ok_or_else(|| anyhow!("no such project"))?;
433 if project.host_connection_id == connection_id {
434 project.language_servers.push(language_server);
435 return Ok(project.connection_ids());
436 }
437
438 Err(anyhow!("no such project"))?
439 }
440
441 pub fn request_join_project(
442 &mut self,
443 requester_id: UserId,
444 project_id: u64,
445 receipt: Receipt<proto::JoinProject>,
446 ) -> Result<()> {
447 let connection = self
448 .connections
449 .get_mut(&receipt.sender_id)
450 .ok_or_else(|| anyhow!("no such connection"))?;
451 let project = self
452 .projects
453 .get_mut(&project_id)
454 .ok_or_else(|| anyhow!("no such project"))?;
455 connection.requested_projects.insert(project_id);
456 project
457 .join_requests
458 .entry(requester_id)
459 .or_default()
460 .push(receipt);
461 Ok(())
462 }
463
464 pub fn deny_join_project_request(
465 &mut self,
466 responder_connection_id: ConnectionId,
467 requester_id: UserId,
468 project_id: u64,
469 ) -> Option<Vec<Receipt<proto::JoinProject>>> {
470 let project = self.projects.get_mut(&project_id)?;
471 if responder_connection_id != project.host_connection_id {
472 return None;
473 }
474
475 let receipts = project.join_requests.remove(&requester_id)?;
476 for receipt in &receipts {
477 let requester_connection = self.connections.get_mut(&receipt.sender_id)?;
478 requester_connection.requested_projects.remove(&project_id);
479 }
480 Some(receipts)
481 }
482
483 pub fn accept_join_project_request(
484 &mut self,
485 responder_connection_id: ConnectionId,
486 requester_id: UserId,
487 project_id: u64,
488 ) -> Option<(Vec<(Receipt<proto::JoinProject>, ReplicaId)>, &Project)> {
489 let project = self.projects.get_mut(&project_id)?;
490 if responder_connection_id != project.host_connection_id {
491 return None;
492 }
493
494 let receipts = project.join_requests.remove(&requester_id)?;
495 let mut receipts_with_replica_ids = Vec::new();
496 for receipt in receipts {
497 let requester_connection = self.connections.get_mut(&receipt.sender_id)?;
498 requester_connection.requested_projects.remove(&project_id);
499 requester_connection.projects.insert(project_id);
500 let mut replica_id = 1;
501 while project.active_replica_ids.contains(&replica_id) {
502 replica_id += 1;
503 }
504 project.active_replica_ids.insert(replica_id);
505 project
506 .guests
507 .insert(receipt.sender_id, (replica_id, requester_id));
508 receipts_with_replica_ids.push((receipt, replica_id));
509 }
510
511 Some((receipts_with_replica_ids, project))
512 }
513
514 pub fn leave_project(
515 &mut self,
516 connection_id: ConnectionId,
517 project_id: u64,
518 ) -> Result<LeftProject> {
519 let user_id = self.user_id_for_connection(connection_id)?;
520 let project = self
521 .projects
522 .get_mut(&project_id)
523 .ok_or_else(|| anyhow!("no such project"))?;
524
525 // If the connection leaving the project is a collaborator, remove it.
526 let remove_collaborator =
527 if let Some((replica_id, _)) = project.guests.remove(&connection_id) {
528 project.active_replica_ids.remove(&replica_id);
529 true
530 } else {
531 false
532 };
533
534 // If the connection leaving the project has a pending request, remove it.
535 // If that user has no other pending requests on other connections, indicate that the request should be cancelled.
536 let mut cancel_request = None;
537 if let Entry::Occupied(mut entry) = project.join_requests.entry(user_id) {
538 entry
539 .get_mut()
540 .retain(|receipt| receipt.sender_id != connection_id);
541 if entry.get().is_empty() {
542 entry.remove();
543 cancel_request = Some(user_id);
544 }
545 }
546
547 if let Some(connection) = self.connections.get_mut(&connection_id) {
548 connection.projects.remove(&project_id);
549 }
550
551 let connection_ids = project.connection_ids();
552 let unshare = connection_ids.len() <= 1 && project.join_requests.is_empty();
553 if unshare {
554 project.language_servers.clear();
555 for worktree in project.worktrees.values_mut() {
556 worktree.diagnostic_summaries.clear();
557 worktree.entries.clear();
558 }
559 }
560
561 Ok(LeftProject {
562 host_connection_id: project.host_connection_id,
563 host_user_id: project.host_user_id,
564 connection_ids,
565 cancel_request,
566 unshare,
567 remove_collaborator,
568 })
569 }
570
571 pub fn update_worktree(
572 &mut self,
573 connection_id: ConnectionId,
574 project_id: u64,
575 worktree_id: u64,
576 removed_entries: &[u64],
577 updated_entries: &[proto::Entry],
578 scan_id: u64,
579 ) -> Result<Vec<ConnectionId>> {
580 let project = self.write_project(project_id, connection_id)?;
581 let worktree = project
582 .worktrees
583 .get_mut(&worktree_id)
584 .ok_or_else(|| anyhow!("no such worktree"))?;
585 for entry_id in removed_entries {
586 worktree.entries.remove(&entry_id);
587 }
588 for entry in updated_entries {
589 worktree.entries.insert(entry.id, entry.clone());
590 }
591 worktree.scan_id = scan_id;
592 let connection_ids = project.connection_ids();
593 Ok(connection_ids)
594 }
595
596 pub fn project_connection_ids(
597 &self,
598 project_id: u64,
599 acting_connection_id: ConnectionId,
600 ) -> Result<Vec<ConnectionId>> {
601 Ok(self
602 .read_project(project_id, acting_connection_id)?
603 .connection_ids())
604 }
605
606 pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Result<Vec<ConnectionId>> {
607 Ok(self
608 .channels
609 .get(&channel_id)
610 .ok_or_else(|| anyhow!("no such channel"))?
611 .connection_ids())
612 }
613
614 pub fn project(&self, project_id: u64) -> Result<&Project> {
615 self.projects
616 .get(&project_id)
617 .ok_or_else(|| anyhow!("no such project"))
618 }
619
620 pub fn read_project(&self, project_id: u64, connection_id: ConnectionId) -> Result<&Project> {
621 let project = self
622 .projects
623 .get(&project_id)
624 .ok_or_else(|| anyhow!("no such project"))?;
625 if project.host_connection_id == connection_id
626 || project.guests.contains_key(&connection_id)
627 {
628 Ok(project)
629 } else {
630 Err(anyhow!("no such project"))?
631 }
632 }
633
634 fn write_project(
635 &mut self,
636 project_id: u64,
637 connection_id: ConnectionId,
638 ) -> Result<&mut Project> {
639 let project = self
640 .projects
641 .get_mut(&project_id)
642 .ok_or_else(|| anyhow!("no such project"))?;
643 if project.host_connection_id == connection_id
644 || project.guests.contains_key(&connection_id)
645 {
646 Ok(project)
647 } else {
648 Err(anyhow!("no such project"))?
649 }
650 }
651
652 #[cfg(test)]
653 pub fn check_invariants(&self) {
654 for (connection_id, connection) in &self.connections {
655 for project_id in &connection.projects {
656 let project = &self.projects.get(&project_id).unwrap();
657 if project.host_connection_id != *connection_id {
658 assert!(project.guests.contains_key(connection_id));
659 }
660
661 for (worktree_id, worktree) in project.worktrees.iter() {
662 let mut paths = HashMap::default();
663 for entry in worktree.entries.values() {
664 let prev_entry = paths.insert(&entry.path, entry);
665 assert_eq!(
666 prev_entry,
667 None,
668 "worktree {:?}, duplicate path for entries {:?} and {:?}",
669 worktree_id,
670 prev_entry.unwrap(),
671 entry
672 );
673 }
674 }
675 }
676 for channel_id in &connection.channels {
677 let channel = self.channels.get(channel_id).unwrap();
678 assert!(channel.connection_ids.contains(connection_id));
679 }
680 assert!(self
681 .connections_by_user_id
682 .get(&connection.user_id)
683 .unwrap()
684 .contains(connection_id));
685 }
686
687 for (user_id, connection_ids) in &self.connections_by_user_id {
688 for connection_id in connection_ids {
689 assert_eq!(
690 self.connections.get(connection_id).unwrap().user_id,
691 *user_id
692 );
693 }
694 }
695
696 for (project_id, project) in &self.projects {
697 let host_connection = self.connections.get(&project.host_connection_id).unwrap();
698 assert!(host_connection.projects.contains(project_id));
699
700 for guest_connection_id in project.guests.keys() {
701 let guest_connection = self.connections.get(guest_connection_id).unwrap();
702 assert!(guest_connection.projects.contains(project_id));
703 }
704 assert_eq!(project.active_replica_ids.len(), project.guests.len(),);
705 assert_eq!(
706 project.active_replica_ids,
707 project
708 .guests
709 .values()
710 .map(|(replica_id, _)| *replica_id)
711 .collect::<HashSet<_>>(),
712 );
713 }
714
715 for (channel_id, channel) in &self.channels {
716 for connection_id in &channel.connection_ids {
717 let connection = self.connections.get(connection_id).unwrap();
718 assert!(connection.channels.contains(channel_id));
719 }
720 }
721 }
722}
723
724impl Project {
725 pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
726 self.guests.keys().copied().collect()
727 }
728
729 pub fn connection_ids(&self) -> Vec<ConnectionId> {
730 self.guests
731 .keys()
732 .copied()
733 .chain(Some(self.host_connection_id))
734 .collect()
735 }
736}
737
738impl Channel {
739 fn connection_ids(&self) -> Vec<ConnectionId> {
740 self.connection_ids.iter().copied().collect()
741 }
742}