1mod store;
2
3use crate::{
4 auth,
5 db::{self, ProjectId, RoomId, User, UserId},
6 AppState, Result,
7};
8use anyhow::anyhow;
9use async_tungstenite::tungstenite::{
10 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
11};
12use axum::{
13 body::Body,
14 extract::{
15 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
16 ConnectInfo, WebSocketUpgrade,
17 },
18 headers::{Header, HeaderName},
19 http::StatusCode,
20 middleware,
21 response::IntoResponse,
22 routing::get,
23 Extension, Router, TypedHeader,
24};
25use collections::{HashMap, HashSet};
26use futures::{
27 channel::oneshot,
28 future::{self, BoxFuture},
29 stream::FuturesUnordered,
30 FutureExt, SinkExt, StreamExt, TryStreamExt,
31};
32use lazy_static::lazy_static;
33use prometheus::{register_int_gauge, IntGauge};
34use rpc::{
35 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
36 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
37};
38use serde::{Serialize, Serializer};
39use std::{
40 any::TypeId,
41 future::Future,
42 marker::PhantomData,
43 net::SocketAddr,
44 ops::{Deref, DerefMut},
45 rc::Rc,
46 sync::{
47 atomic::{AtomicBool, Ordering::SeqCst},
48 Arc,
49 },
50 time::Duration,
51};
52pub use store::{Store, Worktree};
53use tokio::{
54 sync::{Mutex, MutexGuard},
55 time::Sleep,
56};
57use tower::ServiceBuilder;
58use tracing::{info_span, instrument, Instrument};
59
60lazy_static! {
61 static ref METRIC_CONNECTIONS: IntGauge =
62 register_int_gauge!("connections", "number of connections").unwrap();
63 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
64 "shared_projects",
65 "number of open projects with one or more guests"
66 )
67 .unwrap();
68}
69
70type MessageHandler = Box<
71 dyn Send + Sync + Fn(Arc<Server>, UserId, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>,
72>;
73
74struct Message<T> {
75 sender_user_id: UserId,
76 sender_connection_id: ConnectionId,
77 payload: T,
78}
79
80struct Response<R> {
81 server: Arc<Server>,
82 receipt: Receipt<R>,
83 responded: Arc<AtomicBool>,
84}
85
86impl<R: RequestMessage> Response<R> {
87 fn send(self, payload: R::Response) -> Result<()> {
88 self.responded.store(true, SeqCst);
89 self.server.peer.respond(self.receipt, payload)?;
90 Ok(())
91 }
92}
93
94pub struct Server {
95 peer: Arc<Peer>,
96 pub(crate) store: Mutex<Store>,
97 app_state: Arc<AppState>,
98 handlers: HashMap<TypeId, MessageHandler>,
99}
100
101pub trait Executor: Send + Clone {
102 type Sleep: Send + Future;
103 fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
104 fn sleep(&self, duration: Duration) -> Self::Sleep;
105}
106
107#[derive(Clone)]
108pub struct RealExecutor;
109
110pub(crate) struct StoreGuard<'a> {
111 guard: MutexGuard<'a, Store>,
112 _not_send: PhantomData<Rc<()>>,
113}
114
115#[derive(Serialize)]
116pub struct ServerSnapshot<'a> {
117 peer: &'a Peer,
118 #[serde(serialize_with = "serialize_deref")]
119 store: StoreGuard<'a>,
120}
121
122pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
123where
124 S: Serializer,
125 T: Deref<Target = U>,
126 U: Serialize,
127{
128 Serialize::serialize(value.deref(), serializer)
129}
130
131impl Server {
132 pub fn new(app_state: Arc<AppState>) -> Arc<Self> {
133 let mut server = Self {
134 peer: Peer::new(),
135 app_state,
136 store: Default::default(),
137 handlers: Default::default(),
138 };
139
140 server
141 .add_request_handler(Server::ping)
142 .add_request_handler(Server::create_room)
143 .add_request_handler(Server::join_room)
144 .add_message_handler(Server::leave_room)
145 .add_request_handler(Server::call)
146 .add_request_handler(Server::cancel_call)
147 .add_message_handler(Server::decline_call)
148 .add_request_handler(Server::update_participant_location)
149 .add_request_handler(Server::share_project)
150 .add_message_handler(Server::unshare_project)
151 .add_request_handler(Server::join_project)
152 .add_message_handler(Server::leave_project)
153 .add_request_handler(Server::update_project)
154 .add_request_handler(Server::update_worktree)
155 .add_message_handler(Server::start_language_server)
156 .add_message_handler(Server::update_language_server)
157 .add_request_handler(Server::update_diagnostic_summary)
158 .add_request_handler(Server::forward_project_request::<proto::GetHover>)
159 .add_request_handler(Server::forward_project_request::<proto::GetDefinition>)
160 .add_request_handler(Server::forward_project_request::<proto::GetTypeDefinition>)
161 .add_request_handler(Server::forward_project_request::<proto::GetReferences>)
162 .add_request_handler(Server::forward_project_request::<proto::SearchProject>)
163 .add_request_handler(Server::forward_project_request::<proto::GetDocumentHighlights>)
164 .add_request_handler(Server::forward_project_request::<proto::GetProjectSymbols>)
165 .add_request_handler(Server::forward_project_request::<proto::OpenBufferForSymbol>)
166 .add_request_handler(Server::forward_project_request::<proto::OpenBufferById>)
167 .add_request_handler(Server::forward_project_request::<proto::OpenBufferByPath>)
168 .add_request_handler(Server::forward_project_request::<proto::GetCompletions>)
169 .add_request_handler(
170 Server::forward_project_request::<proto::ApplyCompletionAdditionalEdits>,
171 )
172 .add_request_handler(Server::forward_project_request::<proto::GetCodeActions>)
173 .add_request_handler(Server::forward_project_request::<proto::ApplyCodeAction>)
174 .add_request_handler(Server::forward_project_request::<proto::PrepareRename>)
175 .add_request_handler(Server::forward_project_request::<proto::PerformRename>)
176 .add_request_handler(Server::forward_project_request::<proto::ReloadBuffers>)
177 .add_request_handler(Server::forward_project_request::<proto::FormatBuffers>)
178 .add_request_handler(Server::forward_project_request::<proto::CreateProjectEntry>)
179 .add_request_handler(Server::forward_project_request::<proto::RenameProjectEntry>)
180 .add_request_handler(Server::forward_project_request::<proto::CopyProjectEntry>)
181 .add_request_handler(Server::forward_project_request::<proto::DeleteProjectEntry>)
182 .add_message_handler(Server::create_buffer_for_peer)
183 .add_request_handler(Server::update_buffer)
184 .add_message_handler(Server::update_buffer_file)
185 .add_message_handler(Server::buffer_reloaded)
186 .add_message_handler(Server::buffer_saved)
187 .add_request_handler(Server::save_buffer)
188 .add_request_handler(Server::get_users)
189 .add_request_handler(Server::fuzzy_search_users)
190 .add_request_handler(Server::request_contact)
191 .add_request_handler(Server::remove_contact)
192 .add_request_handler(Server::respond_to_contact_request)
193 .add_request_handler(Server::follow)
194 .add_message_handler(Server::unfollow)
195 .add_request_handler(Server::update_followers)
196 .add_message_handler(Server::update_diff_base)
197 .add_request_handler(Server::get_private_user_info);
198
199 Arc::new(server)
200 }
201
202 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
203 where
204 F: 'static + Send + Sync + Fn(Arc<Self>, UserId, TypedEnvelope<M>) -> Fut,
205 Fut: 'static + Send + Future<Output = Result<()>>,
206 M: EnvelopedMessage,
207 {
208 let prev_handler = self.handlers.insert(
209 TypeId::of::<M>(),
210 Box::new(move |server, sender_user_id, envelope| {
211 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
212 let span = info_span!(
213 "handle message",
214 payload_type = envelope.payload_type_name()
215 );
216 span.in_scope(|| {
217 tracing::info!(
218 payload_type = envelope.payload_type_name(),
219 "message received"
220 );
221 });
222 let future = (handler)(server, sender_user_id, *envelope);
223 async move {
224 if let Err(error) = future.await {
225 tracing::error!(%error, "error handling message");
226 }
227 }
228 .instrument(span)
229 .boxed()
230 }),
231 );
232 if prev_handler.is_some() {
233 panic!("registered a handler for the same message twice");
234 }
235 self
236 }
237
238 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
239 where
240 F: 'static + Send + Sync + Fn(Arc<Self>, Message<M>) -> Fut,
241 Fut: 'static + Send + Future<Output = Result<()>>,
242 M: EnvelopedMessage,
243 {
244 self.add_handler(move |server, sender_user_id, envelope| {
245 handler(
246 server,
247 Message {
248 sender_user_id,
249 sender_connection_id: envelope.sender_id,
250 payload: envelope.payload,
251 },
252 )
253 });
254 self
255 }
256
257 /// Handle a request while holding a lock to the store. This is useful when we're registering
258 /// a connection but we want to respond on the connection before anybody else can send on it.
259 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
260 where
261 F: 'static + Send + Sync + Fn(Arc<Self>, Message<M>, Response<M>) -> Fut,
262 Fut: Send + Future<Output = Result<()>>,
263 M: RequestMessage,
264 {
265 let handler = Arc::new(handler);
266 self.add_handler(move |server, sender_user_id, envelope| {
267 let receipt = envelope.receipt();
268 let handler = handler.clone();
269 async move {
270 let request = Message {
271 sender_user_id,
272 sender_connection_id: envelope.sender_id,
273 payload: envelope.payload,
274 };
275 let responded = Arc::new(AtomicBool::default());
276 let response = Response {
277 server: server.clone(),
278 responded: responded.clone(),
279 receipt,
280 };
281 match (handler)(server.clone(), request, response).await {
282 Ok(()) => {
283 if responded.load(std::sync::atomic::Ordering::SeqCst) {
284 Ok(())
285 } else {
286 Err(anyhow!("handler did not send a response"))?
287 }
288 }
289 Err(error) => {
290 server.peer.respond_with_error(
291 receipt,
292 proto::Error {
293 message: error.to_string(),
294 },
295 )?;
296 Err(error)
297 }
298 }
299 }
300 })
301 }
302
303 pub fn handle_connection<E: Executor>(
304 self: &Arc<Self>,
305 connection: Connection,
306 address: String,
307 user: User,
308 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
309 executor: E,
310 ) -> impl Future<Output = Result<()>> {
311 let mut this = self.clone();
312 let user_id = user.id;
313 let login = user.github_login;
314 let span = info_span!("handle connection", %user_id, %login, %address);
315 async move {
316 let (connection_id, handle_io, mut incoming_rx) = this
317 .peer
318 .add_connection(connection, {
319 let executor = executor.clone();
320 move |duration| {
321 let timer = executor.sleep(duration);
322 async move {
323 timer.await;
324 }
325 }
326 });
327
328 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
329 this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
330 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
331
332 if let Some(send_connection_id) = send_connection_id.take() {
333 let _ = send_connection_id.send(connection_id);
334 }
335
336 if !user.connected_once {
337 this.peer.send(connection_id, proto::ShowContacts {})?;
338 this.app_state.db.set_user_connected_once(user_id, true).await?;
339 }
340
341 let (contacts, invite_code) = future::try_join(
342 this.app_state.db.get_contacts(user_id),
343 this.app_state.db.get_invite_code_for_user(user_id)
344 ).await?;
345
346 {
347 let mut store = this.store().await;
348 store.add_connection(connection_id, user_id, user.admin);
349 this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
350
351 if let Some((code, count)) = invite_code {
352 this.peer.send(connection_id, proto::UpdateInviteInfo {
353 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
354 count,
355 })?;
356 }
357 }
358
359 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
360 this.peer.send(connection_id, incoming_call)?;
361 }
362
363 this.update_user_contacts(user_id).await?;
364
365 let handle_io = handle_io.fuse();
366 futures::pin_mut!(handle_io);
367
368 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
369 // This prevents deadlocks when e.g., client A performs a request to client B and
370 // client B performs a request to client A. If both clients stop processing further
371 // messages until their respective request completes, they won't have a chance to
372 // respond to the other client's request and cause a deadlock.
373 //
374 // This arrangement ensures we will attempt to process earlier messages first, but fall
375 // back to processing messages arrived later in the spirit of making progress.
376 let mut foreground_message_handlers = FuturesUnordered::new();
377 loop {
378 let next_message = incoming_rx.next().fuse();
379 futures::pin_mut!(next_message);
380 futures::select_biased! {
381 result = handle_io => {
382 if let Err(error) = result {
383 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
384 }
385 break;
386 }
387 _ = foreground_message_handlers.next() => {}
388 message = next_message => {
389 if let Some(message) = message {
390 let type_name = message.payload_type_name();
391 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
392 let span_enter = span.enter();
393 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
394 let is_background = message.is_background();
395 let handle_message = (handler)(this.clone(), user_id, message);
396 drop(span_enter);
397
398 let handle_message = handle_message.instrument(span);
399 if is_background {
400 executor.spawn_detached(handle_message);
401 } else {
402 foreground_message_handlers.push(handle_message);
403 }
404 } else {
405 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
406 }
407 } else {
408 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
409 break;
410 }
411 }
412 }
413 }
414
415 drop(foreground_message_handlers);
416 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
417 if let Err(error) = this.sign_out(connection_id, user_id).await {
418 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
419 }
420
421 Ok(())
422 }.instrument(span)
423 }
424
425 #[instrument(skip(self), err)]
426 async fn sign_out(
427 self: &mut Arc<Self>,
428 connection_id: ConnectionId,
429 user_id: UserId,
430 ) -> Result<()> {
431 self.peer.disconnect(connection_id);
432 let decline_calls = {
433 let mut store = self.store().await;
434 store.remove_connection(connection_id)?;
435 let mut connections = store.connection_ids_for_user(user_id);
436 connections.next().is_none()
437 };
438
439 self.leave_room_for_connection(connection_id, user_id)
440 .await
441 .trace_err();
442 if decline_calls {
443 if let Some(room) = self
444 .app_state
445 .db
446 .decline_call(None, user_id)
447 .await
448 .trace_err()
449 {
450 self.room_updated(&room);
451 }
452 }
453
454 self.update_user_contacts(user_id).await?;
455
456 Ok(())
457 }
458
459 pub async fn invite_code_redeemed(
460 self: &Arc<Self>,
461 inviter_id: UserId,
462 invitee_id: UserId,
463 ) -> Result<()> {
464 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
465 if let Some(code) = &user.invite_code {
466 let store = self.store().await;
467 let invitee_contact = store.contact_for_user(invitee_id, true, false);
468 for connection_id in store.connection_ids_for_user(inviter_id) {
469 self.peer.send(
470 connection_id,
471 proto::UpdateContacts {
472 contacts: vec![invitee_contact.clone()],
473 ..Default::default()
474 },
475 )?;
476 self.peer.send(
477 connection_id,
478 proto::UpdateInviteInfo {
479 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
480 count: user.invite_count as u32,
481 },
482 )?;
483 }
484 }
485 }
486 Ok(())
487 }
488
489 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
490 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
491 if let Some(invite_code) = &user.invite_code {
492 let store = self.store().await;
493 for connection_id in store.connection_ids_for_user(user_id) {
494 self.peer.send(
495 connection_id,
496 proto::UpdateInviteInfo {
497 url: format!(
498 "{}{}",
499 self.app_state.config.invite_link_prefix, invite_code
500 ),
501 count: user.invite_count as u32,
502 },
503 )?;
504 }
505 }
506 }
507 Ok(())
508 }
509
510 async fn ping(
511 self: Arc<Server>,
512 _: Message<proto::Ping>,
513 response: Response<proto::Ping>,
514 ) -> Result<()> {
515 response.send(proto::Ack {})?;
516 Ok(())
517 }
518
519 async fn create_room(
520 self: Arc<Server>,
521 request: Message<proto::CreateRoom>,
522 response: Response<proto::CreateRoom>,
523 ) -> Result<()> {
524 let room = self
525 .app_state
526 .db
527 .create_room(request.sender_user_id, request.sender_connection_id)
528 .await?;
529
530 let live_kit_connection_info =
531 if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
532 if let Some(_) = live_kit
533 .create_room(room.live_kit_room.clone())
534 .await
535 .trace_err()
536 {
537 if let Some(token) = live_kit
538 .room_token(
539 &room.live_kit_room,
540 &request.sender_connection_id.to_string(),
541 )
542 .trace_err()
543 {
544 Some(proto::LiveKitConnectionInfo {
545 server_url: live_kit.url().into(),
546 token,
547 })
548 } else {
549 None
550 }
551 } else {
552 None
553 }
554 } else {
555 None
556 };
557
558 response.send(proto::CreateRoomResponse {
559 room: Some(room),
560 live_kit_connection_info,
561 })?;
562 self.update_user_contacts(request.sender_user_id).await?;
563 Ok(())
564 }
565
566 async fn join_room(
567 self: Arc<Server>,
568 request: Message<proto::JoinRoom>,
569 response: Response<proto::JoinRoom>,
570 ) -> Result<()> {
571 let room = self
572 .app_state
573 .db
574 .join_room(
575 RoomId::from_proto(request.payload.id),
576 request.sender_user_id,
577 request.sender_connection_id,
578 )
579 .await?;
580 for connection_id in self
581 .store()
582 .await
583 .connection_ids_for_user(request.sender_user_id)
584 {
585 self.peer
586 .send(connection_id, proto::CallCanceled {})
587 .trace_err();
588 }
589
590 let live_kit_connection_info =
591 if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
592 if let Some(token) = live_kit
593 .room_token(
594 &room.live_kit_room,
595 &request.sender_connection_id.to_string(),
596 )
597 .trace_err()
598 {
599 Some(proto::LiveKitConnectionInfo {
600 server_url: live_kit.url().into(),
601 token,
602 })
603 } else {
604 None
605 }
606 } else {
607 None
608 };
609
610 self.room_updated(&room);
611 response.send(proto::JoinRoomResponse {
612 room: Some(room),
613 live_kit_connection_info,
614 })?;
615
616 self.update_user_contacts(request.sender_user_id).await?;
617 Ok(())
618 }
619
620 async fn leave_room(self: Arc<Server>, message: Message<proto::LeaveRoom>) -> Result<()> {
621 self.leave_room_for_connection(message.sender_connection_id, message.sender_user_id)
622 .await
623 }
624
625 async fn leave_room_for_connection(
626 self: &Arc<Server>,
627 leaving_connection_id: ConnectionId,
628 leaving_user_id: UserId,
629 ) -> Result<()> {
630 let mut contacts_to_update = HashSet::default();
631
632 let Some(left_room) = self.app_state.db.leave_room_for_connection(leaving_connection_id).await? else {
633 return Err(anyhow!("no room to leave"))?;
634 };
635 contacts_to_update.insert(leaving_user_id);
636
637 for project in left_room.left_projects.into_values() {
638 for connection_id in project.connection_ids {
639 if project.host_user_id == leaving_user_id {
640 self.peer
641 .send(
642 connection_id,
643 proto::UnshareProject {
644 project_id: project.id.to_proto(),
645 },
646 )
647 .trace_err();
648 } else {
649 self.peer
650 .send(
651 connection_id,
652 proto::RemoveProjectCollaborator {
653 project_id: project.id.to_proto(),
654 peer_id: leaving_connection_id.0,
655 },
656 )
657 .trace_err();
658 }
659 }
660
661 self.peer
662 .send(
663 leaving_connection_id,
664 proto::UnshareProject {
665 project_id: project.id.to_proto(),
666 },
667 )
668 .trace_err();
669 }
670
671 self.room_updated(&left_room.room);
672 {
673 let store = self.store().await;
674 for canceled_user_id in left_room.canceled_calls_to_user_ids {
675 for connection_id in store.connection_ids_for_user(canceled_user_id) {
676 self.peer
677 .send(connection_id, proto::CallCanceled {})
678 .trace_err();
679 }
680 contacts_to_update.insert(canceled_user_id);
681 }
682 }
683
684 for contact_user_id in contacts_to_update {
685 self.update_user_contacts(contact_user_id).await?;
686 }
687
688 if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
689 live_kit
690 .remove_participant(
691 left_room.room.live_kit_room.clone(),
692 leaving_connection_id.to_string(),
693 )
694 .await
695 .trace_err();
696
697 if left_room.room.participants.is_empty() {
698 live_kit
699 .delete_room(left_room.room.live_kit_room)
700 .await
701 .trace_err();
702 }
703 }
704
705 Ok(())
706 }
707
708 async fn call(
709 self: Arc<Server>,
710 request: Message<proto::Call>,
711 response: Response<proto::Call>,
712 ) -> Result<()> {
713 let room_id = RoomId::from_proto(request.payload.room_id);
714 let calling_user_id = request.sender_user_id;
715 let calling_connection_id = request.sender_connection_id;
716 let called_user_id = UserId::from_proto(request.payload.called_user_id);
717 let initial_project_id = request
718 .payload
719 .initial_project_id
720 .map(ProjectId::from_proto);
721 if !self
722 .app_state
723 .db
724 .has_contact(calling_user_id, called_user_id)
725 .await?
726 {
727 return Err(anyhow!("cannot call a user who isn't a contact"))?;
728 }
729
730 let (room, incoming_call) = self
731 .app_state
732 .db
733 .call(
734 room_id,
735 calling_user_id,
736 calling_connection_id,
737 called_user_id,
738 initial_project_id,
739 )
740 .await?;
741 self.room_updated(&room);
742 self.update_user_contacts(called_user_id).await?;
743
744 let mut calls = self
745 .store()
746 .await
747 .connection_ids_for_user(called_user_id)
748 .map(|connection_id| self.peer.request(connection_id, incoming_call.clone()))
749 .collect::<FuturesUnordered<_>>();
750
751 while let Some(call_response) = calls.next().await {
752 match call_response.as_ref() {
753 Ok(_) => {
754 response.send(proto::Ack {})?;
755 return Ok(());
756 }
757 Err(_) => {
758 call_response.trace_err();
759 }
760 }
761 }
762
763 let room = self
764 .app_state
765 .db
766 .call_failed(room_id, called_user_id)
767 .await?;
768 self.room_updated(&room);
769 self.update_user_contacts(called_user_id).await?;
770
771 Err(anyhow!("failed to ring user"))?
772 }
773
774 async fn cancel_call(
775 self: Arc<Server>,
776 request: Message<proto::CancelCall>,
777 response: Response<proto::CancelCall>,
778 ) -> Result<()> {
779 let called_user_id = UserId::from_proto(request.payload.called_user_id);
780 let room_id = RoomId::from_proto(request.payload.room_id);
781 let room = self
782 .app_state
783 .db
784 .cancel_call(Some(room_id), request.sender_connection_id, called_user_id)
785 .await?;
786 for connection_id in self.store().await.connection_ids_for_user(called_user_id) {
787 self.peer
788 .send(connection_id, proto::CallCanceled {})
789 .trace_err();
790 }
791 self.room_updated(&room);
792 response.send(proto::Ack {})?;
793
794 self.update_user_contacts(called_user_id).await?;
795 Ok(())
796 }
797
798 async fn decline_call(self: Arc<Server>, message: Message<proto::DeclineCall>) -> Result<()> {
799 let room_id = RoomId::from_proto(message.payload.room_id);
800 let room = self
801 .app_state
802 .db
803 .decline_call(Some(room_id), message.sender_user_id)
804 .await?;
805 for connection_id in self
806 .store()
807 .await
808 .connection_ids_for_user(message.sender_user_id)
809 {
810 self.peer
811 .send(connection_id, proto::CallCanceled {})
812 .trace_err();
813 }
814 self.room_updated(&room);
815 self.update_user_contacts(message.sender_user_id).await?;
816 Ok(())
817 }
818
819 async fn update_participant_location(
820 self: Arc<Server>,
821 request: Message<proto::UpdateParticipantLocation>,
822 response: Response<proto::UpdateParticipantLocation>,
823 ) -> Result<()> {
824 let room_id = RoomId::from_proto(request.payload.room_id);
825 let location = request
826 .payload
827 .location
828 .ok_or_else(|| anyhow!("invalid location"))?;
829 let room = self
830 .app_state
831 .db
832 .update_room_participant_location(room_id, request.sender_connection_id, location)
833 .await?;
834 self.room_updated(&room);
835 response.send(proto::Ack {})?;
836 Ok(())
837 }
838
839 fn room_updated(&self, room: &proto::Room) {
840 for participant in &room.participants {
841 self.peer
842 .send(
843 ConnectionId(participant.peer_id),
844 proto::RoomUpdated {
845 room: Some(room.clone()),
846 },
847 )
848 .trace_err();
849 }
850 }
851
852 async fn share_project(
853 self: Arc<Server>,
854 request: Message<proto::ShareProject>,
855 response: Response<proto::ShareProject>,
856 ) -> Result<()> {
857 let (project_id, room) = self
858 .app_state
859 .db
860 .share_project(
861 RoomId::from_proto(request.payload.room_id),
862 request.sender_connection_id,
863 &request.payload.worktrees,
864 )
865 .await?;
866 response.send(proto::ShareProjectResponse {
867 project_id: project_id.to_proto(),
868 })?;
869 self.room_updated(&room);
870
871 Ok(())
872 }
873
874 async fn unshare_project(
875 self: Arc<Server>,
876 message: Message<proto::UnshareProject>,
877 ) -> Result<()> {
878 let project_id = ProjectId::from_proto(message.payload.project_id);
879
880 let (room, guest_connection_ids) = self
881 .app_state
882 .db
883 .unshare_project(project_id, message.sender_connection_id)
884 .await?;
885
886 broadcast(
887 message.sender_connection_id,
888 guest_connection_ids,
889 |conn_id| self.peer.send(conn_id, message.payload.clone()),
890 );
891 self.room_updated(&room);
892
893 Ok(())
894 }
895
896 async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
897 let contacts = self.app_state.db.get_contacts(user_id).await?;
898 let busy = self.app_state.db.is_user_busy(user_id).await?;
899 let store = self.store().await;
900 let updated_contact = store.contact_for_user(user_id, false, busy);
901 for contact in contacts {
902 if let db::Contact::Accepted {
903 user_id: contact_user_id,
904 ..
905 } = contact
906 {
907 for contact_conn_id in store.connection_ids_for_user(contact_user_id) {
908 self.peer
909 .send(
910 contact_conn_id,
911 proto::UpdateContacts {
912 contacts: vec![updated_contact.clone()],
913 remove_contacts: Default::default(),
914 incoming_requests: Default::default(),
915 remove_incoming_requests: Default::default(),
916 outgoing_requests: Default::default(),
917 remove_outgoing_requests: Default::default(),
918 },
919 )
920 .trace_err();
921 }
922 }
923 }
924 Ok(())
925 }
926
927 async fn join_project(
928 self: Arc<Server>,
929 request: Message<proto::JoinProject>,
930 response: Response<proto::JoinProject>,
931 ) -> Result<()> {
932 let project_id = ProjectId::from_proto(request.payload.project_id);
933 let guest_user_id = request.sender_user_id;
934
935 tracing::info!(%project_id, "join project");
936
937 let (project, replica_id) = self
938 .app_state
939 .db
940 .join_project(project_id, request.sender_connection_id)
941 .await?;
942
943 let collaborators = project
944 .collaborators
945 .iter()
946 .filter(|collaborator| {
947 collaborator.connection_id != request.sender_connection_id.0 as i32
948 })
949 .map(|collaborator| proto::Collaborator {
950 peer_id: collaborator.connection_id as u32,
951 replica_id: collaborator.replica_id.0 as u32,
952 user_id: collaborator.user_id.to_proto(),
953 })
954 .collect::<Vec<_>>();
955 let worktrees = project
956 .worktrees
957 .iter()
958 .map(|(id, worktree)| proto::WorktreeMetadata {
959 id: id.to_proto(),
960 root_name: worktree.root_name.clone(),
961 visible: worktree.visible,
962 abs_path: worktree.abs_path.clone(),
963 })
964 .collect::<Vec<_>>();
965
966 for collaborator in &collaborators {
967 self.peer
968 .send(
969 ConnectionId(collaborator.peer_id),
970 proto::AddProjectCollaborator {
971 project_id: project_id.to_proto(),
972 collaborator: Some(proto::Collaborator {
973 peer_id: request.sender_connection_id.0,
974 replica_id: replica_id.0 as u32,
975 user_id: guest_user_id.to_proto(),
976 }),
977 },
978 )
979 .trace_err();
980 }
981
982 // First, we send the metadata associated with each worktree.
983 response.send(proto::JoinProjectResponse {
984 worktrees: worktrees.clone(),
985 replica_id: replica_id.0 as u32,
986 collaborators: collaborators.clone(),
987 language_servers: project.language_servers.clone(),
988 })?;
989
990 for (worktree_id, worktree) in project.worktrees {
991 #[cfg(any(test, feature = "test-support"))]
992 const MAX_CHUNK_SIZE: usize = 2;
993 #[cfg(not(any(test, feature = "test-support")))]
994 const MAX_CHUNK_SIZE: usize = 256;
995
996 // Stream this worktree's entries.
997 let message = proto::UpdateWorktree {
998 project_id: project_id.to_proto(),
999 worktree_id: worktree_id.to_proto(),
1000 abs_path: worktree.abs_path.clone(),
1001 root_name: worktree.root_name,
1002 updated_entries: worktree.entries,
1003 removed_entries: Default::default(),
1004 scan_id: worktree.scan_id,
1005 is_last_update: worktree.is_complete,
1006 };
1007 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1008 self.peer
1009 .send(request.sender_connection_id, update.clone())?;
1010 }
1011
1012 // Stream this worktree's diagnostics.
1013 for summary in worktree.diagnostic_summaries {
1014 self.peer.send(
1015 request.sender_connection_id,
1016 proto::UpdateDiagnosticSummary {
1017 project_id: project_id.to_proto(),
1018 worktree_id: worktree.id.to_proto(),
1019 summary: Some(summary),
1020 },
1021 )?;
1022 }
1023 }
1024
1025 for language_server in &project.language_servers {
1026 self.peer.send(
1027 request.sender_connection_id,
1028 proto::UpdateLanguageServer {
1029 project_id: project_id.to_proto(),
1030 language_server_id: language_server.id,
1031 variant: Some(
1032 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1033 proto::LspDiskBasedDiagnosticsUpdated {},
1034 ),
1035 ),
1036 },
1037 )?;
1038 }
1039
1040 Ok(())
1041 }
1042
1043 async fn leave_project(self: Arc<Server>, request: Message<proto::LeaveProject>) -> Result<()> {
1044 let sender_id = request.sender_connection_id;
1045 let project_id = ProjectId::from_proto(request.payload.project_id);
1046 let project;
1047 {
1048 project = self
1049 .app_state
1050 .db
1051 .leave_project(project_id, sender_id)
1052 .await?;
1053 tracing::info!(
1054 %project_id,
1055 host_user_id = %project.host_user_id,
1056 host_connection_id = %project.host_connection_id,
1057 "leave project"
1058 );
1059
1060 broadcast(sender_id, project.connection_ids, |conn_id| {
1061 self.peer.send(
1062 conn_id,
1063 proto::RemoveProjectCollaborator {
1064 project_id: project_id.to_proto(),
1065 peer_id: sender_id.0,
1066 },
1067 )
1068 });
1069 }
1070
1071 Ok(())
1072 }
1073
1074 async fn update_project(
1075 self: Arc<Server>,
1076 request: Message<proto::UpdateProject>,
1077 response: Response<proto::UpdateProject>,
1078 ) -> Result<()> {
1079 let project_id = ProjectId::from_proto(request.payload.project_id);
1080 let (room, guest_connection_ids) = self
1081 .app_state
1082 .db
1083 .update_project(
1084 project_id,
1085 request.sender_connection_id,
1086 &request.payload.worktrees,
1087 )
1088 .await?;
1089 broadcast(
1090 request.sender_connection_id,
1091 guest_connection_ids,
1092 |connection_id| {
1093 self.peer.forward_send(
1094 request.sender_connection_id,
1095 connection_id,
1096 request.payload.clone(),
1097 )
1098 },
1099 );
1100 self.room_updated(&room);
1101 response.send(proto::Ack {})?;
1102
1103 Ok(())
1104 }
1105
1106 async fn update_worktree(
1107 self: Arc<Server>,
1108 request: Message<proto::UpdateWorktree>,
1109 response: Response<proto::UpdateWorktree>,
1110 ) -> Result<()> {
1111 let guest_connection_ids = self
1112 .app_state
1113 .db
1114 .update_worktree(&request.payload, request.sender_connection_id)
1115 .await?;
1116
1117 broadcast(
1118 request.sender_connection_id,
1119 guest_connection_ids,
1120 |connection_id| {
1121 self.peer.forward_send(
1122 request.sender_connection_id,
1123 connection_id,
1124 request.payload.clone(),
1125 )
1126 },
1127 );
1128 response.send(proto::Ack {})?;
1129 Ok(())
1130 }
1131
1132 async fn update_diagnostic_summary(
1133 self: Arc<Server>,
1134 request: Message<proto::UpdateDiagnosticSummary>,
1135 response: Response<proto::UpdateDiagnosticSummary>,
1136 ) -> Result<()> {
1137 let guest_connection_ids = self
1138 .app_state
1139 .db
1140 .update_diagnostic_summary(&request.payload, request.sender_connection_id)
1141 .await?;
1142
1143 broadcast(
1144 request.sender_connection_id,
1145 guest_connection_ids,
1146 |connection_id| {
1147 self.peer.forward_send(
1148 request.sender_connection_id,
1149 connection_id,
1150 request.payload.clone(),
1151 )
1152 },
1153 );
1154
1155 response.send(proto::Ack {})?;
1156 Ok(())
1157 }
1158
1159 async fn start_language_server(
1160 self: Arc<Server>,
1161 request: Message<proto::StartLanguageServer>,
1162 ) -> Result<()> {
1163 let guest_connection_ids = self
1164 .app_state
1165 .db
1166 .start_language_server(&request.payload, request.sender_connection_id)
1167 .await?;
1168
1169 broadcast(
1170 request.sender_connection_id,
1171 guest_connection_ids,
1172 |connection_id| {
1173 self.peer.forward_send(
1174 request.sender_connection_id,
1175 connection_id,
1176 request.payload.clone(),
1177 )
1178 },
1179 );
1180 Ok(())
1181 }
1182
1183 async fn update_language_server(
1184 self: Arc<Server>,
1185 request: Message<proto::UpdateLanguageServer>,
1186 ) -> Result<()> {
1187 let project_id = ProjectId::from_proto(request.payload.project_id);
1188 let project_connection_ids = self
1189 .app_state
1190 .db
1191 .project_connection_ids(project_id, request.sender_connection_id)
1192 .await?;
1193 broadcast(
1194 request.sender_connection_id,
1195 project_connection_ids,
1196 |connection_id| {
1197 self.peer.forward_send(
1198 request.sender_connection_id,
1199 connection_id,
1200 request.payload.clone(),
1201 )
1202 },
1203 );
1204 Ok(())
1205 }
1206
1207 async fn forward_project_request<T>(
1208 self: Arc<Server>,
1209 request: Message<T>,
1210 response: Response<T>,
1211 ) -> Result<()>
1212 where
1213 T: EntityMessage + RequestMessage,
1214 {
1215 let project_id = ProjectId::from_proto(request.payload.remote_entity_id());
1216 let collaborators = self
1217 .app_state
1218 .db
1219 .project_collaborators(project_id, request.sender_connection_id)
1220 .await?;
1221 let host = collaborators
1222 .iter()
1223 .find(|collaborator| collaborator.is_host)
1224 .ok_or_else(|| anyhow!("host not found"))?;
1225
1226 let payload = self
1227 .peer
1228 .forward_request(
1229 request.sender_connection_id,
1230 ConnectionId(host.connection_id as u32),
1231 request.payload,
1232 )
1233 .await?;
1234
1235 response.send(payload)?;
1236 Ok(())
1237 }
1238
1239 async fn save_buffer(
1240 self: Arc<Server>,
1241 request: Message<proto::SaveBuffer>,
1242 response: Response<proto::SaveBuffer>,
1243 ) -> Result<()> {
1244 let project_id = ProjectId::from_proto(request.payload.project_id);
1245 let collaborators = self
1246 .app_state
1247 .db
1248 .project_collaborators(project_id, request.sender_connection_id)
1249 .await?;
1250 let host = collaborators
1251 .into_iter()
1252 .find(|collaborator| collaborator.is_host)
1253 .ok_or_else(|| anyhow!("host not found"))?;
1254 let host_connection_id = ConnectionId(host.connection_id as u32);
1255 let response_payload = self
1256 .peer
1257 .forward_request(
1258 request.sender_connection_id,
1259 host_connection_id,
1260 request.payload.clone(),
1261 )
1262 .await?;
1263
1264 let mut collaborators = self
1265 .app_state
1266 .db
1267 .project_collaborators(project_id, request.sender_connection_id)
1268 .await?;
1269 collaborators.retain(|collaborator| {
1270 collaborator.connection_id != request.sender_connection_id.0 as i32
1271 });
1272 let project_connection_ids = collaborators
1273 .into_iter()
1274 .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
1275 broadcast(host_connection_id, project_connection_ids, |conn_id| {
1276 self.peer
1277 .forward_send(host_connection_id, conn_id, response_payload.clone())
1278 });
1279 response.send(response_payload)?;
1280 Ok(())
1281 }
1282
1283 async fn create_buffer_for_peer(
1284 self: Arc<Server>,
1285 request: Message<proto::CreateBufferForPeer>,
1286 ) -> Result<()> {
1287 self.peer.forward_send(
1288 request.sender_connection_id,
1289 ConnectionId(request.payload.peer_id),
1290 request.payload,
1291 )?;
1292 Ok(())
1293 }
1294
1295 async fn update_buffer(
1296 self: Arc<Server>,
1297 request: Message<proto::UpdateBuffer>,
1298 response: Response<proto::UpdateBuffer>,
1299 ) -> Result<()> {
1300 let project_id = ProjectId::from_proto(request.payload.project_id);
1301 let project_connection_ids = self
1302 .app_state
1303 .db
1304 .project_connection_ids(project_id, request.sender_connection_id)
1305 .await?;
1306
1307 broadcast(
1308 request.sender_connection_id,
1309 project_connection_ids,
1310 |connection_id| {
1311 self.peer.forward_send(
1312 request.sender_connection_id,
1313 connection_id,
1314 request.payload.clone(),
1315 )
1316 },
1317 );
1318 response.send(proto::Ack {})?;
1319 Ok(())
1320 }
1321
1322 async fn update_buffer_file(
1323 self: Arc<Server>,
1324 request: Message<proto::UpdateBufferFile>,
1325 ) -> Result<()> {
1326 let project_id = ProjectId::from_proto(request.payload.project_id);
1327 let project_connection_ids = self
1328 .app_state
1329 .db
1330 .project_connection_ids(project_id, request.sender_connection_id)
1331 .await?;
1332
1333 broadcast(
1334 request.sender_connection_id,
1335 project_connection_ids,
1336 |connection_id| {
1337 self.peer.forward_send(
1338 request.sender_connection_id,
1339 connection_id,
1340 request.payload.clone(),
1341 )
1342 },
1343 );
1344 Ok(())
1345 }
1346
1347 async fn buffer_reloaded(
1348 self: Arc<Server>,
1349 request: Message<proto::BufferReloaded>,
1350 ) -> Result<()> {
1351 let project_id = ProjectId::from_proto(request.payload.project_id);
1352 let project_connection_ids = self
1353 .app_state
1354 .db
1355 .project_connection_ids(project_id, request.sender_connection_id)
1356 .await?;
1357 broadcast(
1358 request.sender_connection_id,
1359 project_connection_ids,
1360 |connection_id| {
1361 self.peer.forward_send(
1362 request.sender_connection_id,
1363 connection_id,
1364 request.payload.clone(),
1365 )
1366 },
1367 );
1368 Ok(())
1369 }
1370
1371 async fn buffer_saved(self: Arc<Server>, request: Message<proto::BufferSaved>) -> Result<()> {
1372 let project_id = ProjectId::from_proto(request.payload.project_id);
1373 let project_connection_ids = self
1374 .app_state
1375 .db
1376 .project_connection_ids(project_id, request.sender_connection_id)
1377 .await?;
1378 broadcast(
1379 request.sender_connection_id,
1380 project_connection_ids,
1381 |connection_id| {
1382 self.peer.forward_send(
1383 request.sender_connection_id,
1384 connection_id,
1385 request.payload.clone(),
1386 )
1387 },
1388 );
1389 Ok(())
1390 }
1391
1392 async fn follow(
1393 self: Arc<Self>,
1394 request: Message<proto::Follow>,
1395 response: Response<proto::Follow>,
1396 ) -> Result<()> {
1397 let project_id = ProjectId::from_proto(request.payload.project_id);
1398 let leader_id = ConnectionId(request.payload.leader_id);
1399 let follower_id = request.sender_connection_id;
1400 let project_connection_ids = self
1401 .app_state
1402 .db
1403 .project_connection_ids(project_id, request.sender_connection_id)
1404 .await?;
1405
1406 if !project_connection_ids.contains(&leader_id) {
1407 Err(anyhow!("no such peer"))?;
1408 }
1409
1410 let mut response_payload = self
1411 .peer
1412 .forward_request(request.sender_connection_id, leader_id, request.payload)
1413 .await?;
1414 response_payload
1415 .views
1416 .retain(|view| view.leader_id != Some(follower_id.0));
1417 response.send(response_payload)?;
1418 Ok(())
1419 }
1420
1421 async fn unfollow(self: Arc<Self>, request: Message<proto::Unfollow>) -> Result<()> {
1422 let project_id = ProjectId::from_proto(request.payload.project_id);
1423 let leader_id = ConnectionId(request.payload.leader_id);
1424 let project_connection_ids = self
1425 .app_state
1426 .db
1427 .project_connection_ids(project_id, request.sender_connection_id)
1428 .await?;
1429 if !project_connection_ids.contains(&leader_id) {
1430 Err(anyhow!("no such peer"))?;
1431 }
1432 self.peer
1433 .forward_send(request.sender_connection_id, leader_id, request.payload)?;
1434 Ok(())
1435 }
1436
1437 async fn update_followers(
1438 self: Arc<Self>,
1439 request: Message<proto::UpdateFollowers>,
1440 response: Response<proto::UpdateFollowers>,
1441 ) -> Result<()> {
1442 let project_id = ProjectId::from_proto(request.payload.project_id);
1443 let project_connection_ids = self
1444 .app_state
1445 .db
1446 .project_connection_ids(project_id, request.sender_connection_id)
1447 .await?;
1448
1449 let leader_id = request
1450 .payload
1451 .variant
1452 .as_ref()
1453 .and_then(|variant| match variant {
1454 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1455 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1456 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1457 });
1458 for follower_id in &request.payload.follower_ids {
1459 let follower_id = ConnectionId(*follower_id);
1460 if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1461 self.peer.forward_send(
1462 request.sender_connection_id,
1463 follower_id,
1464 request.payload.clone(),
1465 )?;
1466 }
1467 }
1468 response.send(proto::Ack {})?;
1469 Ok(())
1470 }
1471
1472 async fn get_users(
1473 self: Arc<Server>,
1474 request: Message<proto::GetUsers>,
1475 response: Response<proto::GetUsers>,
1476 ) -> Result<()> {
1477 let user_ids = request
1478 .payload
1479 .user_ids
1480 .into_iter()
1481 .map(UserId::from_proto)
1482 .collect();
1483 let users = self
1484 .app_state
1485 .db
1486 .get_users_by_ids(user_ids)
1487 .await?
1488 .into_iter()
1489 .map(|user| proto::User {
1490 id: user.id.to_proto(),
1491 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1492 github_login: user.github_login,
1493 })
1494 .collect();
1495 response.send(proto::UsersResponse { users })?;
1496 Ok(())
1497 }
1498
1499 async fn fuzzy_search_users(
1500 self: Arc<Server>,
1501 request: Message<proto::FuzzySearchUsers>,
1502 response: Response<proto::FuzzySearchUsers>,
1503 ) -> Result<()> {
1504 let query = request.payload.query;
1505 let db = &self.app_state.db;
1506 let users = match query.len() {
1507 0 => vec![],
1508 1 | 2 => db
1509 .get_user_by_github_account(&query, None)
1510 .await?
1511 .into_iter()
1512 .collect(),
1513 _ => db.fuzzy_search_users(&query, 10).await?,
1514 };
1515 let users = users
1516 .into_iter()
1517 .filter(|user| user.id != request.sender_user_id)
1518 .map(|user| proto::User {
1519 id: user.id.to_proto(),
1520 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1521 github_login: user.github_login,
1522 })
1523 .collect();
1524 response.send(proto::UsersResponse { users })?;
1525 Ok(())
1526 }
1527
1528 async fn request_contact(
1529 self: Arc<Server>,
1530 request: Message<proto::RequestContact>,
1531 response: Response<proto::RequestContact>,
1532 ) -> Result<()> {
1533 let requester_id = request.sender_user_id;
1534 let responder_id = UserId::from_proto(request.payload.responder_id);
1535 if requester_id == responder_id {
1536 return Err(anyhow!("cannot add yourself as a contact"))?;
1537 }
1538
1539 self.app_state
1540 .db
1541 .send_contact_request(requester_id, responder_id)
1542 .await?;
1543
1544 // Update outgoing contact requests of requester
1545 let mut update = proto::UpdateContacts::default();
1546 update.outgoing_requests.push(responder_id.to_proto());
1547 for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1548 self.peer.send(connection_id, update.clone())?;
1549 }
1550
1551 // Update incoming contact requests of responder
1552 let mut update = proto::UpdateContacts::default();
1553 update
1554 .incoming_requests
1555 .push(proto::IncomingContactRequest {
1556 requester_id: requester_id.to_proto(),
1557 should_notify: true,
1558 });
1559 for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1560 self.peer.send(connection_id, update.clone())?;
1561 }
1562
1563 response.send(proto::Ack {})?;
1564 Ok(())
1565 }
1566
1567 async fn respond_to_contact_request(
1568 self: Arc<Server>,
1569 request: Message<proto::RespondToContactRequest>,
1570 response: Response<proto::RespondToContactRequest>,
1571 ) -> Result<()> {
1572 let responder_id = request.sender_user_id;
1573 let requester_id = UserId::from_proto(request.payload.requester_id);
1574 if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 {
1575 self.app_state
1576 .db
1577 .dismiss_contact_notification(responder_id, requester_id)
1578 .await?;
1579 } else {
1580 let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32;
1581 self.app_state
1582 .db
1583 .respond_to_contact_request(responder_id, requester_id, accept)
1584 .await?;
1585 let busy = self.app_state.db.is_user_busy(requester_id).await?;
1586
1587 let store = self.store().await;
1588 // Update responder with new contact
1589 let mut update = proto::UpdateContacts::default();
1590 if accept {
1591 update
1592 .contacts
1593 .push(store.contact_for_user(requester_id, false, busy));
1594 }
1595 update
1596 .remove_incoming_requests
1597 .push(requester_id.to_proto());
1598 for connection_id in store.connection_ids_for_user(responder_id) {
1599 self.peer.send(connection_id, update.clone())?;
1600 }
1601
1602 // Update requester with new contact
1603 let mut update = proto::UpdateContacts::default();
1604 if accept {
1605 update
1606 .contacts
1607 .push(store.contact_for_user(responder_id, true, busy));
1608 }
1609 update
1610 .remove_outgoing_requests
1611 .push(responder_id.to_proto());
1612 for connection_id in store.connection_ids_for_user(requester_id) {
1613 self.peer.send(connection_id, update.clone())?;
1614 }
1615 }
1616
1617 response.send(proto::Ack {})?;
1618 Ok(())
1619 }
1620
1621 async fn remove_contact(
1622 self: Arc<Server>,
1623 request: Message<proto::RemoveContact>,
1624 response: Response<proto::RemoveContact>,
1625 ) -> Result<()> {
1626 let requester_id = request.sender_user_id;
1627 let responder_id = UserId::from_proto(request.payload.user_id);
1628 self.app_state
1629 .db
1630 .remove_contact(requester_id, responder_id)
1631 .await?;
1632
1633 // Update outgoing contact requests of requester
1634 let mut update = proto::UpdateContacts::default();
1635 update
1636 .remove_outgoing_requests
1637 .push(responder_id.to_proto());
1638 for connection_id in self.store().await.connection_ids_for_user(requester_id) {
1639 self.peer.send(connection_id, update.clone())?;
1640 }
1641
1642 // Update incoming contact requests of responder
1643 let mut update = proto::UpdateContacts::default();
1644 update
1645 .remove_incoming_requests
1646 .push(requester_id.to_proto());
1647 for connection_id in self.store().await.connection_ids_for_user(responder_id) {
1648 self.peer.send(connection_id, update.clone())?;
1649 }
1650
1651 response.send(proto::Ack {})?;
1652 Ok(())
1653 }
1654
1655 async fn update_diff_base(
1656 self: Arc<Server>,
1657 request: Message<proto::UpdateDiffBase>,
1658 ) -> Result<()> {
1659 let project_id = ProjectId::from_proto(request.payload.project_id);
1660 let project_connection_ids = self
1661 .app_state
1662 .db
1663 .project_connection_ids(project_id, request.sender_connection_id)
1664 .await?;
1665 broadcast(
1666 request.sender_connection_id,
1667 project_connection_ids,
1668 |connection_id| {
1669 self.peer.forward_send(
1670 request.sender_connection_id,
1671 connection_id,
1672 request.payload.clone(),
1673 )
1674 },
1675 );
1676 Ok(())
1677 }
1678
1679 async fn get_private_user_info(
1680 self: Arc<Self>,
1681 request: Message<proto::GetPrivateUserInfo>,
1682 response: Response<proto::GetPrivateUserInfo>,
1683 ) -> Result<()> {
1684 let metrics_id = self
1685 .app_state
1686 .db
1687 .get_user_metrics_id(request.sender_user_id)
1688 .await?;
1689 let user = self
1690 .app_state
1691 .db
1692 .get_user_by_id(request.sender_user_id)
1693 .await?
1694 .ok_or_else(|| anyhow!("user not found"))?;
1695 response.send(proto::GetPrivateUserInfoResponse {
1696 metrics_id,
1697 staff: user.admin,
1698 })?;
1699 Ok(())
1700 }
1701
1702 pub(crate) async fn store(&self) -> StoreGuard<'_> {
1703 #[cfg(test)]
1704 tokio::task::yield_now().await;
1705 let guard = self.store.lock().await;
1706 #[cfg(test)]
1707 tokio::task::yield_now().await;
1708 StoreGuard {
1709 guard,
1710 _not_send: PhantomData,
1711 }
1712 }
1713
1714 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1715 ServerSnapshot {
1716 store: self.store().await,
1717 peer: &self.peer,
1718 }
1719 }
1720}
1721
1722impl<'a> Deref for StoreGuard<'a> {
1723 type Target = Store;
1724
1725 fn deref(&self) -> &Self::Target {
1726 &*self.guard
1727 }
1728}
1729
1730impl<'a> DerefMut for StoreGuard<'a> {
1731 fn deref_mut(&mut self) -> &mut Self::Target {
1732 &mut *self.guard
1733 }
1734}
1735
1736impl<'a> Drop for StoreGuard<'a> {
1737 fn drop(&mut self) {
1738 #[cfg(test)]
1739 self.check_invariants();
1740 }
1741}
1742
1743impl Executor for RealExecutor {
1744 type Sleep = Sleep;
1745
1746 fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
1747 tokio::task::spawn(future);
1748 }
1749
1750 fn sleep(&self, duration: Duration) -> Self::Sleep {
1751 tokio::time::sleep(duration)
1752 }
1753}
1754
1755fn broadcast<F>(
1756 sender_id: ConnectionId,
1757 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1758 mut f: F,
1759) where
1760 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1761{
1762 for receiver_id in receiver_ids {
1763 if receiver_id != sender_id {
1764 f(receiver_id).trace_err();
1765 }
1766 }
1767}
1768
1769lazy_static! {
1770 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
1771}
1772
1773pub struct ProtocolVersion(u32);
1774
1775impl Header for ProtocolVersion {
1776 fn name() -> &'static HeaderName {
1777 &ZED_PROTOCOL_VERSION
1778 }
1779
1780 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1781 where
1782 Self: Sized,
1783 I: Iterator<Item = &'i axum::http::HeaderValue>,
1784 {
1785 let version = values
1786 .next()
1787 .ok_or_else(axum::headers::Error::invalid)?
1788 .to_str()
1789 .map_err(|_| axum::headers::Error::invalid())?
1790 .parse()
1791 .map_err(|_| axum::headers::Error::invalid())?;
1792 Ok(Self(version))
1793 }
1794
1795 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1796 values.extend([self.0.to_string().parse().unwrap()]);
1797 }
1798}
1799
1800pub fn routes(server: Arc<Server>) -> Router<Body> {
1801 Router::new()
1802 .route("/rpc", get(handle_websocket_request))
1803 .layer(
1804 ServiceBuilder::new()
1805 .layer(Extension(server.app_state.clone()))
1806 .layer(middleware::from_fn(auth::validate_header)),
1807 )
1808 .route("/metrics", get(handle_metrics))
1809 .layer(Extension(server))
1810}
1811
1812pub async fn handle_websocket_request(
1813 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1814 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1815 Extension(server): Extension<Arc<Server>>,
1816 Extension(user): Extension<User>,
1817 ws: WebSocketUpgrade,
1818) -> axum::response::Response {
1819 if protocol_version != rpc::PROTOCOL_VERSION {
1820 return (
1821 StatusCode::UPGRADE_REQUIRED,
1822 "client must be upgraded".to_string(),
1823 )
1824 .into_response();
1825 }
1826 let socket_address = socket_address.to_string();
1827 ws.on_upgrade(move |socket| {
1828 use util::ResultExt;
1829 let socket = socket
1830 .map_ok(to_tungstenite_message)
1831 .err_into()
1832 .with(|message| async move { Ok(to_axum_message(message)) });
1833 let connection = Connection::new(Box::pin(socket));
1834 async move {
1835 server
1836 .handle_connection(connection, socket_address, user, None, RealExecutor)
1837 .await
1838 .log_err();
1839 }
1840 })
1841}
1842
1843pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> axum::response::Response {
1844 let metrics = server.store().await.metrics();
1845 METRIC_CONNECTIONS.set(metrics.connections as _);
1846 METRIC_SHARED_PROJECTS.set(metrics.shared_projects as _);
1847
1848 let encoder = prometheus::TextEncoder::new();
1849 let metric_families = prometheus::gather();
1850 match encoder.encode_to_string(&metric_families) {
1851 Ok(string) => (StatusCode::OK, string).into_response(),
1852 Err(error) => (
1853 StatusCode::INTERNAL_SERVER_ERROR,
1854 format!("failed to encode metrics {:?}", error),
1855 )
1856 .into_response(),
1857 }
1858}
1859
1860fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1861 match message {
1862 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1863 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1864 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1865 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1866 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1867 code: frame.code.into(),
1868 reason: frame.reason,
1869 })),
1870 }
1871}
1872
1873fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1874 match message {
1875 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1876 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1877 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1878 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1879 AxumMessage::Close(frame) => {
1880 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1881 code: frame.code.into(),
1882 reason: frame.reason,
1883 }))
1884 }
1885 }
1886}
1887
1888pub trait ResultExt {
1889 type Ok;
1890
1891 fn trace_err(self) -> Option<Self::Ok>;
1892}
1893
1894impl<T, E> ResultExt for Result<T, E>
1895where
1896 E: std::fmt::Debug,
1897{
1898 type Ok = T;
1899
1900 fn trace_err(self) -> Option<T> {
1901 match self {
1902 Ok(value) => Some(value),
1903 Err(error) => {
1904 tracing::error!("{:?}", error);
1905 None
1906 }
1907 }
1908 }
1909}