1mod store;
2
3use crate::{
4 auth,
5 db::{self, ChannelId, MessageId, User, UserId},
6 AppState, Result,
7};
8use anyhow::anyhow;
9use async_tungstenite::tungstenite::{
10 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
11};
12use axum::{
13 body::Body,
14 extract::{
15 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
16 ConnectInfo, WebSocketUpgrade,
17 },
18 headers::{Header, HeaderName},
19 http::StatusCode,
20 middleware,
21 response::IntoResponse,
22 routing::get,
23 Extension, Router, TypedHeader,
24};
25use collections::HashMap;
26use futures::{
27 channel::mpsc,
28 future::{self, BoxFuture},
29 FutureExt, SinkExt, StreamExt, TryStreamExt,
30};
31use lazy_static::lazy_static;
32use rpc::{
33 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
34 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
35};
36use serde::{Serialize, Serializer};
37use std::{
38 any::TypeId,
39 future::Future,
40 marker::PhantomData,
41 net::SocketAddr,
42 ops::{Deref, DerefMut},
43 rc::Rc,
44 sync::{
45 atomic::{AtomicBool, Ordering::SeqCst},
46 Arc,
47 },
48 time::Duration,
49};
50use time::OffsetDateTime;
51use tokio::{
52 sync::{RwLock, RwLockReadGuard, RwLockWriteGuard},
53 time::Sleep,
54};
55use tower::ServiceBuilder;
56use tracing::{info_span, instrument, Instrument};
57
58pub use store::{Store, Worktree};
59
60type MessageHandler =
61 Box<dyn Send + Sync + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>>;
62
63struct Response<R> {
64 server: Arc<Server>,
65 receipt: Receipt<R>,
66 responded: Arc<AtomicBool>,
67}
68
69impl<R: RequestMessage> Response<R> {
70 fn send(self, payload: R::Response) -> Result<()> {
71 self.responded.store(true, SeqCst);
72 self.server.peer.respond(self.receipt, payload)?;
73 Ok(())
74 }
75
76 fn into_receipt(self) -> Receipt<R> {
77 self.responded.store(true, SeqCst);
78 self.receipt
79 }
80}
81
82pub struct Server {
83 peer: Arc<Peer>,
84 pub(crate) store: RwLock<Store>,
85 app_state: Arc<AppState>,
86 handlers: HashMap<TypeId, MessageHandler>,
87 notifications: Option<mpsc::UnboundedSender<()>>,
88}
89
90pub trait Executor: Send + Clone {
91 type Sleep: Send + Future;
92 fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
93 fn sleep(&self, duration: Duration) -> Self::Sleep;
94}
95
96#[derive(Clone)]
97pub struct RealExecutor;
98
99const MESSAGE_COUNT_PER_PAGE: usize = 100;
100const MAX_MESSAGE_LEN: usize = 1024;
101
102struct StoreReadGuard<'a> {
103 guard: RwLockReadGuard<'a, Store>,
104 _not_send: PhantomData<Rc<()>>,
105}
106
107struct StoreWriteGuard<'a> {
108 guard: RwLockWriteGuard<'a, Store>,
109 _not_send: PhantomData<Rc<()>>,
110}
111
112#[derive(Serialize)]
113pub struct ServerSnapshot<'a> {
114 peer: &'a Peer,
115 #[serde(serialize_with = "serialize_deref")]
116 store: RwLockReadGuard<'a, Store>,
117}
118
119pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
120where
121 S: Serializer,
122 T: Deref<Target = U>,
123 U: Serialize,
124{
125 Serialize::serialize(value.deref(), serializer)
126}
127
128impl Server {
129 pub fn new(
130 app_state: Arc<AppState>,
131 notifications: Option<mpsc::UnboundedSender<()>>,
132 ) -> Arc<Self> {
133 let mut server = Self {
134 peer: Peer::new(),
135 app_state,
136 store: Default::default(),
137 handlers: Default::default(),
138 notifications,
139 };
140
141 server
142 .add_request_handler(Server::ping)
143 .add_request_handler(Server::register_project)
144 .add_request_handler(Server::unregister_project)
145 .add_request_handler(Server::join_project)
146 .add_message_handler(Server::leave_project)
147 .add_message_handler(Server::respond_to_join_project_request)
148 .add_message_handler(Server::update_project)
149 .add_request_handler(Server::update_worktree)
150 .add_message_handler(Server::start_language_server)
151 .add_message_handler(Server::update_language_server)
152 .add_message_handler(Server::update_diagnostic_summary)
153 .add_request_handler(Server::forward_project_request::<proto::GetHover>)
154 .add_request_handler(Server::forward_project_request::<proto::GetDefinition>)
155 .add_request_handler(Server::forward_project_request::<proto::GetReferences>)
156 .add_request_handler(Server::forward_project_request::<proto::SearchProject>)
157 .add_request_handler(Server::forward_project_request::<proto::GetDocumentHighlights>)
158 .add_request_handler(Server::forward_project_request::<proto::GetProjectSymbols>)
159 .add_request_handler(Server::forward_project_request::<proto::OpenBufferForSymbol>)
160 .add_request_handler(Server::forward_project_request::<proto::OpenBufferById>)
161 .add_request_handler(Server::forward_project_request::<proto::OpenBufferByPath>)
162 .add_request_handler(Server::forward_project_request::<proto::GetCompletions>)
163 .add_request_handler(
164 Server::forward_project_request::<proto::ApplyCompletionAdditionalEdits>,
165 )
166 .add_request_handler(Server::forward_project_request::<proto::GetCodeActions>)
167 .add_request_handler(Server::forward_project_request::<proto::ApplyCodeAction>)
168 .add_request_handler(Server::forward_project_request::<proto::PrepareRename>)
169 .add_request_handler(Server::forward_project_request::<proto::PerformRename>)
170 .add_request_handler(Server::forward_project_request::<proto::ReloadBuffers>)
171 .add_request_handler(Server::forward_project_request::<proto::FormatBuffers>)
172 .add_request_handler(Server::forward_project_request::<proto::CreateProjectEntry>)
173 .add_request_handler(Server::forward_project_request::<proto::RenameProjectEntry>)
174 .add_request_handler(Server::forward_project_request::<proto::CopyProjectEntry>)
175 .add_request_handler(Server::forward_project_request::<proto::DeleteProjectEntry>)
176 .add_request_handler(Server::update_buffer)
177 .add_message_handler(Server::update_buffer_file)
178 .add_message_handler(Server::buffer_reloaded)
179 .add_message_handler(Server::buffer_saved)
180 .add_request_handler(Server::save_buffer)
181 .add_request_handler(Server::get_channels)
182 .add_request_handler(Server::get_users)
183 .add_request_handler(Server::fuzzy_search_users)
184 .add_request_handler(Server::request_contact)
185 .add_request_handler(Server::remove_contact)
186 .add_request_handler(Server::respond_to_contact_request)
187 .add_request_handler(Server::join_channel)
188 .add_message_handler(Server::leave_channel)
189 .add_request_handler(Server::send_channel_message)
190 .add_request_handler(Server::follow)
191 .add_message_handler(Server::unfollow)
192 .add_message_handler(Server::update_followers)
193 .add_request_handler(Server::get_channel_messages);
194
195 Arc::new(server)
196 }
197
198 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
199 where
200 F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
201 Fut: 'static + Send + Future<Output = Result<()>>,
202 M: EnvelopedMessage,
203 {
204 let prev_handler = self.handlers.insert(
205 TypeId::of::<M>(),
206 Box::new(move |server, envelope| {
207 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
208 let span = info_span!(
209 "handle message",
210 payload_type = envelope.payload_type_name()
211 );
212 span.in_scope(|| {
213 tracing::info!(
214 payload = format!("{:?}", envelope.payload).as_str(),
215 "message payload"
216 );
217 });
218 let future = (handler)(server, *envelope);
219 async move {
220 if let Err(error) = future.await {
221 tracing::error!(%error, "error handling message");
222 }
223 }
224 .instrument(span)
225 .boxed()
226 }),
227 );
228 if prev_handler.is_some() {
229 panic!("registered a handler for the same message twice");
230 }
231 self
232 }
233
234 /// Handle a request while holding a lock to the store. This is useful when we're registering
235 /// a connection but we want to respond on the connection before anybody else can send on it.
236 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
237 where
238 F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>, Response<M>) -> Fut,
239 Fut: Send + Future<Output = Result<()>>,
240 M: RequestMessage,
241 {
242 let handler = Arc::new(handler);
243 self.add_message_handler(move |server, envelope| {
244 let receipt = envelope.receipt();
245 let handler = handler.clone();
246 async move {
247 let responded = Arc::new(AtomicBool::default());
248 let response = Response {
249 server: server.clone(),
250 responded: responded.clone(),
251 receipt: envelope.receipt(),
252 };
253 match (handler)(server.clone(), envelope, response).await {
254 Ok(()) => {
255 if responded.load(std::sync::atomic::Ordering::SeqCst) {
256 Ok(())
257 } else {
258 Err(anyhow!("handler did not send a response"))?
259 }
260 }
261 Err(error) => {
262 server.peer.respond_with_error(
263 receipt,
264 proto::Error {
265 message: error.to_string(),
266 },
267 )?;
268 Err(error)
269 }
270 }
271 }
272 })
273 }
274
275 pub fn handle_connection<E: Executor>(
276 self: &Arc<Self>,
277 connection: Connection,
278 address: String,
279 user: User,
280 mut send_connection_id: Option<mpsc::Sender<ConnectionId>>,
281 executor: E,
282 ) -> impl Future<Output = Result<()>> {
283 let mut this = self.clone();
284 let user_id = user.id;
285 let login = user.github_login;
286 let span = info_span!("handle connection", %user_id, %login, %address);
287 async move {
288 let (connection_id, handle_io, mut incoming_rx) = this
289 .peer
290 .add_connection(connection, {
291 let executor = executor.clone();
292 move |duration| {
293 let timer = executor.sleep(duration);
294 async move {
295 timer.await;
296 }
297 }
298 })
299 .await;
300
301 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
302
303 if let Some(send_connection_id) = send_connection_id.as_mut() {
304 let _ = send_connection_id.send(connection_id).await;
305 }
306
307 if !user.connected_once {
308 this.peer.send(connection_id, proto::ShowContacts {})?;
309 this.app_state.db.set_user_connected_once(user_id, true).await?;
310 }
311
312 let (contacts, invite_code) = future::try_join(
313 this.app_state.db.get_contacts(user_id),
314 this.app_state.db.get_invite_code_for_user(user_id)
315 ).await?;
316
317 {
318 let mut store = this.store_mut().await;
319 store.add_connection(connection_id, user_id);
320 this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
321
322 if let Some((code, count)) = invite_code {
323 this.peer.send(connection_id, proto::UpdateInviteInfo {
324 url: format!("{}{}", this.app_state.invite_link_prefix, code),
325 count,
326 })?;
327 }
328 }
329 this.update_user_contacts(user_id).await?;
330
331 let handle_io = handle_io.fuse();
332 futures::pin_mut!(handle_io);
333 loop {
334 let next_message = incoming_rx.next().fuse();
335 futures::pin_mut!(next_message);
336 futures::select_biased! {
337 result = handle_io => {
338 if let Err(error) = result {
339 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
340 }
341 break;
342 }
343 message = next_message => {
344 if let Some(message) = message {
345 let type_name = message.payload_type_name();
346 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
347 async {
348 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
349 let notifications = this.notifications.clone();
350 let is_background = message.is_background();
351 let handle_message = (handler)(this.clone(), message);
352 let handle_message = async move {
353 handle_message.await;
354 if let Some(mut notifications) = notifications {
355 let _ = notifications.send(()).await;
356 }
357 };
358 if is_background {
359 executor.spawn_detached(handle_message);
360 } else {
361 handle_message.await;
362 }
363 } else {
364 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
365 }
366 }.instrument(span).await;
367 } else {
368 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
369 break;
370 }
371 }
372 }
373 }
374
375 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
376 if let Err(error) = this.sign_out(connection_id).await {
377 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
378 }
379
380 Ok(())
381 }.instrument(span)
382 }
383
384 #[instrument(skip(self), err)]
385 async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> Result<()> {
386 self.peer.disconnect(connection_id);
387
388 let removed_user_id = {
389 let mut store = self.store_mut().await;
390 let removed_connection = store.remove_connection(connection_id)?;
391
392 for (project_id, project) in removed_connection.hosted_projects {
393 broadcast(connection_id, project.guests.keys().copied(), |conn_id| {
394 self.peer
395 .send(conn_id, proto::UnregisterProject { project_id })
396 });
397
398 for (_, receipts) in project.join_requests {
399 for receipt in receipts {
400 self.peer.respond(
401 receipt,
402 proto::JoinProjectResponse {
403 variant: Some(proto::join_project_response::Variant::Decline(
404 proto::join_project_response::Decline {
405 reason: proto::join_project_response::decline::Reason::WentOffline as i32
406 },
407 )),
408 },
409 )?;
410 }
411 }
412 }
413
414 for project_id in removed_connection.guest_project_ids {
415 if let Some(project) = store.project(project_id).trace_err() {
416 broadcast(connection_id, project.connection_ids(), |conn_id| {
417 self.peer.send(
418 conn_id,
419 proto::RemoveProjectCollaborator {
420 project_id,
421 peer_id: connection_id.0,
422 },
423 )
424 });
425 if project.guests.is_empty() {
426 self.peer
427 .send(
428 project.host_connection_id,
429 proto::ProjectUnshared { project_id },
430 )
431 .trace_err();
432 }
433 }
434 }
435
436 removed_connection.user_id
437 };
438
439 self.update_user_contacts(removed_user_id).await?;
440
441 Ok(())
442 }
443
444 pub async fn invite_code_redeemed(
445 self: &Arc<Self>,
446 code: &str,
447 invitee_id: UserId,
448 ) -> Result<()> {
449 let user = self.app_state.db.get_user_for_invite_code(code).await?;
450 let store = self.store().await;
451 let invitee_contact = store.contact_for_user(invitee_id, true);
452 for connection_id in store.connection_ids_for_user(user.id) {
453 self.peer.send(
454 connection_id,
455 proto::UpdateContacts {
456 contacts: vec![invitee_contact.clone()],
457 ..Default::default()
458 },
459 )?;
460 self.peer.send(
461 connection_id,
462 proto::UpdateInviteInfo {
463 url: format!("{}{}", self.app_state.invite_link_prefix, code),
464 count: user.invite_count as u32,
465 },
466 )?;
467 }
468 Ok(())
469 }
470
471 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
472 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
473 if let Some(invite_code) = &user.invite_code {
474 let store = self.store().await;
475 for connection_id in store.connection_ids_for_user(user_id) {
476 self.peer.send(
477 connection_id,
478 proto::UpdateInviteInfo {
479 url: format!("{}{}", self.app_state.invite_link_prefix, invite_code),
480 count: user.invite_count as u32,
481 },
482 )?;
483 }
484 }
485 }
486 Ok(())
487 }
488
489 async fn ping(
490 self: Arc<Server>,
491 _: TypedEnvelope<proto::Ping>,
492 response: Response<proto::Ping>,
493 ) -> Result<()> {
494 response.send(proto::Ack {})?;
495 Ok(())
496 }
497
498 async fn register_project(
499 self: Arc<Server>,
500 request: TypedEnvelope<proto::RegisterProject>,
501 response: Response<proto::RegisterProject>,
502 ) -> Result<()> {
503 let project_id;
504 {
505 let mut state = self.store_mut().await;
506 let user_id = state.user_id_for_connection(request.sender_id)?;
507 project_id = state.register_project(request.sender_id, user_id);
508 };
509
510 response.send(proto::RegisterProjectResponse { project_id })?;
511
512 Ok(())
513 }
514
515 async fn unregister_project(
516 self: Arc<Server>,
517 request: TypedEnvelope<proto::UnregisterProject>,
518 response: Response<proto::UnregisterProject>,
519 ) -> Result<()> {
520 let (user_id, project) = {
521 let mut state = self.store_mut().await;
522 let project =
523 state.unregister_project(request.payload.project_id, request.sender_id)?;
524 (state.user_id_for_connection(request.sender_id)?, project)
525 };
526
527 broadcast(
528 request.sender_id,
529 project.guests.keys().copied(),
530 |conn_id| {
531 self.peer.send(
532 conn_id,
533 proto::UnregisterProject {
534 project_id: request.payload.project_id,
535 },
536 )
537 },
538 );
539 for (_, receipts) in project.join_requests {
540 for receipt in receipts {
541 self.peer.respond(
542 receipt,
543 proto::JoinProjectResponse {
544 variant: Some(proto::join_project_response::Variant::Decline(
545 proto::join_project_response::Decline {
546 reason: proto::join_project_response::decline::Reason::Closed
547 as i32,
548 },
549 )),
550 },
551 )?;
552 }
553 }
554
555 // Send out the `UpdateContacts` message before responding to the unregister
556 // request. This way, when the project's host can keep track of the project's
557 // remote id until after they've received the `UpdateContacts` message for
558 // themself.
559 self.update_user_contacts(user_id).await?;
560 response.send(proto::Ack {})?;
561
562 Ok(())
563 }
564
565 async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
566 let contacts = self.app_state.db.get_contacts(user_id).await?;
567 let store = self.store().await;
568 let updated_contact = store.contact_for_user(user_id, false);
569 for contact in contacts {
570 if let db::Contact::Accepted {
571 user_id: contact_user_id,
572 ..
573 } = contact
574 {
575 for contact_conn_id in store.connection_ids_for_user(contact_user_id) {
576 self.peer
577 .send(
578 contact_conn_id,
579 proto::UpdateContacts {
580 contacts: vec![updated_contact.clone()],
581 remove_contacts: Default::default(),
582 incoming_requests: Default::default(),
583 remove_incoming_requests: Default::default(),
584 outgoing_requests: Default::default(),
585 remove_outgoing_requests: Default::default(),
586 },
587 )
588 .trace_err();
589 }
590 }
591 }
592 Ok(())
593 }
594
595 async fn join_project(
596 self: Arc<Server>,
597 request: TypedEnvelope<proto::JoinProject>,
598 response: Response<proto::JoinProject>,
599 ) -> Result<()> {
600 let project_id = request.payload.project_id;
601
602 let host_user_id;
603 let guest_user_id;
604 let host_connection_id;
605 {
606 let state = self.store().await;
607 let project = state.project(project_id)?;
608 host_user_id = project.host_user_id;
609 host_connection_id = project.host_connection_id;
610 guest_user_id = state.user_id_for_connection(request.sender_id)?;
611 };
612
613 tracing::info!(project_id, %host_user_id, %host_connection_id, "join project");
614 let has_contact = self
615 .app_state
616 .db
617 .has_contact(guest_user_id, host_user_id)
618 .await?;
619 if !has_contact {
620 return Err(anyhow!("no such project"))?;
621 }
622
623 self.store_mut().await.request_join_project(
624 guest_user_id,
625 project_id,
626 response.into_receipt(),
627 )?;
628 self.peer.send(
629 host_connection_id,
630 proto::RequestJoinProject {
631 project_id,
632 requester_id: guest_user_id.to_proto(),
633 },
634 )?;
635 Ok(())
636 }
637
638 async fn respond_to_join_project_request(
639 self: Arc<Server>,
640 request: TypedEnvelope<proto::RespondToJoinProjectRequest>,
641 ) -> Result<()> {
642 let host_user_id;
643
644 {
645 let mut state = self.store_mut().await;
646 let project_id = request.payload.project_id;
647 let project = state.project(project_id)?;
648 if project.host_connection_id != request.sender_id {
649 Err(anyhow!("no such connection"))?;
650 }
651
652 host_user_id = project.host_user_id;
653 let guest_user_id = UserId::from_proto(request.payload.requester_id);
654
655 if !request.payload.allow {
656 let receipts = state
657 .deny_join_project_request(request.sender_id, guest_user_id, project_id)
658 .ok_or_else(|| anyhow!("no such request"))?;
659 for receipt in receipts {
660 self.peer.respond(
661 receipt,
662 proto::JoinProjectResponse {
663 variant: Some(proto::join_project_response::Variant::Decline(
664 proto::join_project_response::Decline {
665 reason: proto::join_project_response::decline::Reason::Declined
666 as i32,
667 },
668 )),
669 },
670 )?;
671 }
672 return Ok(());
673 }
674
675 let (receipts_with_replica_ids, project) = state
676 .accept_join_project_request(request.sender_id, guest_user_id, project_id)
677 .ok_or_else(|| anyhow!("no such request"))?;
678
679 let peer_count = project.guests.len();
680 let mut collaborators = Vec::with_capacity(peer_count);
681 collaborators.push(proto::Collaborator {
682 peer_id: project.host_connection_id.0,
683 replica_id: 0,
684 user_id: project.host_user_id.to_proto(),
685 });
686 let worktrees = project
687 .worktrees
688 .iter()
689 .filter_map(|(id, shared_worktree)| {
690 let worktree = project.worktrees.get(&id)?;
691 Some(proto::Worktree {
692 id: *id,
693 root_name: worktree.root_name.clone(),
694 entries: shared_worktree.entries.values().cloned().collect(),
695 diagnostic_summaries: shared_worktree
696 .diagnostic_summaries
697 .values()
698 .cloned()
699 .collect(),
700 visible: worktree.visible,
701 scan_id: shared_worktree.scan_id,
702 })
703 })
704 .collect::<Vec<_>>();
705
706 // Add all guests other than the requesting user's own connections as collaborators
707 for (peer_conn_id, (peer_replica_id, peer_user_id)) in &project.guests {
708 if receipts_with_replica_ids
709 .iter()
710 .all(|(receipt, _)| receipt.sender_id != *peer_conn_id)
711 {
712 collaborators.push(proto::Collaborator {
713 peer_id: peer_conn_id.0,
714 replica_id: *peer_replica_id as u32,
715 user_id: peer_user_id.to_proto(),
716 });
717 }
718 }
719
720 for conn_id in project.connection_ids() {
721 for (receipt, replica_id) in &receipts_with_replica_ids {
722 if conn_id != receipt.sender_id {
723 self.peer.send(
724 conn_id,
725 proto::AddProjectCollaborator {
726 project_id,
727 collaborator: Some(proto::Collaborator {
728 peer_id: receipt.sender_id.0,
729 replica_id: *replica_id as u32,
730 user_id: guest_user_id.to_proto(),
731 }),
732 },
733 )?;
734 }
735 }
736 }
737
738 for (receipt, replica_id) in receipts_with_replica_ids {
739 self.peer.respond(
740 receipt,
741 proto::JoinProjectResponse {
742 variant: Some(proto::join_project_response::Variant::Accept(
743 proto::join_project_response::Accept {
744 worktrees: worktrees.clone(),
745 replica_id: replica_id as u32,
746 collaborators: collaborators.clone(),
747 language_servers: project.language_servers.clone(),
748 },
749 )),
750 },
751 )?;
752 }
753 }
754
755 self.update_user_contacts(host_user_id).await?;
756 Ok(())
757 }
758
759 async fn leave_project(
760 self: Arc<Server>,
761 request: TypedEnvelope<proto::LeaveProject>,
762 ) -> Result<()> {
763 let sender_id = request.sender_id;
764 let project_id = request.payload.project_id;
765 let project;
766 {
767 let mut store = self.store_mut().await;
768 project = store.leave_project(sender_id, project_id)?;
769 tracing::info!(
770 project_id,
771 host_user_id = %project.host_user_id,
772 host_connection_id = %project.host_connection_id,
773 "leave project"
774 );
775
776 if project.remove_collaborator {
777 broadcast(sender_id, project.connection_ids, |conn_id| {
778 self.peer.send(
779 conn_id,
780 proto::RemoveProjectCollaborator {
781 project_id,
782 peer_id: sender_id.0,
783 },
784 )
785 });
786 }
787
788 if let Some(requester_id) = project.cancel_request {
789 self.peer.send(
790 project.host_connection_id,
791 proto::JoinProjectRequestCancelled {
792 project_id,
793 requester_id: requester_id.to_proto(),
794 },
795 )?;
796 }
797
798 if project.unshare {
799 self.peer.send(
800 project.host_connection_id,
801 proto::ProjectUnshared { project_id },
802 )?;
803 }
804 }
805 self.update_user_contacts(project.host_user_id).await?;
806 Ok(())
807 }
808
809 async fn update_project(
810 self: Arc<Server>,
811 request: TypedEnvelope<proto::UpdateProject>,
812 ) -> Result<()> {
813 let user_id;
814 {
815 let mut state = self.store_mut().await;
816 user_id = state.user_id_for_connection(request.sender_id)?;
817 let guest_connection_ids = state
818 .read_project(request.payload.project_id, request.sender_id)?
819 .guest_connection_ids();
820 state.update_project(
821 request.payload.project_id,
822 &request.payload.worktrees,
823 request.sender_id,
824 )?;
825 broadcast(request.sender_id, guest_connection_ids, |connection_id| {
826 self.peer
827 .forward_send(request.sender_id, connection_id, request.payload.clone())
828 });
829 };
830 self.update_user_contacts(user_id).await?;
831 Ok(())
832 }
833
834 async fn update_worktree(
835 self: Arc<Server>,
836 request: TypedEnvelope<proto::UpdateWorktree>,
837 response: Response<proto::UpdateWorktree>,
838 ) -> Result<()> {
839 let (connection_ids, metadata_changed) = {
840 let mut store = self.store_mut().await;
841 let (connection_ids, metadata_changed, extension_counts) = store.update_worktree(
842 request.sender_id,
843 request.payload.project_id,
844 request.payload.worktree_id,
845 &request.payload.root_name,
846 &request.payload.removed_entries,
847 &request.payload.updated_entries,
848 request.payload.scan_id,
849 )?;
850 for (extension, count) in extension_counts {
851 tracing::info!(
852 project_id = request.payload.project_id,
853 worktree_id = request.payload.worktree_id,
854 ?extension,
855 %count,
856 "worktree updated"
857 );
858 }
859 (connection_ids, metadata_changed)
860 };
861
862 broadcast(request.sender_id, connection_ids, |connection_id| {
863 self.peer
864 .forward_send(request.sender_id, connection_id, request.payload.clone())
865 });
866 if metadata_changed {
867 let user_id = self
868 .store()
869 .await
870 .user_id_for_connection(request.sender_id)?;
871 self.update_user_contacts(user_id).await?;
872 }
873 response.send(proto::Ack {})?;
874 Ok(())
875 }
876
877 async fn update_diagnostic_summary(
878 self: Arc<Server>,
879 request: TypedEnvelope<proto::UpdateDiagnosticSummary>,
880 ) -> Result<()> {
881 let summary = request
882 .payload
883 .summary
884 .clone()
885 .ok_or_else(|| anyhow!("invalid summary"))?;
886 let receiver_ids = self.store_mut().await.update_diagnostic_summary(
887 request.payload.project_id,
888 request.payload.worktree_id,
889 request.sender_id,
890 summary,
891 )?;
892
893 broadcast(request.sender_id, receiver_ids, |connection_id| {
894 self.peer
895 .forward_send(request.sender_id, connection_id, request.payload.clone())
896 });
897 Ok(())
898 }
899
900 async fn start_language_server(
901 self: Arc<Server>,
902 request: TypedEnvelope<proto::StartLanguageServer>,
903 ) -> Result<()> {
904 let receiver_ids = self.store_mut().await.start_language_server(
905 request.payload.project_id,
906 request.sender_id,
907 request
908 .payload
909 .server
910 .clone()
911 .ok_or_else(|| anyhow!("invalid language server"))?,
912 )?;
913 broadcast(request.sender_id, receiver_ids, |connection_id| {
914 self.peer
915 .forward_send(request.sender_id, connection_id, request.payload.clone())
916 });
917 Ok(())
918 }
919
920 async fn update_language_server(
921 self: Arc<Server>,
922 request: TypedEnvelope<proto::UpdateLanguageServer>,
923 ) -> Result<()> {
924 let receiver_ids = self
925 .store()
926 .await
927 .project_connection_ids(request.payload.project_id, request.sender_id)?;
928 broadcast(request.sender_id, receiver_ids, |connection_id| {
929 self.peer
930 .forward_send(request.sender_id, connection_id, request.payload.clone())
931 });
932 Ok(())
933 }
934
935 async fn forward_project_request<T>(
936 self: Arc<Server>,
937 request: TypedEnvelope<T>,
938 response: Response<T>,
939 ) -> Result<()>
940 where
941 T: EntityMessage + RequestMessage,
942 {
943 let host_connection_id = self
944 .store()
945 .await
946 .read_project(request.payload.remote_entity_id(), request.sender_id)?
947 .host_connection_id;
948
949 response.send(
950 self.peer
951 .forward_request(request.sender_id, host_connection_id, request.payload)
952 .await?,
953 )?;
954 Ok(())
955 }
956
957 async fn save_buffer(
958 self: Arc<Server>,
959 request: TypedEnvelope<proto::SaveBuffer>,
960 response: Response<proto::SaveBuffer>,
961 ) -> Result<()> {
962 let host = self
963 .store()
964 .await
965 .read_project(request.payload.project_id, request.sender_id)?
966 .host_connection_id;
967 let response_payload = self
968 .peer
969 .forward_request(request.sender_id, host, request.payload.clone())
970 .await?;
971
972 let mut guests = self
973 .store()
974 .await
975 .read_project(request.payload.project_id, request.sender_id)?
976 .connection_ids();
977 guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id);
978 broadcast(host, guests, |conn_id| {
979 self.peer
980 .forward_send(host, conn_id, response_payload.clone())
981 });
982 response.send(response_payload)?;
983 Ok(())
984 }
985
986 async fn update_buffer(
987 self: Arc<Server>,
988 request: TypedEnvelope<proto::UpdateBuffer>,
989 response: Response<proto::UpdateBuffer>,
990 ) -> Result<()> {
991 let receiver_ids = self
992 .store()
993 .await
994 .project_connection_ids(request.payload.project_id, request.sender_id)?;
995 broadcast(request.sender_id, receiver_ids, |connection_id| {
996 self.peer
997 .forward_send(request.sender_id, connection_id, request.payload.clone())
998 });
999 response.send(proto::Ack {})?;
1000 Ok(())
1001 }
1002
1003 async fn update_buffer_file(
1004 self: Arc<Server>,
1005 request: TypedEnvelope<proto::UpdateBufferFile>,
1006 ) -> Result<()> {
1007 let receiver_ids = self
1008 .store()
1009 .await
1010 .project_connection_ids(request.payload.project_id, request.sender_id)?;
1011 broadcast(request.sender_id, receiver_ids, |connection_id| {
1012 self.peer
1013 .forward_send(request.sender_id, connection_id, request.payload.clone())
1014 });
1015 Ok(())
1016 }
1017
1018 async fn buffer_reloaded(
1019 self: Arc<Server>,
1020 request: TypedEnvelope<proto::BufferReloaded>,
1021 ) -> Result<()> {
1022 let receiver_ids = self
1023 .store()
1024 .await
1025 .project_connection_ids(request.payload.project_id, request.sender_id)?;
1026 broadcast(request.sender_id, receiver_ids, |connection_id| {
1027 self.peer
1028 .forward_send(request.sender_id, connection_id, request.payload.clone())
1029 });
1030 Ok(())
1031 }
1032
1033 async fn buffer_saved(
1034 self: Arc<Server>,
1035 request: TypedEnvelope<proto::BufferSaved>,
1036 ) -> Result<()> {
1037 let receiver_ids = self
1038 .store()
1039 .await
1040 .project_connection_ids(request.payload.project_id, request.sender_id)?;
1041 broadcast(request.sender_id, receiver_ids, |connection_id| {
1042 self.peer
1043 .forward_send(request.sender_id, connection_id, request.payload.clone())
1044 });
1045 Ok(())
1046 }
1047
1048 async fn follow(
1049 self: Arc<Self>,
1050 request: TypedEnvelope<proto::Follow>,
1051 response: Response<proto::Follow>,
1052 ) -> Result<()> {
1053 let leader_id = ConnectionId(request.payload.leader_id);
1054 let follower_id = request.sender_id;
1055 if !self
1056 .store()
1057 .await
1058 .project_connection_ids(request.payload.project_id, follower_id)?
1059 .contains(&leader_id)
1060 {
1061 Err(anyhow!("no such peer"))?;
1062 }
1063 let mut response_payload = self
1064 .peer
1065 .forward_request(request.sender_id, leader_id, request.payload)
1066 .await?;
1067 response_payload
1068 .views
1069 .retain(|view| view.leader_id != Some(follower_id.0));
1070 response.send(response_payload)?;
1071 Ok(())
1072 }
1073
1074 async fn unfollow(self: Arc<Self>, request: TypedEnvelope<proto::Unfollow>) -> Result<()> {
1075 let leader_id = ConnectionId(request.payload.leader_id);
1076 if !self
1077 .store()
1078 .await
1079 .project_connection_ids(request.payload.project_id, request.sender_id)?
1080 .contains(&leader_id)
1081 {
1082 Err(anyhow!("no such peer"))?;
1083 }
1084 self.peer
1085 .forward_send(request.sender_id, leader_id, request.payload)?;
1086 Ok(())
1087 }
1088
1089 async fn update_followers(
1090 self: Arc<Self>,
1091 request: TypedEnvelope<proto::UpdateFollowers>,
1092 ) -> Result<()> {
1093 let connection_ids = self
1094 .store()
1095 .await
1096 .project_connection_ids(request.payload.project_id, request.sender_id)?;
1097 let leader_id = request
1098 .payload
1099 .variant
1100 .as_ref()
1101 .and_then(|variant| match variant {
1102 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1103 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1104 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1105 });
1106 for follower_id in &request.payload.follower_ids {
1107 let follower_id = ConnectionId(*follower_id);
1108 if connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1109 self.peer
1110 .forward_send(request.sender_id, follower_id, request.payload.clone())?;
1111 }
1112 }
1113 Ok(())
1114 }
1115
1116 async fn get_channels(
1117 self: Arc<Server>,
1118 request: TypedEnvelope<proto::GetChannels>,
1119 response: Response<proto::GetChannels>,
1120 ) -> Result<()> {
1121 let user_id = self
1122 .store()
1123 .await
1124 .user_id_for_connection(request.sender_id)?;
1125 let channels = self.app_state.db.get_accessible_channels(user_id).await?;
1126 response.send(proto::GetChannelsResponse {
1127 channels: channels
1128 .into_iter()
1129 .map(|chan| proto::Channel {
1130 id: chan.id.to_proto(),
1131 name: chan.name,
1132 })
1133 .collect(),
1134 })?;
1135 Ok(())
1136 }
1137
1138 async fn get_users(
1139 self: Arc<Server>,
1140 request: TypedEnvelope<proto::GetUsers>,
1141 response: Response<proto::GetUsers>,
1142 ) -> Result<()> {
1143 let user_ids = request
1144 .payload
1145 .user_ids
1146 .into_iter()
1147 .map(UserId::from_proto)
1148 .collect();
1149 let users = self
1150 .app_state
1151 .db
1152 .get_users_by_ids(user_ids)
1153 .await?
1154 .into_iter()
1155 .map(|user| proto::User {
1156 id: user.id.to_proto(),
1157 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1158 github_login: user.github_login,
1159 })
1160 .collect();
1161 response.send(proto::UsersResponse { users })?;
1162 Ok(())
1163 }
1164
1165 async fn fuzzy_search_users(
1166 self: Arc<Server>,
1167 request: TypedEnvelope<proto::FuzzySearchUsers>,
1168 response: Response<proto::FuzzySearchUsers>,
1169 ) -> Result<()> {
1170 let user_id = self
1171 .store()
1172 .await
1173 .user_id_for_connection(request.sender_id)?;
1174 let query = request.payload.query;
1175 let db = &self.app_state.db;
1176 let users = match query.len() {
1177 0 => vec![],
1178 1 | 2 => db
1179 .get_user_by_github_login(&query)
1180 .await?
1181 .into_iter()
1182 .collect(),
1183 _ => db.fuzzy_search_users(&query, 10).await?,
1184 };
1185 let users = users
1186 .into_iter()
1187 .filter(|user| user.id != user_id)
1188 .map(|user| proto::User {
1189 id: user.id.to_proto(),
1190 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1191 github_login: user.github_login,
1192 })
1193 .collect();
1194 response.send(proto::UsersResponse { users })?;
1195 Ok(())
1196 }
1197
1198 async fn request_contact(
1199 self: Arc<Server>,
1200 request: TypedEnvelope<proto::RequestContact>,
1201 response: Response<proto::RequestContact>,
1202 ) -> Result<()> {
1203 let requester_id = self
1204 .store()
1205 .await
1206 .user_id_for_connection(request.sender_id)?;
1207 let responder_id = UserId::from_proto(request.payload.responder_id);
1208 if requester_id == responder_id {
1209 return Err(anyhow!("cannot add yourself as a contact"))?;
1210 }
1211
1212 self.app_state
1213 .db
1214 .send_contact_request(requester_id, responder_id)
1215 .await?;
1216
1217 // Update outgoing contact requests of requester
1218 let mut update = proto::UpdateContacts::default();
1219 update.outgoing_requests.push(responder_id.to_proto());
1220 for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1221 self.peer.send(connection_id, update.clone())?;
1222 }
1223
1224 // Update incoming contact requests of responder
1225 let mut update = proto::UpdateContacts::default();
1226 update
1227 .incoming_requests
1228 .push(proto::IncomingContactRequest {
1229 requester_id: requester_id.to_proto(),
1230 should_notify: true,
1231 });
1232 for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1233 self.peer.send(connection_id, update.clone())?;
1234 }
1235
1236 response.send(proto::Ack {})?;
1237 Ok(())
1238 }
1239
1240 async fn respond_to_contact_request(
1241 self: Arc<Server>,
1242 request: TypedEnvelope<proto::RespondToContactRequest>,
1243 response: Response<proto::RespondToContactRequest>,
1244 ) -> Result<()> {
1245 let responder_id = self
1246 .store()
1247 .await
1248 .user_id_for_connection(request.sender_id)?;
1249 let requester_id = UserId::from_proto(request.payload.requester_id);
1250 if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 {
1251 self.app_state
1252 .db
1253 .dismiss_contact_notification(responder_id, requester_id)
1254 .await?;
1255 } else {
1256 let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32;
1257 self.app_state
1258 .db
1259 .respond_to_contact_request(responder_id, requester_id, accept)
1260 .await?;
1261
1262 let store = self.store().await;
1263 // Update responder with new contact
1264 let mut update = proto::UpdateContacts::default();
1265 if accept {
1266 update
1267 .contacts
1268 .push(store.contact_for_user(requester_id, false));
1269 }
1270 update
1271 .remove_incoming_requests
1272 .push(requester_id.to_proto());
1273 for connection_id in store.connection_ids_for_user(responder_id) {
1274 self.peer.send(connection_id, update.clone())?;
1275 }
1276
1277 // Update requester with new contact
1278 let mut update = proto::UpdateContacts::default();
1279 if accept {
1280 update
1281 .contacts
1282 .push(store.contact_for_user(responder_id, true));
1283 }
1284 update
1285 .remove_outgoing_requests
1286 .push(responder_id.to_proto());
1287 for connection_id in store.connection_ids_for_user(requester_id) {
1288 self.peer.send(connection_id, update.clone())?;
1289 }
1290 }
1291
1292 response.send(proto::Ack {})?;
1293 Ok(())
1294 }
1295
1296 async fn remove_contact(
1297 self: Arc<Server>,
1298 request: TypedEnvelope<proto::RemoveContact>,
1299 response: Response<proto::RemoveContact>,
1300 ) -> Result<()> {
1301 let requester_id = self
1302 .store()
1303 .await
1304 .user_id_for_connection(request.sender_id)?;
1305 let responder_id = UserId::from_proto(request.payload.user_id);
1306 self.app_state
1307 .db
1308 .remove_contact(requester_id, responder_id)
1309 .await?;
1310
1311 // Update outgoing contact requests of requester
1312 let mut update = proto::UpdateContacts::default();
1313 update
1314 .remove_outgoing_requests
1315 .push(responder_id.to_proto());
1316 for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1317 self.peer.send(connection_id, update.clone())?;
1318 }
1319
1320 // Update incoming contact requests of responder
1321 let mut update = proto::UpdateContacts::default();
1322 update
1323 .remove_incoming_requests
1324 .push(requester_id.to_proto());
1325 for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1326 self.peer.send(connection_id, update.clone())?;
1327 }
1328
1329 response.send(proto::Ack {})?;
1330 Ok(())
1331 }
1332
1333 async fn join_channel(
1334 self: Arc<Self>,
1335 request: TypedEnvelope<proto::JoinChannel>,
1336 response: Response<proto::JoinChannel>,
1337 ) -> Result<()> {
1338 let user_id = self
1339 .store()
1340 .await
1341 .user_id_for_connection(request.sender_id)?;
1342 let channel_id = ChannelId::from_proto(request.payload.channel_id);
1343 if !self
1344 .app_state
1345 .db
1346 .can_user_access_channel(user_id, channel_id)
1347 .await?
1348 {
1349 Err(anyhow!("access denied"))?;
1350 }
1351
1352 self.store_mut()
1353 .await
1354 .join_channel(request.sender_id, channel_id);
1355 let messages = self
1356 .app_state
1357 .db
1358 .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
1359 .await?
1360 .into_iter()
1361 .map(|msg| proto::ChannelMessage {
1362 id: msg.id.to_proto(),
1363 body: msg.body,
1364 timestamp: msg.sent_at.unix_timestamp() as u64,
1365 sender_id: msg.sender_id.to_proto(),
1366 nonce: Some(msg.nonce.as_u128().into()),
1367 })
1368 .collect::<Vec<_>>();
1369 response.send(proto::JoinChannelResponse {
1370 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1371 messages,
1372 })?;
1373 Ok(())
1374 }
1375
1376 async fn leave_channel(
1377 self: Arc<Self>,
1378 request: TypedEnvelope<proto::LeaveChannel>,
1379 ) -> Result<()> {
1380 let user_id = self
1381 .store()
1382 .await
1383 .user_id_for_connection(request.sender_id)?;
1384 let channel_id = ChannelId::from_proto(request.payload.channel_id);
1385 if !self
1386 .app_state
1387 .db
1388 .can_user_access_channel(user_id, channel_id)
1389 .await?
1390 {
1391 Err(anyhow!("access denied"))?;
1392 }
1393
1394 self.store_mut()
1395 .await
1396 .leave_channel(request.sender_id, channel_id);
1397
1398 Ok(())
1399 }
1400
1401 async fn send_channel_message(
1402 self: Arc<Self>,
1403 request: TypedEnvelope<proto::SendChannelMessage>,
1404 response: Response<proto::SendChannelMessage>,
1405 ) -> Result<()> {
1406 let channel_id = ChannelId::from_proto(request.payload.channel_id);
1407 let user_id;
1408 let connection_ids;
1409 {
1410 let state = self.store().await;
1411 user_id = state.user_id_for_connection(request.sender_id)?;
1412 connection_ids = state.channel_connection_ids(channel_id)?;
1413 }
1414
1415 // Validate the message body.
1416 let body = request.payload.body.trim().to_string();
1417 if body.len() > MAX_MESSAGE_LEN {
1418 return Err(anyhow!("message is too long"))?;
1419 }
1420 if body.is_empty() {
1421 return Err(anyhow!("message can't be blank"))?;
1422 }
1423
1424 let timestamp = OffsetDateTime::now_utc();
1425 let nonce = request
1426 .payload
1427 .nonce
1428 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
1429
1430 let message_id = self
1431 .app_state
1432 .db
1433 .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
1434 .await?
1435 .to_proto();
1436 let message = proto::ChannelMessage {
1437 sender_id: user_id.to_proto(),
1438 id: message_id,
1439 body,
1440 timestamp: timestamp.unix_timestamp() as u64,
1441 nonce: Some(nonce),
1442 };
1443 broadcast(request.sender_id, connection_ids, |conn_id| {
1444 self.peer.send(
1445 conn_id,
1446 proto::ChannelMessageSent {
1447 channel_id: channel_id.to_proto(),
1448 message: Some(message.clone()),
1449 },
1450 )
1451 });
1452 response.send(proto::SendChannelMessageResponse {
1453 message: Some(message),
1454 })?;
1455 Ok(())
1456 }
1457
1458 async fn get_channel_messages(
1459 self: Arc<Self>,
1460 request: TypedEnvelope<proto::GetChannelMessages>,
1461 response: Response<proto::GetChannelMessages>,
1462 ) -> Result<()> {
1463 let user_id = self
1464 .store()
1465 .await
1466 .user_id_for_connection(request.sender_id)?;
1467 let channel_id = ChannelId::from_proto(request.payload.channel_id);
1468 if !self
1469 .app_state
1470 .db
1471 .can_user_access_channel(user_id, channel_id)
1472 .await?
1473 {
1474 Err(anyhow!("access denied"))?;
1475 }
1476
1477 let messages = self
1478 .app_state
1479 .db
1480 .get_channel_messages(
1481 channel_id,
1482 MESSAGE_COUNT_PER_PAGE,
1483 Some(MessageId::from_proto(request.payload.before_message_id)),
1484 )
1485 .await?
1486 .into_iter()
1487 .map(|msg| proto::ChannelMessage {
1488 id: msg.id.to_proto(),
1489 body: msg.body,
1490 timestamp: msg.sent_at.unix_timestamp() as u64,
1491 sender_id: msg.sender_id.to_proto(),
1492 nonce: Some(msg.nonce.as_u128().into()),
1493 })
1494 .collect::<Vec<_>>();
1495 response.send(proto::GetChannelMessagesResponse {
1496 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
1497 messages,
1498 })?;
1499 Ok(())
1500 }
1501
1502 async fn store<'a>(self: &'a Arc<Self>) -> StoreReadGuard<'a> {
1503 #[cfg(test)]
1504 tokio::task::yield_now().await;
1505 let guard = self.store.read().await;
1506 #[cfg(test)]
1507 tokio::task::yield_now().await;
1508 StoreReadGuard {
1509 guard,
1510 _not_send: PhantomData,
1511 }
1512 }
1513
1514 async fn store_mut<'a>(self: &'a Arc<Self>) -> StoreWriteGuard<'a> {
1515 #[cfg(test)]
1516 tokio::task::yield_now().await;
1517 let guard = self.store.write().await;
1518 #[cfg(test)]
1519 tokio::task::yield_now().await;
1520 StoreWriteGuard {
1521 guard,
1522 _not_send: PhantomData,
1523 }
1524 }
1525
1526 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1527 ServerSnapshot {
1528 store: self.store.read().await,
1529 peer: &self.peer,
1530 }
1531 }
1532}
1533
1534impl<'a> Deref for StoreReadGuard<'a> {
1535 type Target = Store;
1536
1537 fn deref(&self) -> &Self::Target {
1538 &*self.guard
1539 }
1540}
1541
1542impl<'a> Deref for StoreWriteGuard<'a> {
1543 type Target = Store;
1544
1545 fn deref(&self) -> &Self::Target {
1546 &*self.guard
1547 }
1548}
1549
1550impl<'a> DerefMut for StoreWriteGuard<'a> {
1551 fn deref_mut(&mut self) -> &mut Self::Target {
1552 &mut *self.guard
1553 }
1554}
1555
1556impl<'a> Drop for StoreWriteGuard<'a> {
1557 fn drop(&mut self) {
1558 #[cfg(test)]
1559 self.check_invariants();
1560
1561 let metrics = self.metrics();
1562 tracing::info!(
1563 connections = metrics.connections,
1564 registered_projects = metrics.registered_projects,
1565 shared_projects = metrics.shared_projects,
1566 "metrics"
1567 );
1568 }
1569}
1570
1571impl Executor for RealExecutor {
1572 type Sleep = Sleep;
1573
1574 fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
1575 tokio::task::spawn(future);
1576 }
1577
1578 fn sleep(&self, duration: Duration) -> Self::Sleep {
1579 tokio::time::sleep(duration)
1580 }
1581}
1582
1583fn broadcast<F>(
1584 sender_id: ConnectionId,
1585 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1586 mut f: F,
1587) where
1588 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1589{
1590 for receiver_id in receiver_ids {
1591 if receiver_id != sender_id {
1592 f(receiver_id).trace_err();
1593 }
1594 }
1595}
1596
1597lazy_static! {
1598 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
1599}
1600
1601pub struct ProtocolVersion(u32);
1602
1603impl Header for ProtocolVersion {
1604 fn name() -> &'static HeaderName {
1605 &ZED_PROTOCOL_VERSION
1606 }
1607
1608 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1609 where
1610 Self: Sized,
1611 I: Iterator<Item = &'i axum::http::HeaderValue>,
1612 {
1613 let version = values
1614 .next()
1615 .ok_or_else(|| axum::headers::Error::invalid())?
1616 .to_str()
1617 .map_err(|_| axum::headers::Error::invalid())?
1618 .parse()
1619 .map_err(|_| axum::headers::Error::invalid())?;
1620 Ok(Self(version))
1621 }
1622
1623 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1624 values.extend([self.0.to_string().parse().unwrap()]);
1625 }
1626}
1627
1628pub fn routes(server: Arc<Server>) -> Router<Body> {
1629 Router::new()
1630 .route("/rpc", get(handle_websocket_request))
1631 .layer(
1632 ServiceBuilder::new()
1633 .layer(Extension(server.app_state.clone()))
1634 .layer(middleware::from_fn(auth::validate_header))
1635 .layer(Extension(server)),
1636 )
1637}
1638
1639pub async fn handle_websocket_request(
1640 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1641 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1642 Extension(server): Extension<Arc<Server>>,
1643 Extension(user): Extension<User>,
1644 ws: WebSocketUpgrade,
1645) -> axum::response::Response {
1646 if protocol_version != rpc::PROTOCOL_VERSION {
1647 return (
1648 StatusCode::UPGRADE_REQUIRED,
1649 "client must be upgraded".to_string(),
1650 )
1651 .into_response();
1652 }
1653 let socket_address = socket_address.to_string();
1654 ws.on_upgrade(move |socket| {
1655 use util::ResultExt;
1656 let socket = socket
1657 .map_ok(to_tungstenite_message)
1658 .err_into()
1659 .with(|message| async move { Ok(to_axum_message(message)) });
1660 let connection = Connection::new(Box::pin(socket));
1661 async move {
1662 server
1663 .handle_connection(connection, socket_address, user, None, RealExecutor)
1664 .await
1665 .log_err();
1666 }
1667 })
1668}
1669
1670fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1671 match message {
1672 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1673 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1674 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1675 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1676 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1677 code: frame.code.into(),
1678 reason: frame.reason,
1679 })),
1680 }
1681}
1682
1683fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1684 match message {
1685 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1686 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1687 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1688 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1689 AxumMessage::Close(frame) => {
1690 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1691 code: frame.code.into(),
1692 reason: frame.reason,
1693 }))
1694 }
1695 }
1696}
1697
1698pub trait ResultExt {
1699 type Ok;
1700
1701 fn trace_err(self) -> Option<Self::Ok>;
1702}
1703
1704impl<T, E> ResultExt for Result<T, E>
1705where
1706 E: std::fmt::Debug,
1707{
1708 type Ok = T;
1709
1710 fn trace_err(self) -> Option<T> {
1711 match self {
1712 Ok(value) => Some(value),
1713 Err(error) => {
1714 tracing::error!("{:?}", error);
1715 None
1716 }
1717 }
1718 }
1719}