1mod store;
2
3use crate::{
4 auth,
5 db::{self, ChannelId, MessageId, User, UserId},
6 AppState, Result,
7};
8use anyhow::anyhow;
9use async_tungstenite::tungstenite::{
10 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
11};
12use axum::{
13 body::Body,
14 extract::{
15 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
16 ConnectInfo, WebSocketUpgrade,
17 },
18 headers::{Header, HeaderName},
19 http::StatusCode,
20 middleware,
21 response::IntoResponse,
22 routing::get,
23 Extension, Router, TypedHeader,
24};
25use collections::HashMap;
26use futures::{
27 channel::mpsc,
28 future::{self, BoxFuture},
29 FutureExt, SinkExt, StreamExt, TryStreamExt,
30};
31use lazy_static::lazy_static;
32use prometheus::{register_int_gauge, IntGauge};
33use rpc::{
34 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
35 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
36};
37use serde::{Serialize, Serializer};
38use std::{
39 any::TypeId,
40 future::Future,
41 marker::PhantomData,
42 net::SocketAddr,
43 ops::{Deref, DerefMut},
44 rc::Rc,
45 sync::{
46 atomic::{AtomicBool, Ordering::SeqCst},
47 Arc,
48 },
49 time::Duration,
50};
51use time::OffsetDateTime;
52use tokio::{
53 sync::{RwLock, RwLockReadGuard, RwLockWriteGuard},
54 time::Sleep,
55};
56use tower::ServiceBuilder;
57use tracing::{info_span, instrument, Instrument};
58
59pub use store::{Store, Worktree};
60
61lazy_static! {
62 static ref METRIC_CONNECTIONS: IntGauge =
63 register_int_gauge!("connections", "number of connections").unwrap();
64 static ref METRIC_PROJECTS: IntGauge =
65 register_int_gauge!("projects", "number of open projects").unwrap();
66 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
67 "shared_projects",
68 "number of open projects with one or more guests"
69 )
70 .unwrap();
71}
72
73type MessageHandler =
74 Box<dyn Send + Sync + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>>;
75
76struct Response<R> {
77 server: Arc<Server>,
78 receipt: Receipt<R>,
79 responded: Arc<AtomicBool>,
80}
81
82impl<R: RequestMessage> Response<R> {
83 fn send(self, payload: R::Response) -> Result<()> {
84 self.responded.store(true, SeqCst);
85 self.server.peer.respond(self.receipt, payload)?;
86 Ok(())
87 }
88
89 fn into_receipt(self) -> Receipt<R> {
90 self.responded.store(true, SeqCst);
91 self.receipt
92 }
93}
94
95pub struct Server {
96 peer: Arc<Peer>,
97 pub(crate) store: RwLock<Store>,
98 app_state: Arc<AppState>,
99 handlers: HashMap<TypeId, MessageHandler>,
100 notifications: Option<mpsc::UnboundedSender<()>>,
101}
102
103pub trait Executor: Send + Clone {
104 type Sleep: Send + Future;
105 fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
106 fn sleep(&self, duration: Duration) -> Self::Sleep;
107}
108
109#[derive(Clone)]
110pub struct RealExecutor;
111
112const MESSAGE_COUNT_PER_PAGE: usize = 100;
113const MAX_MESSAGE_LEN: usize = 1024;
114
115struct StoreReadGuard<'a> {
116 guard: RwLockReadGuard<'a, Store>,
117 _not_send: PhantomData<Rc<()>>,
118}
119
120struct StoreWriteGuard<'a> {
121 guard: RwLockWriteGuard<'a, Store>,
122 _not_send: PhantomData<Rc<()>>,
123}
124
125#[derive(Serialize)]
126pub struct ServerSnapshot<'a> {
127 peer: &'a Peer,
128 #[serde(serialize_with = "serialize_deref")]
129 store: RwLockReadGuard<'a, Store>,
130}
131
132pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
133where
134 S: Serializer,
135 T: Deref<Target = U>,
136 U: Serialize,
137{
138 Serialize::serialize(value.deref(), serializer)
139}
140
141impl Server {
142 pub fn new(
143 app_state: Arc<AppState>,
144 notifications: Option<mpsc::UnboundedSender<()>>,
145 ) -> Arc<Self> {
146 let mut server = Self {
147 peer: Peer::new(),
148 app_state,
149 store: Default::default(),
150 handlers: Default::default(),
151 notifications,
152 };
153
154 server
155 .add_request_handler(Server::ping)
156 .add_request_handler(Server::register_project)
157 .add_request_handler(Server::unregister_project)
158 .add_request_handler(Server::join_project)
159 .add_message_handler(Server::leave_project)
160 .add_message_handler(Server::respond_to_join_project_request)
161 .add_message_handler(Server::update_project)
162 .add_request_handler(Server::update_worktree)
163 .add_message_handler(Server::start_language_server)
164 .add_message_handler(Server::update_language_server)
165 .add_message_handler(Server::update_diagnostic_summary)
166 .add_request_handler(Server::forward_project_request::<proto::GetHover>)
167 .add_request_handler(Server::forward_project_request::<proto::GetDefinition>)
168 .add_request_handler(Server::forward_project_request::<proto::GetReferences>)
169 .add_request_handler(Server::forward_project_request::<proto::SearchProject>)
170 .add_request_handler(Server::forward_project_request::<proto::GetDocumentHighlights>)
171 .add_request_handler(Server::forward_project_request::<proto::GetProjectSymbols>)
172 .add_request_handler(Server::forward_project_request::<proto::OpenBufferForSymbol>)
173 .add_request_handler(Server::forward_project_request::<proto::OpenBufferById>)
174 .add_request_handler(Server::forward_project_request::<proto::OpenBufferByPath>)
175 .add_request_handler(Server::forward_project_request::<proto::GetCompletions>)
176 .add_request_handler(
177 Server::forward_project_request::<proto::ApplyCompletionAdditionalEdits>,
178 )
179 .add_request_handler(Server::forward_project_request::<proto::GetCodeActions>)
180 .add_request_handler(Server::forward_project_request::<proto::ApplyCodeAction>)
181 .add_request_handler(Server::forward_project_request::<proto::PrepareRename>)
182 .add_request_handler(Server::forward_project_request::<proto::PerformRename>)
183 .add_request_handler(Server::forward_project_request::<proto::ReloadBuffers>)
184 .add_request_handler(Server::forward_project_request::<proto::FormatBuffers>)
185 .add_request_handler(Server::forward_project_request::<proto::CreateProjectEntry>)
186 .add_request_handler(Server::forward_project_request::<proto::RenameProjectEntry>)
187 .add_request_handler(Server::forward_project_request::<proto::CopyProjectEntry>)
188 .add_request_handler(Server::forward_project_request::<proto::DeleteProjectEntry>)
189 .add_request_handler(Server::update_buffer)
190 .add_message_handler(Server::update_buffer_file)
191 .add_message_handler(Server::buffer_reloaded)
192 .add_message_handler(Server::buffer_saved)
193 .add_request_handler(Server::save_buffer)
194 .add_request_handler(Server::get_channels)
195 .add_request_handler(Server::get_users)
196 .add_request_handler(Server::fuzzy_search_users)
197 .add_request_handler(Server::request_contact)
198 .add_request_handler(Server::remove_contact)
199 .add_request_handler(Server::respond_to_contact_request)
200 .add_request_handler(Server::join_channel)
201 .add_message_handler(Server::leave_channel)
202 .add_request_handler(Server::send_channel_message)
203 .add_request_handler(Server::follow)
204 .add_message_handler(Server::unfollow)
205 .add_message_handler(Server::update_followers)
206 .add_request_handler(Server::get_channel_messages);
207
208 Arc::new(server)
209 }
210
211 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
212 where
213 F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
214 Fut: 'static + Send + Future<Output = Result<()>>,
215 M: EnvelopedMessage,
216 {
217 let prev_handler = self.handlers.insert(
218 TypeId::of::<M>(),
219 Box::new(move |server, envelope| {
220 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
221 let span = info_span!(
222 "handle message",
223 payload_type = envelope.payload_type_name(),
224 payload = format!("{:?}", envelope.payload).as_str(),
225 );
226 let future = (handler)(server, *envelope);
227 async move {
228 if let Err(error) = future.await {
229 tracing::error!(%error, "error handling message");
230 }
231 }
232 .instrument(span)
233 .boxed()
234 }),
235 );
236 if prev_handler.is_some() {
237 panic!("registered a handler for the same message twice");
238 }
239 self
240 }
241
242 /// Handle a request while holding a lock to the store. This is useful when we're registering
243 /// a connection but we want to respond on the connection before anybody else can send on it.
244 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
245 where
246 F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>, Response<M>) -> Fut,
247 Fut: Send + Future<Output = Result<()>>,
248 M: RequestMessage,
249 {
250 let handler = Arc::new(handler);
251 self.add_message_handler(move |server, envelope| {
252 let receipt = envelope.receipt();
253 let handler = handler.clone();
254 async move {
255 let responded = Arc::new(AtomicBool::default());
256 let response = Response {
257 server: server.clone(),
258 responded: responded.clone(),
259 receipt: envelope.receipt(),
260 };
261 match (handler)(server.clone(), envelope, response).await {
262 Ok(()) => {
263 if responded.load(std::sync::atomic::Ordering::SeqCst) {
264 Ok(())
265 } else {
266 Err(anyhow!("handler did not send a response"))?
267 }
268 }
269 Err(error) => {
270 server.peer.respond_with_error(
271 receipt,
272 proto::Error {
273 message: error.to_string(),
274 },
275 )?;
276 Err(error)
277 }
278 }
279 }
280 })
281 }
282
283 pub fn handle_connection<E: Executor>(
284 self: &Arc<Self>,
285 connection: Connection,
286 address: String,
287 user: User,
288 mut send_connection_id: Option<mpsc::Sender<ConnectionId>>,
289 executor: E,
290 ) -> impl Future<Output = Result<()>> {
291 let mut this = self.clone();
292 let user_id = user.id;
293 let login = user.github_login;
294 let span = info_span!("handle connection", %user_id, %login, %address);
295 async move {
296 let (connection_id, handle_io, mut incoming_rx) = this
297 .peer
298 .add_connection(connection, {
299 let executor = executor.clone();
300 move |duration| {
301 let timer = executor.sleep(duration);
302 async move {
303 timer.await;
304 }
305 }
306 })
307 .await;
308
309 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
310
311 if let Some(send_connection_id) = send_connection_id.as_mut() {
312 let _ = send_connection_id.send(connection_id).await;
313 }
314
315 if !user.connected_once {
316 this.peer.send(connection_id, proto::ShowContacts {})?;
317 this.app_state.db.set_user_connected_once(user_id, true).await?;
318 }
319
320 let (contacts, invite_code) = future::try_join(
321 this.app_state.db.get_contacts(user_id),
322 this.app_state.db.get_invite_code_for_user(user_id)
323 ).await?;
324
325 {
326 let mut store = this.store_mut().await;
327 store.add_connection(connection_id, user_id);
328 this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
329
330 if let Some((code, count)) = invite_code {
331 this.peer.send(connection_id, proto::UpdateInviteInfo {
332 url: format!("{}{}", this.app_state.invite_link_prefix, code),
333 count,
334 })?;
335 }
336 }
337 this.update_user_contacts(user_id).await?;
338
339 let handle_io = handle_io.fuse();
340 futures::pin_mut!(handle_io);
341 loop {
342 let next_message = incoming_rx.next().fuse();
343 futures::pin_mut!(next_message);
344 futures::select_biased! {
345 result = handle_io => {
346 if let Err(error) = result {
347 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
348 }
349 break;
350 }
351 message = next_message => {
352 if let Some(message) = message {
353 let type_name = message.payload_type_name();
354 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
355 async {
356 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
357 let notifications = this.notifications.clone();
358 let is_background = message.is_background();
359 let handle_message = (handler)(this.clone(), message);
360 let handle_message = async move {
361 handle_message.await;
362 if let Some(mut notifications) = notifications {
363 let _ = notifications.send(()).await;
364 }
365 };
366 if is_background {
367 executor.spawn_detached(handle_message);
368 } else {
369 handle_message.await;
370 }
371 } else {
372 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
373 }
374 }.instrument(span).await;
375 } else {
376 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
377 break;
378 }
379 }
380 }
381 }
382
383 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
384 if let Err(error) = this.sign_out(connection_id).await {
385 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
386 }
387
388 Ok(())
389 }.instrument(span)
390 }
391
392 #[instrument(skip(self), err)]
393 async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> Result<()> {
394 self.peer.disconnect(connection_id);
395
396 let removed_user_id = {
397 let mut store = self.store_mut().await;
398 let removed_connection = store.remove_connection(connection_id)?;
399
400 for (project_id, project) in removed_connection.hosted_projects {
401 broadcast(connection_id, project.guests.keys().copied(), |conn_id| {
402 self.peer
403 .send(conn_id, proto::UnregisterProject { project_id })
404 });
405
406 for (_, receipts) in project.join_requests {
407 for receipt in receipts {
408 self.peer.respond(
409 receipt,
410 proto::JoinProjectResponse {
411 variant: Some(proto::join_project_response::Variant::Decline(
412 proto::join_project_response::Decline {
413 reason: proto::join_project_response::decline::Reason::WentOffline as i32
414 },
415 )),
416 },
417 )?;
418 }
419 }
420 }
421
422 for project_id in removed_connection.guest_project_ids {
423 if let Some(project) = store.project(project_id).trace_err() {
424 broadcast(connection_id, project.connection_ids(), |conn_id| {
425 self.peer.send(
426 conn_id,
427 proto::RemoveProjectCollaborator {
428 project_id,
429 peer_id: connection_id.0,
430 },
431 )
432 });
433 if project.guests.is_empty() {
434 self.peer
435 .send(
436 project.host_connection_id,
437 proto::ProjectUnshared { project_id },
438 )
439 .trace_err();
440 }
441 }
442 }
443
444 removed_connection.user_id
445 };
446
447 self.update_user_contacts(removed_user_id).await?;
448
449 Ok(())
450 }
451
452 pub async fn invite_code_redeemed(
453 self: &Arc<Self>,
454 code: &str,
455 invitee_id: UserId,
456 ) -> Result<()> {
457 let user = self.app_state.db.get_user_for_invite_code(code).await?;
458 let store = self.store().await;
459 let invitee_contact = store.contact_for_user(invitee_id, true);
460 for connection_id in store.connection_ids_for_user(user.id) {
461 self.peer.send(
462 connection_id,
463 proto::UpdateContacts {
464 contacts: vec![invitee_contact.clone()],
465 ..Default::default()
466 },
467 )?;
468 self.peer.send(
469 connection_id,
470 proto::UpdateInviteInfo {
471 url: format!("{}{}", self.app_state.invite_link_prefix, code),
472 count: user.invite_count as u32,
473 },
474 )?;
475 }
476 Ok(())
477 }
478
479 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
480 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
481 if let Some(invite_code) = &user.invite_code {
482 let store = self.store().await;
483 for connection_id in store.connection_ids_for_user(user_id) {
484 self.peer.send(
485 connection_id,
486 proto::UpdateInviteInfo {
487 url: format!("{}{}", self.app_state.invite_link_prefix, invite_code),
488 count: user.invite_count as u32,
489 },
490 )?;
491 }
492 }
493 }
494 Ok(())
495 }
496
497 async fn ping(
498 self: Arc<Server>,
499 _: TypedEnvelope<proto::Ping>,
500 response: Response<proto::Ping>,
501 ) -> Result<()> {
502 response.send(proto::Ack {})?;
503 Ok(())
504 }
505
506 async fn register_project(
507 self: Arc<Server>,
508 request: TypedEnvelope<proto::RegisterProject>,
509 response: Response<proto::RegisterProject>,
510 ) -> Result<()> {
511 let project_id;
512 {
513 let mut state = self.store_mut().await;
514 let user_id = state.user_id_for_connection(request.sender_id)?;
515 project_id = state.register_project(request.sender_id, user_id);
516 };
517
518 response.send(proto::RegisterProjectResponse { project_id })?;
519
520 Ok(())
521 }
522
523 async fn unregister_project(
524 self: Arc<Server>,
525 request: TypedEnvelope<proto::UnregisterProject>,
526 response: Response<proto::UnregisterProject>,
527 ) -> Result<()> {
528 let (user_id, project) = {
529 let mut state = self.store_mut().await;
530 let project =
531 state.unregister_project(request.payload.project_id, request.sender_id)?;
532 (state.user_id_for_connection(request.sender_id)?, project)
533 };
534
535 broadcast(
536 request.sender_id,
537 project.guests.keys().copied(),
538 |conn_id| {
539 self.peer.send(
540 conn_id,
541 proto::UnregisterProject {
542 project_id: request.payload.project_id,
543 },
544 )
545 },
546 );
547 for (_, receipts) in project.join_requests {
548 for receipt in receipts {
549 self.peer.respond(
550 receipt,
551 proto::JoinProjectResponse {
552 variant: Some(proto::join_project_response::Variant::Decline(
553 proto::join_project_response::Decline {
554 reason: proto::join_project_response::decline::Reason::Closed
555 as i32,
556 },
557 )),
558 },
559 )?;
560 }
561 }
562
563 // Send out the `UpdateContacts` message before responding to the unregister
564 // request. This way, when the project's host can keep track of the project's
565 // remote id until after they've received the `UpdateContacts` message for
566 // themself.
567 self.update_user_contacts(user_id).await?;
568 response.send(proto::Ack {})?;
569
570 Ok(())
571 }
572
573 async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
574 let contacts = self.app_state.db.get_contacts(user_id).await?;
575 let store = self.store().await;
576 let updated_contact = store.contact_for_user(user_id, false);
577 for contact in contacts {
578 if let db::Contact::Accepted {
579 user_id: contact_user_id,
580 ..
581 } = contact
582 {
583 for contact_conn_id in store.connection_ids_for_user(contact_user_id) {
584 self.peer
585 .send(
586 contact_conn_id,
587 proto::UpdateContacts {
588 contacts: vec![updated_contact.clone()],
589 remove_contacts: Default::default(),
590 incoming_requests: Default::default(),
591 remove_incoming_requests: Default::default(),
592 outgoing_requests: Default::default(),
593 remove_outgoing_requests: Default::default(),
594 },
595 )
596 .trace_err();
597 }
598 }
599 }
600 Ok(())
601 }
602
603 async fn join_project(
604 self: Arc<Server>,
605 request: TypedEnvelope<proto::JoinProject>,
606 response: Response<proto::JoinProject>,
607 ) -> Result<()> {
608 let project_id = request.payload.project_id;
609
610 let host_user_id;
611 let guest_user_id;
612 let host_connection_id;
613 {
614 let state = self.store().await;
615 let project = state.project(project_id)?;
616 host_user_id = project.host_user_id;
617 host_connection_id = project.host_connection_id;
618 guest_user_id = state.user_id_for_connection(request.sender_id)?;
619 };
620
621 let has_contact = self
622 .app_state
623 .db
624 .has_contact(guest_user_id, host_user_id)
625 .await?;
626 if !has_contact {
627 return Err(anyhow!("no such project"))?;
628 }
629
630 self.store_mut().await.request_join_project(
631 guest_user_id,
632 project_id,
633 response.into_receipt(),
634 )?;
635 self.peer.send(
636 host_connection_id,
637 proto::RequestJoinProject {
638 project_id,
639 requester_id: guest_user_id.to_proto(),
640 },
641 )?;
642 Ok(())
643 }
644
645 async fn respond_to_join_project_request(
646 self: Arc<Server>,
647 request: TypedEnvelope<proto::RespondToJoinProjectRequest>,
648 ) -> Result<()> {
649 let host_user_id;
650
651 {
652 let mut state = self.store_mut().await;
653 let project_id = request.payload.project_id;
654 let project = state.project(project_id)?;
655 if project.host_connection_id != request.sender_id {
656 Err(anyhow!("no such connection"))?;
657 }
658
659 host_user_id = project.host_user_id;
660 let guest_user_id = UserId::from_proto(request.payload.requester_id);
661
662 if !request.payload.allow {
663 let receipts = state
664 .deny_join_project_request(request.sender_id, guest_user_id, project_id)
665 .ok_or_else(|| anyhow!("no such request"))?;
666 for receipt in receipts {
667 self.peer.respond(
668 receipt,
669 proto::JoinProjectResponse {
670 variant: Some(proto::join_project_response::Variant::Decline(
671 proto::join_project_response::Decline {
672 reason: proto::join_project_response::decline::Reason::Declined
673 as i32,
674 },
675 )),
676 },
677 )?;
678 }
679 return Ok(());
680 }
681
682 let (receipts_with_replica_ids, project) = state
683 .accept_join_project_request(request.sender_id, guest_user_id, project_id)
684 .ok_or_else(|| anyhow!("no such request"))?;
685
686 let peer_count = project.guests.len();
687 let mut collaborators = Vec::with_capacity(peer_count);
688 collaborators.push(proto::Collaborator {
689 peer_id: project.host_connection_id.0,
690 replica_id: 0,
691 user_id: project.host_user_id.to_proto(),
692 });
693 let worktrees = project
694 .worktrees
695 .iter()
696 .filter_map(|(id, shared_worktree)| {
697 let worktree = project.worktrees.get(&id)?;
698 Some(proto::Worktree {
699 id: *id,
700 root_name: worktree.root_name.clone(),
701 entries: shared_worktree.entries.values().cloned().collect(),
702 diagnostic_summaries: shared_worktree
703 .diagnostic_summaries
704 .values()
705 .cloned()
706 .collect(),
707 visible: worktree.visible,
708 scan_id: shared_worktree.scan_id,
709 })
710 })
711 .collect::<Vec<_>>();
712
713 // Add all guests other than the requesting user's own connections as collaborators
714 for (peer_conn_id, (peer_replica_id, peer_user_id)) in &project.guests {
715 if receipts_with_replica_ids
716 .iter()
717 .all(|(receipt, _)| receipt.sender_id != *peer_conn_id)
718 {
719 collaborators.push(proto::Collaborator {
720 peer_id: peer_conn_id.0,
721 replica_id: *peer_replica_id as u32,
722 user_id: peer_user_id.to_proto(),
723 });
724 }
725 }
726
727 for conn_id in project.connection_ids() {
728 for (receipt, replica_id) in &receipts_with_replica_ids {
729 if conn_id != receipt.sender_id {
730 self.peer.send(
731 conn_id,
732 proto::AddProjectCollaborator {
733 project_id,
734 collaborator: Some(proto::Collaborator {
735 peer_id: receipt.sender_id.0,
736 replica_id: *replica_id as u32,
737 user_id: guest_user_id.to_proto(),
738 }),
739 },
740 )?;
741 }
742 }
743 }
744
745 for (receipt, replica_id) in receipts_with_replica_ids {
746 self.peer.respond(
747 receipt,
748 proto::JoinProjectResponse {
749 variant: Some(proto::join_project_response::Variant::Accept(
750 proto::join_project_response::Accept {
751 worktrees: worktrees.clone(),
752 replica_id: replica_id as u32,
753 collaborators: collaborators.clone(),
754 language_servers: project.language_servers.clone(),
755 },
756 )),
757 },
758 )?;
759 }
760 }
761
762 self.update_user_contacts(host_user_id).await?;
763 Ok(())
764 }
765
766 async fn leave_project(
767 self: Arc<Server>,
768 request: TypedEnvelope<proto::LeaveProject>,
769 ) -> Result<()> {
770 let sender_id = request.sender_id;
771 let project_id = request.payload.project_id;
772 let project;
773 {
774 let mut store = self.store_mut().await;
775 project = store.leave_project(sender_id, project_id)?;
776
777 if project.remove_collaborator {
778 broadcast(sender_id, project.connection_ids, |conn_id| {
779 self.peer.send(
780 conn_id,
781 proto::RemoveProjectCollaborator {
782 project_id,
783 peer_id: sender_id.0,
784 },
785 )
786 });
787 }
788
789 if let Some(requester_id) = project.cancel_request {
790 self.peer.send(
791 project.host_connection_id,
792 proto::JoinProjectRequestCancelled {
793 project_id,
794 requester_id: requester_id.to_proto(),
795 },
796 )?;
797 }
798
799 if project.unshare {
800 self.peer.send(
801 project.host_connection_id,
802 proto::ProjectUnshared { project_id },
803 )?;
804 }
805 }
806 self.update_user_contacts(project.host_user_id).await?;
807 Ok(())
808 }
809
810 async fn update_project(
811 self: Arc<Server>,
812 request: TypedEnvelope<proto::UpdateProject>,
813 ) -> Result<()> {
814 let user_id;
815 {
816 let mut state = self.store_mut().await;
817 user_id = state.user_id_for_connection(request.sender_id)?;
818 let guest_connection_ids = state
819 .read_project(request.payload.project_id, request.sender_id)?
820 .guest_connection_ids();
821 state.update_project(
822 request.payload.project_id,
823 &request.payload.worktrees,
824 request.sender_id,
825 )?;
826 broadcast(request.sender_id, guest_connection_ids, |connection_id| {
827 self.peer
828 .forward_send(request.sender_id, connection_id, request.payload.clone())
829 });
830 };
831 self.update_user_contacts(user_id).await?;
832 Ok(())
833 }
834
835 async fn update_worktree(
836 self: Arc<Server>,
837 request: TypedEnvelope<proto::UpdateWorktree>,
838 response: Response<proto::UpdateWorktree>,
839 ) -> Result<()> {
840 let (connection_ids, metadata_changed) = self.store_mut().await.update_worktree(
841 request.sender_id,
842 request.payload.project_id,
843 request.payload.worktree_id,
844 &request.payload.root_name,
845 &request.payload.removed_entries,
846 &request.payload.updated_entries,
847 request.payload.scan_id,
848 )?;
849
850 broadcast(request.sender_id, connection_ids, |connection_id| {
851 self.peer
852 .forward_send(request.sender_id, connection_id, request.payload.clone())
853 });
854 if metadata_changed {
855 let user_id = self
856 .store()
857 .await
858 .user_id_for_connection(request.sender_id)?;
859 self.update_user_contacts(user_id).await?;
860 }
861 response.send(proto::Ack {})?;
862 Ok(())
863 }
864
865 async fn update_diagnostic_summary(
866 self: Arc<Server>,
867 request: TypedEnvelope<proto::UpdateDiagnosticSummary>,
868 ) -> Result<()> {
869 let summary = request
870 .payload
871 .summary
872 .clone()
873 .ok_or_else(|| anyhow!("invalid summary"))?;
874 let receiver_ids = self.store_mut().await.update_diagnostic_summary(
875 request.payload.project_id,
876 request.payload.worktree_id,
877 request.sender_id,
878 summary,
879 )?;
880
881 broadcast(request.sender_id, receiver_ids, |connection_id| {
882 self.peer
883 .forward_send(request.sender_id, connection_id, request.payload.clone())
884 });
885 Ok(())
886 }
887
888 async fn start_language_server(
889 self: Arc<Server>,
890 request: TypedEnvelope<proto::StartLanguageServer>,
891 ) -> Result<()> {
892 let receiver_ids = self.store_mut().await.start_language_server(
893 request.payload.project_id,
894 request.sender_id,
895 request
896 .payload
897 .server
898 .clone()
899 .ok_or_else(|| anyhow!("invalid language server"))?,
900 )?;
901 broadcast(request.sender_id, receiver_ids, |connection_id| {
902 self.peer
903 .forward_send(request.sender_id, connection_id, request.payload.clone())
904 });
905 Ok(())
906 }
907
908 async fn update_language_server(
909 self: Arc<Server>,
910 request: TypedEnvelope<proto::UpdateLanguageServer>,
911 ) -> Result<()> {
912 let receiver_ids = self
913 .store()
914 .await
915 .project_connection_ids(request.payload.project_id, request.sender_id)?;
916 broadcast(request.sender_id, receiver_ids, |connection_id| {
917 self.peer
918 .forward_send(request.sender_id, connection_id, request.payload.clone())
919 });
920 Ok(())
921 }
922
923 async fn forward_project_request<T>(
924 self: Arc<Server>,
925 request: TypedEnvelope<T>,
926 response: Response<T>,
927 ) -> Result<()>
928 where
929 T: EntityMessage + RequestMessage,
930 {
931 let host_connection_id = self
932 .store()
933 .await
934 .read_project(request.payload.remote_entity_id(), request.sender_id)?
935 .host_connection_id;
936
937 response.send(
938 self.peer
939 .forward_request(request.sender_id, host_connection_id, request.payload)
940 .await?,
941 )?;
942 Ok(())
943 }
944
945 async fn save_buffer(
946 self: Arc<Server>,
947 request: TypedEnvelope<proto::SaveBuffer>,
948 response: Response<proto::SaveBuffer>,
949 ) -> Result<()> {
950 let host = self
951 .store()
952 .await
953 .read_project(request.payload.project_id, request.sender_id)?
954 .host_connection_id;
955 let response_payload = self
956 .peer
957 .forward_request(request.sender_id, host, request.payload.clone())
958 .await?;
959
960 let mut guests = self
961 .store()
962 .await
963 .read_project(request.payload.project_id, request.sender_id)?
964 .connection_ids();
965 guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id);
966 broadcast(host, guests, |conn_id| {
967 self.peer
968 .forward_send(host, conn_id, response_payload.clone())
969 });
970 response.send(response_payload)?;
971 Ok(())
972 }
973
974 async fn update_buffer(
975 self: Arc<Server>,
976 request: TypedEnvelope<proto::UpdateBuffer>,
977 response: Response<proto::UpdateBuffer>,
978 ) -> Result<()> {
979 let receiver_ids = self
980 .store()
981 .await
982 .project_connection_ids(request.payload.project_id, request.sender_id)?;
983 broadcast(request.sender_id, receiver_ids, |connection_id| {
984 self.peer
985 .forward_send(request.sender_id, connection_id, request.payload.clone())
986 });
987 response.send(proto::Ack {})?;
988 Ok(())
989 }
990
991 async fn update_buffer_file(
992 self: Arc<Server>,
993 request: TypedEnvelope<proto::UpdateBufferFile>,
994 ) -> Result<()> {
995 let receiver_ids = self
996 .store()
997 .await
998 .project_connection_ids(request.payload.project_id, request.sender_id)?;
999 broadcast(request.sender_id, receiver_ids, |connection_id| {
1000 self.peer
1001 .forward_send(request.sender_id, connection_id, request.payload.clone())
1002 });
1003 Ok(())
1004 }
1005
1006 async fn buffer_reloaded(
1007 self: Arc<Server>,
1008 request: TypedEnvelope<proto::BufferReloaded>,
1009 ) -> Result<()> {
1010 let receiver_ids = self
1011 .store()
1012 .await
1013 .project_connection_ids(request.payload.project_id, request.sender_id)?;
1014 broadcast(request.sender_id, receiver_ids, |connection_id| {
1015 self.peer
1016 .forward_send(request.sender_id, connection_id, request.payload.clone())
1017 });
1018 Ok(())
1019 }
1020
1021 async fn buffer_saved(
1022 self: Arc<Server>,
1023 request: TypedEnvelope<proto::BufferSaved>,
1024 ) -> Result<()> {
1025 let receiver_ids = self
1026 .store()
1027 .await
1028 .project_connection_ids(request.payload.project_id, request.sender_id)?;
1029 broadcast(request.sender_id, receiver_ids, |connection_id| {
1030 self.peer
1031 .forward_send(request.sender_id, connection_id, request.payload.clone())
1032 });
1033 Ok(())
1034 }
1035
1036 async fn follow(
1037 self: Arc<Self>,
1038 request: TypedEnvelope<proto::Follow>,
1039 response: Response<proto::Follow>,
1040 ) -> Result<()> {
1041 let leader_id = ConnectionId(request.payload.leader_id);
1042 let follower_id = request.sender_id;
1043 if !self
1044 .store()
1045 .await
1046 .project_connection_ids(request.payload.project_id, follower_id)?
1047 .contains(&leader_id)
1048 {
1049 Err(anyhow!("no such peer"))?;
1050 }
1051 let mut response_payload = self
1052 .peer
1053 .forward_request(request.sender_id, leader_id, request.payload)
1054 .await?;
1055 response_payload
1056 .views
1057 .retain(|view| view.leader_id != Some(follower_id.0));
1058 response.send(response_payload)?;
1059 Ok(())
1060 }
1061
1062 async fn unfollow(self: Arc<Self>, request: TypedEnvelope<proto::Unfollow>) -> Result<()> {
1063 let leader_id = ConnectionId(request.payload.leader_id);
1064 if !self
1065 .store()
1066 .await
1067 .project_connection_ids(request.payload.project_id, request.sender_id)?
1068 .contains(&leader_id)
1069 {
1070 Err(anyhow!("no such peer"))?;
1071 }
1072 self.peer
1073 .forward_send(request.sender_id, leader_id, request.payload)?;
1074 Ok(())
1075 }
1076
1077 async fn update_followers(
1078 self: Arc<Self>,
1079 request: TypedEnvelope<proto::UpdateFollowers>,
1080 ) -> Result<()> {
1081 let connection_ids = self
1082 .store()
1083 .await
1084 .project_connection_ids(request.payload.project_id, request.sender_id)?;
1085 let leader_id = request
1086 .payload
1087 .variant
1088 .as_ref()
1089 .and_then(|variant| match variant {
1090 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1091 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1092 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1093 });
1094 for follower_id in &request.payload.follower_ids {
1095 let follower_id = ConnectionId(*follower_id);
1096 if connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1097 self.peer
1098 .forward_send(request.sender_id, follower_id, request.payload.clone())?;
1099 }
1100 }
1101 Ok(())
1102 }
1103
1104 async fn get_channels(
1105 self: Arc<Server>,
1106 request: TypedEnvelope<proto::GetChannels>,
1107 response: Response<proto::GetChannels>,
1108 ) -> Result<()> {
1109 let user_id = self
1110 .store()
1111 .await
1112 .user_id_for_connection(request.sender_id)?;
1113 let channels = self.app_state.db.get_accessible_channels(user_id).await?;
1114 response.send(proto::GetChannelsResponse {
1115 channels: channels
1116 .into_iter()
1117 .map(|chan| proto::Channel {
1118 id: chan.id.to_proto(),
1119 name: chan.name,
1120 })
1121 .collect(),
1122 })?;
1123 Ok(())
1124 }
1125
1126 async fn get_users(
1127 self: Arc<Server>,
1128 request: TypedEnvelope<proto::GetUsers>,
1129 response: Response<proto::GetUsers>,
1130 ) -> Result<()> {
1131 let user_ids = request
1132 .payload
1133 .user_ids
1134 .into_iter()
1135 .map(UserId::from_proto)
1136 .collect();
1137 let users = self
1138 .app_state
1139 .db
1140 .get_users_by_ids(user_ids)
1141 .await?
1142 .into_iter()
1143 .map(|user| proto::User {
1144 id: user.id.to_proto(),
1145 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1146 github_login: user.github_login,
1147 })
1148 .collect();
1149 response.send(proto::UsersResponse { users })?;
1150 Ok(())
1151 }
1152
1153 async fn fuzzy_search_users(
1154 self: Arc<Server>,
1155 request: TypedEnvelope<proto::FuzzySearchUsers>,
1156 response: Response<proto::FuzzySearchUsers>,
1157 ) -> Result<()> {
1158 let user_id = self
1159 .store()
1160 .await
1161 .user_id_for_connection(request.sender_id)?;
1162 let query = request.payload.query;
1163 let db = &self.app_state.db;
1164 let users = match query.len() {
1165 0 => vec![],
1166 1 | 2 => db
1167 .get_user_by_github_login(&query)
1168 .await?
1169 .into_iter()
1170 .collect(),
1171 _ => db.fuzzy_search_users(&query, 10).await?,
1172 };
1173 let users = users
1174 .into_iter()
1175 .filter(|user| user.id != user_id)
1176 .map(|user| proto::User {
1177 id: user.id.to_proto(),
1178 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1179 github_login: user.github_login,
1180 })
1181 .collect();
1182 response.send(proto::UsersResponse { users })?;
1183 Ok(())
1184 }
1185
1186 async fn request_contact(
1187 self: Arc<Server>,
1188 request: TypedEnvelope<proto::RequestContact>,
1189 response: Response<proto::RequestContact>,
1190 ) -> Result<()> {
1191 let requester_id = self
1192 .store()
1193 .await
1194 .user_id_for_connection(request.sender_id)?;
1195 let responder_id = UserId::from_proto(request.payload.responder_id);
1196 if requester_id == responder_id {
1197 return Err(anyhow!("cannot add yourself as a contact"))?;
1198 }
1199
1200 self.app_state
1201 .db
1202 .send_contact_request(requester_id, responder_id)
1203 .await?;
1204
1205 // Update outgoing contact requests of requester
1206 let mut update = proto::UpdateContacts::default();
1207 update.outgoing_requests.push(responder_id.to_proto());
1208 for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1209 self.peer.send(connection_id, update.clone())?;
1210 }
1211
1212 // Update incoming contact requests of responder
1213 let mut update = proto::UpdateContacts::default();
1214 update
1215 .incoming_requests
1216 .push(proto::IncomingContactRequest {
1217 requester_id: requester_id.to_proto(),
1218 should_notify: true,
1219 });
1220 for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1221 self.peer.send(connection_id, update.clone())?;
1222 }
1223
1224 response.send(proto::Ack {})?;
1225 Ok(())
1226 }
1227
1228 async fn respond_to_contact_request(
1229 self: Arc<Server>,
1230 request: TypedEnvelope<proto::RespondToContactRequest>,
1231 response: Response<proto::RespondToContactRequest>,
1232 ) -> Result<()> {
1233 let responder_id = self
1234 .store()
1235 .await
1236 .user_id_for_connection(request.sender_id)?;
1237 let requester_id = UserId::from_proto(request.payload.requester_id);
1238 if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 {
1239 self.app_state
1240 .db
1241 .dismiss_contact_notification(responder_id, requester_id)
1242 .await?;
1243 } else {
1244 let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32;
1245 self.app_state
1246 .db
1247 .respond_to_contact_request(responder_id, requester_id, accept)
1248 .await?;
1249
1250 let store = self.store().await;
1251 // Update responder with new contact
1252 let mut update = proto::UpdateContacts::default();
1253 if accept {
1254 update
1255 .contacts
1256 .push(store.contact_for_user(requester_id, false));
1257 }
1258 update
1259 .remove_incoming_requests
1260 .push(requester_id.to_proto());
1261 for connection_id in store.connection_ids_for_user(responder_id) {
1262 self.peer.send(connection_id, update.clone())?;
1263 }
1264
1265 // Update requester with new contact
1266 let mut update = proto::UpdateContacts::default();
1267 if accept {
1268 update
1269 .contacts
1270 .push(store.contact_for_user(responder_id, true));
1271 }
1272 update
1273 .remove_outgoing_requests
1274 .push(responder_id.to_proto());
1275 for connection_id in store.connection_ids_for_user(requester_id) {
1276 self.peer.send(connection_id, update.clone())?;
1277 }
1278 }
1279
1280 response.send(proto::Ack {})?;
1281 Ok(())
1282 }
1283
1284 async fn remove_contact(
1285 self: Arc<Server>,
1286 request: TypedEnvelope<proto::RemoveContact>,
1287 response: Response<proto::RemoveContact>,
1288 ) -> Result<()> {
1289 let requester_id = self
1290 .store()
1291 .await
1292 .user_id_for_connection(request.sender_id)?;
1293 let responder_id = UserId::from_proto(request.payload.user_id);
1294 self.app_state
1295 .db
1296 .remove_contact(requester_id, responder_id)
1297 .await?;
1298
1299 // Update outgoing contact requests of requester
1300 let mut update = proto::UpdateContacts::default();
1301 update
1302 .remove_outgoing_requests
1303 .push(responder_id.to_proto());
1304 for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1305 self.peer.send(connection_id, update.clone())?;
1306 }
1307
1308 // Update incoming contact requests of responder
1309 let mut update = proto::UpdateContacts::default();
1310 update
1311 .remove_incoming_requests
1312 .push(requester_id.to_proto());
1313 for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1314 self.peer.send(connection_id, update.clone())?;
1315 }
1316
1317 response.send(proto::Ack {})?;
1318 Ok(())
1319 }
1320
1321 async fn join_channel(
1322 self: Arc<Self>,
1323 request: TypedEnvelope<proto::JoinChannel>,
1324 response: Response<proto::JoinChannel>,
1325 ) -> Result<()> {
1326 let user_id = self
1327 .store()
1328 .await
1329 .user_id_for_connection(request.sender_id)?;
1330 let channel_id = ChannelId::from_proto(request.payload.channel_id);
1331 if !self
1332 .app_state
1333 .db
1334 .can_user_access_channel(user_id, channel_id)
1335 .await?
1336 {
1337 Err(anyhow!("access denied"))?;
1338 }
1339
1340 self.store_mut()
1341 .await
1342 .join_channel(request.sender_id, channel_id);
1343 let messages = self
1344 .app_state
1345 .db
1346 .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
1347 .await?
1348 .into_iter()
1349 .map(|msg| proto::ChannelMessage {
1350 id: msg.id.to_proto(),
1351 body: msg.body,
1352 timestamp: msg.sent_at.unix_timestamp() as u64,
1353 sender_id: msg.sender_id.to_proto(),
1354 nonce: Some(msg.nonce.as_u128().into()),
1355 })
1356 .collect::<Vec<_>>();
1357 response.send(proto::JoinChannelResponse {
1358 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1359 messages,
1360 })?;
1361 Ok(())
1362 }
1363
1364 async fn leave_channel(
1365 self: Arc<Self>,
1366 request: TypedEnvelope<proto::LeaveChannel>,
1367 ) -> Result<()> {
1368 let user_id = self
1369 .store()
1370 .await
1371 .user_id_for_connection(request.sender_id)?;
1372 let channel_id = ChannelId::from_proto(request.payload.channel_id);
1373 if !self
1374 .app_state
1375 .db
1376 .can_user_access_channel(user_id, channel_id)
1377 .await?
1378 {
1379 Err(anyhow!("access denied"))?;
1380 }
1381
1382 self.store_mut()
1383 .await
1384 .leave_channel(request.sender_id, channel_id);
1385
1386 Ok(())
1387 }
1388
1389 async fn send_channel_message(
1390 self: Arc<Self>,
1391 request: TypedEnvelope<proto::SendChannelMessage>,
1392 response: Response<proto::SendChannelMessage>,
1393 ) -> Result<()> {
1394 let channel_id = ChannelId::from_proto(request.payload.channel_id);
1395 let user_id;
1396 let connection_ids;
1397 {
1398 let state = self.store().await;
1399 user_id = state.user_id_for_connection(request.sender_id)?;
1400 connection_ids = state.channel_connection_ids(channel_id)?;
1401 }
1402
1403 // Validate the message body.
1404 let body = request.payload.body.trim().to_string();
1405 if body.len() > MAX_MESSAGE_LEN {
1406 return Err(anyhow!("message is too long"))?;
1407 }
1408 if body.is_empty() {
1409 return Err(anyhow!("message can't be blank"))?;
1410 }
1411
1412 let timestamp = OffsetDateTime::now_utc();
1413 let nonce = request
1414 .payload
1415 .nonce
1416 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
1417
1418 let message_id = self
1419 .app_state
1420 .db
1421 .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
1422 .await?
1423 .to_proto();
1424 let message = proto::ChannelMessage {
1425 sender_id: user_id.to_proto(),
1426 id: message_id,
1427 body,
1428 timestamp: timestamp.unix_timestamp() as u64,
1429 nonce: Some(nonce),
1430 };
1431 broadcast(request.sender_id, connection_ids, |conn_id| {
1432 self.peer.send(
1433 conn_id,
1434 proto::ChannelMessageSent {
1435 channel_id: channel_id.to_proto(),
1436 message: Some(message.clone()),
1437 },
1438 )
1439 });
1440 response.send(proto::SendChannelMessageResponse {
1441 message: Some(message),
1442 })?;
1443 Ok(())
1444 }
1445
1446 async fn get_channel_messages(
1447 self: Arc<Self>,
1448 request: TypedEnvelope<proto::GetChannelMessages>,
1449 response: Response<proto::GetChannelMessages>,
1450 ) -> Result<()> {
1451 let user_id = self
1452 .store()
1453 .await
1454 .user_id_for_connection(request.sender_id)?;
1455 let channel_id = ChannelId::from_proto(request.payload.channel_id);
1456 if !self
1457 .app_state
1458 .db
1459 .can_user_access_channel(user_id, channel_id)
1460 .await?
1461 {
1462 Err(anyhow!("access denied"))?;
1463 }
1464
1465 let messages = self
1466 .app_state
1467 .db
1468 .get_channel_messages(
1469 channel_id,
1470 MESSAGE_COUNT_PER_PAGE,
1471 Some(MessageId::from_proto(request.payload.before_message_id)),
1472 )
1473 .await?
1474 .into_iter()
1475 .map(|msg| proto::ChannelMessage {
1476 id: msg.id.to_proto(),
1477 body: msg.body,
1478 timestamp: msg.sent_at.unix_timestamp() as u64,
1479 sender_id: msg.sender_id.to_proto(),
1480 nonce: Some(msg.nonce.as_u128().into()),
1481 })
1482 .collect::<Vec<_>>();
1483 response.send(proto::GetChannelMessagesResponse {
1484 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1485 messages,
1486 })?;
1487 Ok(())
1488 }
1489
1490 async fn store<'a>(self: &'a Arc<Self>) -> StoreReadGuard<'a> {
1491 #[cfg(test)]
1492 tokio::task::yield_now().await;
1493 let guard = self.store.read().await;
1494 #[cfg(test)]
1495 tokio::task::yield_now().await;
1496 StoreReadGuard {
1497 guard,
1498 _not_send: PhantomData,
1499 }
1500 }
1501
1502 async fn store_mut<'a>(self: &'a Arc<Self>) -> StoreWriteGuard<'a> {
1503 #[cfg(test)]
1504 tokio::task::yield_now().await;
1505 let guard = self.store.write().await;
1506 #[cfg(test)]
1507 tokio::task::yield_now().await;
1508 StoreWriteGuard {
1509 guard,
1510 _not_send: PhantomData,
1511 }
1512 }
1513
1514 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1515 ServerSnapshot {
1516 store: self.store.read().await,
1517 peer: &self.peer,
1518 }
1519 }
1520}
1521
1522impl<'a> Deref for StoreReadGuard<'a> {
1523 type Target = Store;
1524
1525 fn deref(&self) -> &Self::Target {
1526 &*self.guard
1527 }
1528}
1529
1530impl<'a> Deref for StoreWriteGuard<'a> {
1531 type Target = Store;
1532
1533 fn deref(&self) -> &Self::Target {
1534 &*self.guard
1535 }
1536}
1537
1538impl<'a> DerefMut for StoreWriteGuard<'a> {
1539 fn deref_mut(&mut self) -> &mut Self::Target {
1540 &mut *self.guard
1541 }
1542}
1543
1544impl<'a> Drop for StoreWriteGuard<'a> {
1545 fn drop(&mut self) {
1546 #[cfg(test)]
1547 self.check_invariants();
1548
1549 let metrics = self.metrics();
1550
1551 METRIC_CONNECTIONS.set(metrics.connections as _);
1552 METRIC_PROJECTS.set(metrics.registered_projects as _);
1553 METRIC_SHARED_PROJECTS.set(metrics.shared_projects as _);
1554
1555 tracing::info!(
1556 connections = metrics.connections,
1557 registered_projects = metrics.registered_projects,
1558 shared_projects = metrics.shared_projects,
1559 "metrics"
1560 );
1561 }
1562}
1563
1564impl Executor for RealExecutor {
1565 type Sleep = Sleep;
1566
1567 fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
1568 tokio::task::spawn(future);
1569 }
1570
1571 fn sleep(&self, duration: Duration) -> Self::Sleep {
1572 tokio::time::sleep(duration)
1573 }
1574}
1575
1576fn broadcast<F>(
1577 sender_id: ConnectionId,
1578 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1579 mut f: F,
1580) where
1581 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1582{
1583 for receiver_id in receiver_ids {
1584 if receiver_id != sender_id {
1585 f(receiver_id).trace_err();
1586 }
1587 }
1588}
1589
1590lazy_static! {
1591 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
1592}
1593
1594pub struct ProtocolVersion(u32);
1595
1596impl Header for ProtocolVersion {
1597 fn name() -> &'static HeaderName {
1598 &ZED_PROTOCOL_VERSION
1599 }
1600
1601 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1602 where
1603 Self: Sized,
1604 I: Iterator<Item = &'i axum::http::HeaderValue>,
1605 {
1606 let version = values
1607 .next()
1608 .ok_or_else(|| axum::headers::Error::invalid())?
1609 .to_str()
1610 .map_err(|_| axum::headers::Error::invalid())?
1611 .parse()
1612 .map_err(|_| axum::headers::Error::invalid())?;
1613 Ok(Self(version))
1614 }
1615
1616 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1617 values.extend([self.0.to_string().parse().unwrap()]);
1618 }
1619}
1620
1621pub fn routes(server: Arc<Server>) -> Router<Body> {
1622 Router::new()
1623 .route("/rpc", get(handle_websocket_request))
1624 .layer(
1625 ServiceBuilder::new()
1626 .layer(Extension(server.app_state.clone()))
1627 .layer(middleware::from_fn(auth::validate_header))
1628 .layer(Extension(server)),
1629 )
1630 .route("/metrics", get(handle_metrics))
1631}
1632
1633pub async fn handle_websocket_request(
1634 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1635 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1636 Extension(server): Extension<Arc<Server>>,
1637 Extension(user): Extension<User>,
1638 ws: WebSocketUpgrade,
1639) -> axum::response::Response {
1640 if protocol_version != rpc::PROTOCOL_VERSION {
1641 return (
1642 StatusCode::UPGRADE_REQUIRED,
1643 "client must be upgraded".to_string(),
1644 )
1645 .into_response();
1646 }
1647 let socket_address = socket_address.to_string();
1648 ws.on_upgrade(move |socket| {
1649 use util::ResultExt;
1650 let socket = socket
1651 .map_ok(to_tungstenite_message)
1652 .err_into()
1653 .with(|message| async move { Ok(to_axum_message(message)) });
1654 let connection = Connection::new(Box::pin(socket));
1655 async move {
1656 server
1657 .handle_connection(connection, socket_address, user, None, RealExecutor)
1658 .await
1659 .log_err();
1660 }
1661 })
1662}
1663
1664pub async fn handle_metrics() -> axum::response::Response {
1665 let encoder = prometheus::TextEncoder::new();
1666 let metric_families = prometheus::gather();
1667 match encoder.encode_to_string(&metric_families) {
1668 Ok(string) => (StatusCode::OK, string).into_response(),
1669 Err(error) => (
1670 StatusCode::INTERNAL_SERVER_ERROR,
1671 format!("failed to encode metrics {:?}", error),
1672 )
1673 .into_response(),
1674 }
1675}
1676
1677fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1678 match message {
1679 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1680 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1681 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1682 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1683 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1684 code: frame.code.into(),
1685 reason: frame.reason,
1686 })),
1687 }
1688}
1689
1690fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1691 match message {
1692 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1693 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1694 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1695 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1696 AxumMessage::Close(frame) => {
1697 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1698 code: frame.code.into(),
1699 reason: frame.reason,
1700 }))
1701 }
1702 }
1703}
1704
1705pub trait ResultExt {
1706 type Ok;
1707
1708 fn trace_err(self) -> Option<Self::Ok>;
1709}
1710
1711impl<T, E> ResultExt for Result<T, E>
1712where
1713 E: std::fmt::Debug,
1714{
1715 type Ok = T;
1716
1717 fn trace_err(self) -> Option<T> {
1718 match self {
1719 Ok(value) => Some(value),
1720 Err(error) => {
1721 tracing::error!("{:?}", error);
1722 None
1723 }
1724 }
1725 }
1726}