1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, DefaultDb, ProjectId, RoomId, User, UserId},
6 AppState, Result,
7};
8use anyhow::anyhow;
9use async_tungstenite::tungstenite::{
10 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
11};
12use axum::{
13 body::Body,
14 extract::{
15 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
16 ConnectInfo, WebSocketUpgrade,
17 },
18 headers::{Header, HeaderName},
19 http::StatusCode,
20 middleware,
21 response::IntoResponse,
22 routing::get,
23 Extension, Router, TypedHeader,
24};
25use collections::{HashMap, HashSet};
26pub use connection_pool::ConnectionPool;
27use futures::{
28 channel::oneshot,
29 future::{self, BoxFuture},
30 stream::FuturesUnordered,
31 FutureExt, SinkExt, StreamExt, TryStreamExt,
32};
33use lazy_static::lazy_static;
34use prometheus::{register_int_gauge, IntGauge};
35use rpc::{
36 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
37 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
38};
39use serde::{Serialize, Serializer};
40use std::{
41 any::TypeId,
42 future::Future,
43 marker::PhantomData,
44 net::SocketAddr,
45 ops::{Deref, DerefMut},
46 rc::Rc,
47 sync::{
48 atomic::{AtomicBool, Ordering::SeqCst},
49 Arc,
50 },
51 time::Duration,
52};
53use tokio::{
54 sync::{Mutex, MutexGuard},
55 time::Sleep,
56};
57use tower::ServiceBuilder;
58use tracing::{info_span, instrument, Instrument};
59
60lazy_static! {
61 static ref METRIC_CONNECTIONS: IntGauge =
62 register_int_gauge!("connections", "number of connections").unwrap();
63 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
64 "shared_projects",
65 "number of open projects with one or more guests"
66 )
67 .unwrap();
68}
69
70type MessageHandler = Box<
71 dyn Send + Sync + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>,
72>;
73
74struct Response<R> {
75 server: Arc<Server>,
76 receipt: Receipt<R>,
77 responded: Arc<AtomicBool>,
78}
79
80struct Session {
81 user_id: UserId,
82 connection_id: ConnectionId,
83 db: Arc<Mutex<DbHandle>>,
84}
85
86struct DbHandle(Arc<DefaultDb>);
87
88impl Deref for DbHandle {
89 type Target = DefaultDb;
90
91 fn deref(&self) -> &Self::Target {
92 self.0.as_ref()
93 }
94}
95
96impl<R: RequestMessage> Response<R> {
97 fn send(self, payload: R::Response) -> Result<()> {
98 self.responded.store(true, SeqCst);
99 self.server.peer.respond(self.receipt, payload)?;
100 Ok(())
101 }
102}
103
104pub struct Server {
105 peer: Arc<Peer>,
106 pub(crate) connection_pool: Mutex<ConnectionPool>,
107 app_state: Arc<AppState>,
108 handlers: HashMap<TypeId, MessageHandler>,
109}
110
111pub trait Executor: Send + Clone {
112 type Sleep: Send + Future;
113 fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
114 fn sleep(&self, duration: Duration) -> Self::Sleep;
115}
116
117#[derive(Clone)]
118pub struct RealExecutor;
119
120pub(crate) struct ConnectionPoolGuard<'a> {
121 guard: MutexGuard<'a, ConnectionPool>,
122 _not_send: PhantomData<Rc<()>>,
123}
124
125#[derive(Serialize)]
126pub struct ServerSnapshot<'a> {
127 peer: &'a Peer,
128 #[serde(serialize_with = "serialize_deref")]
129 connection_pool: ConnectionPoolGuard<'a>,
130}
131
132pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
133where
134 S: Serializer,
135 T: Deref<Target = U>,
136 U: Serialize,
137{
138 Serialize::serialize(value.deref(), serializer)
139}
140
141impl Server {
142 pub fn new(app_state: Arc<AppState>) -> Arc<Self> {
143 let mut server = Self {
144 peer: Peer::new(),
145 app_state,
146 connection_pool: Default::default(),
147 handlers: Default::default(),
148 };
149
150 server
151 .add_request_handler(Server::ping)
152 .add_request_handler(Server::create_room)
153 .add_request_handler(Server::join_room)
154 .add_message_handler(Server::leave_room)
155 .add_request_handler(Server::call)
156 .add_request_handler(Server::cancel_call)
157 .add_message_handler(Server::decline_call)
158 .add_request_handler(Server::update_participant_location)
159 .add_request_handler(Server::share_project)
160 .add_message_handler(Server::unshare_project)
161 .add_request_handler(Server::join_project)
162 .add_message_handler(Server::leave_project)
163 .add_request_handler(Server::update_project)
164 .add_request_handler(Server::update_worktree)
165 .add_message_handler(Server::start_language_server)
166 .add_message_handler(Server::update_language_server)
167 .add_request_handler(Server::update_diagnostic_summary)
168 .add_request_handler(Server::forward_project_request::<proto::GetHover>)
169 .add_request_handler(Server::forward_project_request::<proto::GetDefinition>)
170 .add_request_handler(Server::forward_project_request::<proto::GetTypeDefinition>)
171 .add_request_handler(Server::forward_project_request::<proto::GetReferences>)
172 .add_request_handler(Server::forward_project_request::<proto::SearchProject>)
173 .add_request_handler(Server::forward_project_request::<proto::GetDocumentHighlights>)
174 .add_request_handler(Server::forward_project_request::<proto::GetProjectSymbols>)
175 .add_request_handler(Server::forward_project_request::<proto::OpenBufferForSymbol>)
176 .add_request_handler(Server::forward_project_request::<proto::OpenBufferById>)
177 .add_request_handler(Server::forward_project_request::<proto::OpenBufferByPath>)
178 .add_request_handler(Server::forward_project_request::<proto::GetCompletions>)
179 .add_request_handler(
180 Server::forward_project_request::<proto::ApplyCompletionAdditionalEdits>,
181 )
182 .add_request_handler(Server::forward_project_request::<proto::GetCodeActions>)
183 .add_request_handler(Server::forward_project_request::<proto::ApplyCodeAction>)
184 .add_request_handler(Server::forward_project_request::<proto::PrepareRename>)
185 .add_request_handler(Server::forward_project_request::<proto::PerformRename>)
186 .add_request_handler(Server::forward_project_request::<proto::ReloadBuffers>)
187 .add_request_handler(Server::forward_project_request::<proto::FormatBuffers>)
188 .add_request_handler(Server::forward_project_request::<proto::CreateProjectEntry>)
189 .add_request_handler(Server::forward_project_request::<proto::RenameProjectEntry>)
190 .add_request_handler(Server::forward_project_request::<proto::CopyProjectEntry>)
191 .add_request_handler(Server::forward_project_request::<proto::DeleteProjectEntry>)
192 .add_message_handler(Server::create_buffer_for_peer)
193 .add_request_handler(Server::update_buffer)
194 .add_message_handler(Server::update_buffer_file)
195 .add_message_handler(Server::buffer_reloaded)
196 .add_message_handler(Server::buffer_saved)
197 .add_request_handler(Server::save_buffer)
198 .add_request_handler(Server::get_users)
199 .add_request_handler(Server::fuzzy_search_users)
200 .add_request_handler(Server::request_contact)
201 .add_request_handler(Server::remove_contact)
202 .add_request_handler(Server::respond_to_contact_request)
203 .add_request_handler(Server::follow)
204 .add_message_handler(Server::unfollow)
205 .add_message_handler(Server::update_followers)
206 .add_message_handler(Server::update_diff_base)
207 .add_request_handler(Server::get_private_user_info);
208
209 Arc::new(server)
210 }
211
212 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
213 where
214 F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>, Session) -> Fut,
215 Fut: 'static + Send + Future<Output = Result<()>>,
216 M: EnvelopedMessage,
217 {
218 let prev_handler = self.handlers.insert(
219 TypeId::of::<M>(),
220 Box::new(move |server, envelope, session| {
221 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
222 let span = info_span!(
223 "handle message",
224 payload_type = envelope.payload_type_name()
225 );
226 span.in_scope(|| {
227 tracing::info!(
228 payload_type = envelope.payload_type_name(),
229 "message received"
230 );
231 });
232 let future = (handler)(server, *envelope, session);
233 async move {
234 if let Err(error) = future.await {
235 tracing::error!(%error, "error handling message");
236 }
237 }
238 .instrument(span)
239 .boxed()
240 }),
241 );
242 if prev_handler.is_some() {
243 panic!("registered a handler for the same message twice");
244 }
245 self
246 }
247
248 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
249 where
250 F: 'static + Send + Sync + Fn(Arc<Self>, M, Session) -> Fut,
251 Fut: 'static + Send + Future<Output = Result<()>>,
252 M: EnvelopedMessage,
253 {
254 self.add_handler(move |server, envelope, session| {
255 handler(server, envelope.payload, session)
256 });
257 self
258 }
259
260 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
261 where
262 F: 'static + Send + Sync + Fn(Arc<Self>, M, Response<M>, Session) -> Fut,
263 Fut: Send + Future<Output = Result<()>>,
264 M: RequestMessage,
265 {
266 let handler = Arc::new(handler);
267 self.add_handler(move |server, envelope, session| {
268 let receipt = envelope.receipt();
269 let handler = handler.clone();
270 async move {
271 let responded = Arc::new(AtomicBool::default());
272 let response = Response {
273 server: server.clone(),
274 responded: responded.clone(),
275 receipt,
276 };
277 match (handler)(server.clone(), envelope.payload, response, session).await {
278 Ok(()) => {
279 if responded.load(std::sync::atomic::Ordering::SeqCst) {
280 Ok(())
281 } else {
282 Err(anyhow!("handler did not send a response"))?
283 }
284 }
285 Err(error) => {
286 server.peer.respond_with_error(
287 receipt,
288 proto::Error {
289 message: error.to_string(),
290 },
291 )?;
292 Err(error)
293 }
294 }
295 }
296 })
297 }
298
299 pub fn handle_connection<E: Executor>(
300 self: &Arc<Self>,
301 connection: Connection,
302 address: String,
303 user: User,
304 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
305 executor: E,
306 ) -> impl Future<Output = Result<()>> {
307 let mut this = self.clone();
308 let user_id = user.id;
309 let login = user.github_login;
310 let span = info_span!("handle connection", %user_id, %login, %address);
311 async move {
312 let (connection_id, handle_io, mut incoming_rx) = this
313 .peer
314 .add_connection(connection, {
315 let executor = executor.clone();
316 move |duration| {
317 let timer = executor.sleep(duration);
318 async move {
319 timer.await;
320 }
321 }
322 });
323
324 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
325 this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
326 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
327
328 if let Some(send_connection_id) = send_connection_id.take() {
329 let _ = send_connection_id.send(connection_id);
330 }
331
332 if !user.connected_once {
333 this.peer.send(connection_id, proto::ShowContacts {})?;
334 this.app_state.db.set_user_connected_once(user_id, true).await?;
335 }
336
337 let (contacts, invite_code) = future::try_join(
338 this.app_state.db.get_contacts(user_id),
339 this.app_state.db.get_invite_code_for_user(user_id)
340 ).await?;
341
342 {
343 let mut pool = this.connection_pool().await;
344 pool.add_connection(connection_id, user_id, user.admin);
345 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
346
347 if let Some((code, count)) = invite_code {
348 this.peer.send(connection_id, proto::UpdateInviteInfo {
349 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
350 count,
351 })?;
352 }
353 }
354
355 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
356 this.peer.send(connection_id, incoming_call)?;
357 }
358
359 this.update_user_contacts(user_id).await?;
360
361 let handle_io = handle_io.fuse();
362 futures::pin_mut!(handle_io);
363
364 let db = Arc::new(Mutex::new(DbHandle(this.app_state.db.clone())));
365
366 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
367 // This prevents deadlocks when e.g., client A performs a request to client B and
368 // client B performs a request to client A. If both clients stop processing further
369 // messages until their respective request completes, they won't have a chance to
370 // respond to the other client's request and cause a deadlock.
371 //
372 // This arrangement ensures we will attempt to process earlier messages first, but fall
373 // back to processing messages arrived later in the spirit of making progress.
374 let mut foreground_message_handlers = FuturesUnordered::new();
375 loop {
376 let next_message = incoming_rx.next().fuse();
377 futures::pin_mut!(next_message);
378 futures::select_biased! {
379 result = handle_io => {
380 if let Err(error) = result {
381 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
382 }
383 break;
384 }
385 _ = foreground_message_handlers.next() => {}
386 message = next_message => {
387 if let Some(message) = message {
388 let type_name = message.payload_type_name();
389 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
390 let span_enter = span.enter();
391 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
392 let is_background = message.is_background();
393 let session = Session {
394 user_id,
395 connection_id,
396 db: db.clone(),
397 };
398 let handle_message = (handler)(this.clone(), message, session);
399 drop(span_enter);
400
401 let handle_message = handle_message.instrument(span);
402 if is_background {
403 executor.spawn_detached(handle_message);
404 } else {
405 foreground_message_handlers.push(handle_message);
406 }
407 } else {
408 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
409 }
410 } else {
411 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
412 break;
413 }
414 }
415 }
416 }
417
418 drop(foreground_message_handlers);
419 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
420 if let Err(error) = this.sign_out(connection_id, user_id).await {
421 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
422 }
423
424 Ok(())
425 }.instrument(span)
426 }
427
428 #[instrument(skip(self), err)]
429 async fn sign_out(
430 self: &mut Arc<Self>,
431 connection_id: ConnectionId,
432 user_id: UserId,
433 ) -> Result<()> {
434 self.peer.disconnect(connection_id);
435 let decline_calls = {
436 let mut pool = self.connection_pool().await;
437 pool.remove_connection(connection_id)?;
438 let mut connections = pool.user_connection_ids(user_id);
439 connections.next().is_none()
440 };
441
442 self.leave_room_for_connection(connection_id, user_id)
443 .await
444 .trace_err();
445 if decline_calls {
446 if let Some(room) = self
447 .app_state
448 .db
449 .decline_call(None, user_id)
450 .await
451 .trace_err()
452 {
453 self.room_updated(&room);
454 }
455 }
456
457 self.update_user_contacts(user_id).await?;
458
459 Ok(())
460 }
461
462 pub async fn invite_code_redeemed(
463 self: &Arc<Self>,
464 inviter_id: UserId,
465 invitee_id: UserId,
466 ) -> Result<()> {
467 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
468 if let Some(code) = &user.invite_code {
469 let pool = self.connection_pool().await;
470 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
471 for connection_id in pool.user_connection_ids(inviter_id) {
472 self.peer.send(
473 connection_id,
474 proto::UpdateContacts {
475 contacts: vec![invitee_contact.clone()],
476 ..Default::default()
477 },
478 )?;
479 self.peer.send(
480 connection_id,
481 proto::UpdateInviteInfo {
482 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
483 count: user.invite_count as u32,
484 },
485 )?;
486 }
487 }
488 }
489 Ok(())
490 }
491
492 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
493 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
494 if let Some(invite_code) = &user.invite_code {
495 let pool = self.connection_pool().await;
496 for connection_id in pool.user_connection_ids(user_id) {
497 self.peer.send(
498 connection_id,
499 proto::UpdateInviteInfo {
500 url: format!(
501 "{}{}",
502 self.app_state.config.invite_link_prefix, invite_code
503 ),
504 count: user.invite_count as u32,
505 },
506 )?;
507 }
508 }
509 }
510 Ok(())
511 }
512
513 async fn ping(
514 self: Arc<Server>,
515 _: proto::Ping,
516 response: Response<proto::Ping>,
517 _session: Session,
518 ) -> Result<()> {
519 response.send(proto::Ack {})?;
520 Ok(())
521 }
522
523 async fn create_room(
524 self: Arc<Server>,
525 _request: proto::CreateRoom,
526 response: Response<proto::CreateRoom>,
527 session: Session,
528 ) -> Result<()> {
529 let room = self
530 .app_state
531 .db
532 .create_room(session.user_id, session.connection_id)
533 .await?;
534
535 let live_kit_connection_info =
536 if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
537 if let Some(_) = live_kit
538 .create_room(room.live_kit_room.clone())
539 .await
540 .trace_err()
541 {
542 if let Some(token) = live_kit
543 .room_token(&room.live_kit_room, &session.connection_id.to_string())
544 .trace_err()
545 {
546 Some(proto::LiveKitConnectionInfo {
547 server_url: live_kit.url().into(),
548 token,
549 })
550 } else {
551 None
552 }
553 } else {
554 None
555 }
556 } else {
557 None
558 };
559
560 response.send(proto::CreateRoomResponse {
561 room: Some(room),
562 live_kit_connection_info,
563 })?;
564 self.update_user_contacts(session.user_id).await?;
565 Ok(())
566 }
567
568 async fn join_room(
569 self: Arc<Server>,
570 request: proto::JoinRoom,
571 response: Response<proto::JoinRoom>,
572 session: Session,
573 ) -> Result<()> {
574 let room = self
575 .app_state
576 .db
577 .join_room(
578 RoomId::from_proto(request.id),
579 session.user_id,
580 session.connection_id,
581 )
582 .await?;
583 for connection_id in self
584 .connection_pool()
585 .await
586 .user_connection_ids(session.user_id)
587 {
588 self.peer
589 .send(connection_id, proto::CallCanceled {})
590 .trace_err();
591 }
592
593 let live_kit_connection_info =
594 if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
595 if let Some(token) = live_kit
596 .room_token(&room.live_kit_room, &session.connection_id.to_string())
597 .trace_err()
598 {
599 Some(proto::LiveKitConnectionInfo {
600 server_url: live_kit.url().into(),
601 token,
602 })
603 } else {
604 None
605 }
606 } else {
607 None
608 };
609
610 self.room_updated(&room);
611 response.send(proto::JoinRoomResponse {
612 room: Some(room),
613 live_kit_connection_info,
614 })?;
615
616 self.update_user_contacts(session.user_id).await?;
617 Ok(())
618 }
619
620 async fn leave_room(
621 self: Arc<Server>,
622 _message: proto::LeaveRoom,
623 session: Session,
624 ) -> Result<()> {
625 self.leave_room_for_connection(session.connection_id, session.user_id)
626 .await
627 }
628
629 async fn leave_room_for_connection(
630 self: &Arc<Server>,
631 leaving_connection_id: ConnectionId,
632 leaving_user_id: UserId,
633 ) -> Result<()> {
634 let mut contacts_to_update = HashSet::default();
635
636 let Some(left_room) = self.app_state.db.leave_room(leaving_connection_id).await? else {
637 return Err(anyhow!("no room to leave"))?;
638 };
639 contacts_to_update.insert(leaving_user_id);
640
641 for project in left_room.left_projects.into_values() {
642 for connection_id in project.connection_ids {
643 if project.host_user_id == leaving_user_id {
644 self.peer
645 .send(
646 connection_id,
647 proto::UnshareProject {
648 project_id: project.id.to_proto(),
649 },
650 )
651 .trace_err();
652 } else {
653 self.peer
654 .send(
655 connection_id,
656 proto::RemoveProjectCollaborator {
657 project_id: project.id.to_proto(),
658 peer_id: leaving_connection_id.0,
659 },
660 )
661 .trace_err();
662 }
663 }
664
665 self.peer
666 .send(
667 leaving_connection_id,
668 proto::UnshareProject {
669 project_id: project.id.to_proto(),
670 },
671 )
672 .trace_err();
673 }
674
675 self.room_updated(&left_room.room);
676 {
677 let pool = self.connection_pool().await;
678 for canceled_user_id in left_room.canceled_calls_to_user_ids {
679 for connection_id in pool.user_connection_ids(canceled_user_id) {
680 self.peer
681 .send(connection_id, proto::CallCanceled {})
682 .trace_err();
683 }
684 contacts_to_update.insert(canceled_user_id);
685 }
686 }
687
688 for contact_user_id in contacts_to_update {
689 self.update_user_contacts(contact_user_id).await?;
690 }
691
692 if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
693 live_kit
694 .remove_participant(
695 left_room.room.live_kit_room.clone(),
696 leaving_connection_id.to_string(),
697 )
698 .await
699 .trace_err();
700
701 if left_room.room.participants.is_empty() {
702 live_kit
703 .delete_room(left_room.room.live_kit_room)
704 .await
705 .trace_err();
706 }
707 }
708
709 Ok(())
710 }
711
712 async fn call(
713 self: Arc<Server>,
714 request: proto::Call,
715 response: Response<proto::Call>,
716 session: Session,
717 ) -> Result<()> {
718 let room_id = RoomId::from_proto(request.room_id);
719 let calling_user_id = session.user_id;
720 let calling_connection_id = session.connection_id;
721 let called_user_id = UserId::from_proto(request.called_user_id);
722 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
723 if !self
724 .app_state
725 .db
726 .has_contact(calling_user_id, called_user_id)
727 .await?
728 {
729 return Err(anyhow!("cannot call a user who isn't a contact"))?;
730 }
731
732 let (room, incoming_call) = self
733 .app_state
734 .db
735 .call(
736 room_id,
737 calling_user_id,
738 calling_connection_id,
739 called_user_id,
740 initial_project_id,
741 )
742 .await?;
743 self.room_updated(&room);
744 self.update_user_contacts(called_user_id).await?;
745
746 let mut calls = self
747 .connection_pool()
748 .await
749 .user_connection_ids(called_user_id)
750 .map(|connection_id| self.peer.request(connection_id, incoming_call.clone()))
751 .collect::<FuturesUnordered<_>>();
752
753 while let Some(call_response) = calls.next().await {
754 match call_response.as_ref() {
755 Ok(_) => {
756 response.send(proto::Ack {})?;
757 return Ok(());
758 }
759 Err(_) => {
760 call_response.trace_err();
761 }
762 }
763 }
764
765 let room = self
766 .app_state
767 .db
768 .call_failed(room_id, called_user_id)
769 .await?;
770 self.room_updated(&room);
771 self.update_user_contacts(called_user_id).await?;
772
773 Err(anyhow!("failed to ring user"))?
774 }
775
776 async fn cancel_call(
777 self: Arc<Server>,
778 request: proto::CancelCall,
779 response: Response<proto::CancelCall>,
780 session: Session,
781 ) -> Result<()> {
782 let called_user_id = UserId::from_proto(request.called_user_id);
783 let room_id = RoomId::from_proto(request.room_id);
784 let room = self
785 .app_state
786 .db
787 .cancel_call(Some(room_id), session.connection_id, called_user_id)
788 .await?;
789 for connection_id in self
790 .connection_pool()
791 .await
792 .user_connection_ids(called_user_id)
793 {
794 self.peer
795 .send(connection_id, proto::CallCanceled {})
796 .trace_err();
797 }
798 self.room_updated(&room);
799 response.send(proto::Ack {})?;
800
801 self.update_user_contacts(called_user_id).await?;
802 Ok(())
803 }
804
805 async fn decline_call(
806 self: Arc<Server>,
807 message: proto::DeclineCall,
808 session: Session,
809 ) -> Result<()> {
810 let room_id = RoomId::from_proto(message.room_id);
811 let room = self
812 .app_state
813 .db
814 .decline_call(Some(room_id), session.user_id)
815 .await?;
816 for connection_id in self
817 .connection_pool()
818 .await
819 .user_connection_ids(session.user_id)
820 {
821 self.peer
822 .send(connection_id, proto::CallCanceled {})
823 .trace_err();
824 }
825 self.room_updated(&room);
826 self.update_user_contacts(session.user_id).await?;
827 Ok(())
828 }
829
830 async fn update_participant_location(
831 self: Arc<Server>,
832 request: proto::UpdateParticipantLocation,
833 response: Response<proto::UpdateParticipantLocation>,
834 session: Session,
835 ) -> Result<()> {
836 let room_id = RoomId::from_proto(request.room_id);
837 let location = request
838 .location
839 .ok_or_else(|| anyhow!("invalid location"))?;
840 let room = self
841 .app_state
842 .db
843 .update_room_participant_location(room_id, session.connection_id, location)
844 .await?;
845 self.room_updated(&room);
846 response.send(proto::Ack {})?;
847 Ok(())
848 }
849
850 fn room_updated(&self, room: &proto::Room) {
851 for participant in &room.participants {
852 self.peer
853 .send(
854 ConnectionId(participant.peer_id),
855 proto::RoomUpdated {
856 room: Some(room.clone()),
857 },
858 )
859 .trace_err();
860 }
861 }
862
863 async fn share_project(
864 self: Arc<Server>,
865 request: proto::ShareProject,
866 response: Response<proto::ShareProject>,
867 session: Session,
868 ) -> Result<()> {
869 let (project_id, room) = self
870 .app_state
871 .db
872 .share_project(
873 RoomId::from_proto(request.room_id),
874 session.connection_id,
875 &request.worktrees,
876 )
877 .await?;
878 response.send(proto::ShareProjectResponse {
879 project_id: project_id.to_proto(),
880 })?;
881 self.room_updated(&room);
882
883 Ok(())
884 }
885
886 async fn unshare_project(
887 self: Arc<Server>,
888 message: proto::UnshareProject,
889 session: Session,
890 ) -> Result<()> {
891 let project_id = ProjectId::from_proto(message.project_id);
892
893 let (room, guest_connection_ids) = self
894 .app_state
895 .db
896 .unshare_project(project_id, session.connection_id)
897 .await?;
898
899 broadcast(session.connection_id, guest_connection_ids, |conn_id| {
900 self.peer.send(conn_id, message.clone())
901 });
902 self.room_updated(&room);
903
904 Ok(())
905 }
906
907 async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
908 let contacts = self.app_state.db.get_contacts(user_id).await?;
909 let busy = self.app_state.db.is_user_busy(user_id).await?;
910 let pool = self.connection_pool().await;
911 let updated_contact = contact_for_user(user_id, false, busy, &pool);
912 for contact in contacts {
913 if let db::Contact::Accepted {
914 user_id: contact_user_id,
915 ..
916 } = contact
917 {
918 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
919 self.peer
920 .send(
921 contact_conn_id,
922 proto::UpdateContacts {
923 contacts: vec![updated_contact.clone()],
924 remove_contacts: Default::default(),
925 incoming_requests: Default::default(),
926 remove_incoming_requests: Default::default(),
927 outgoing_requests: Default::default(),
928 remove_outgoing_requests: Default::default(),
929 },
930 )
931 .trace_err();
932 }
933 }
934 }
935 Ok(())
936 }
937
938 async fn join_project(
939 self: Arc<Server>,
940 request: proto::JoinProject,
941 response: Response<proto::JoinProject>,
942 session: Session,
943 ) -> Result<()> {
944 let project_id = ProjectId::from_proto(request.project_id);
945 let guest_user_id = session.user_id;
946
947 tracing::info!(%project_id, "join project");
948
949 let (project, replica_id) = self
950 .app_state
951 .db
952 .join_project(project_id, session.connection_id)
953 .await?;
954
955 let collaborators = project
956 .collaborators
957 .iter()
958 .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32)
959 .map(|collaborator| proto::Collaborator {
960 peer_id: collaborator.connection_id as u32,
961 replica_id: collaborator.replica_id.0 as u32,
962 user_id: collaborator.user_id.to_proto(),
963 })
964 .collect::<Vec<_>>();
965 let worktrees = project
966 .worktrees
967 .iter()
968 .map(|(id, worktree)| proto::WorktreeMetadata {
969 id: id.to_proto(),
970 root_name: worktree.root_name.clone(),
971 visible: worktree.visible,
972 abs_path: worktree.abs_path.clone(),
973 })
974 .collect::<Vec<_>>();
975
976 for collaborator in &collaborators {
977 self.peer
978 .send(
979 ConnectionId(collaborator.peer_id),
980 proto::AddProjectCollaborator {
981 project_id: project_id.to_proto(),
982 collaborator: Some(proto::Collaborator {
983 peer_id: session.connection_id.0,
984 replica_id: replica_id.0 as u32,
985 user_id: guest_user_id.to_proto(),
986 }),
987 },
988 )
989 .trace_err();
990 }
991
992 // First, we send the metadata associated with each worktree.
993 response.send(proto::JoinProjectResponse {
994 worktrees: worktrees.clone(),
995 replica_id: replica_id.0 as u32,
996 collaborators: collaborators.clone(),
997 language_servers: project.language_servers.clone(),
998 })?;
999
1000 for (worktree_id, worktree) in project.worktrees {
1001 #[cfg(any(test, feature = "test-support"))]
1002 const MAX_CHUNK_SIZE: usize = 2;
1003 #[cfg(not(any(test, feature = "test-support")))]
1004 const MAX_CHUNK_SIZE: usize = 256;
1005
1006 // Stream this worktree's entries.
1007 let message = proto::UpdateWorktree {
1008 project_id: project_id.to_proto(),
1009 worktree_id: worktree_id.to_proto(),
1010 abs_path: worktree.abs_path.clone(),
1011 root_name: worktree.root_name,
1012 updated_entries: worktree.entries,
1013 removed_entries: Default::default(),
1014 scan_id: worktree.scan_id,
1015 is_last_update: worktree.is_complete,
1016 };
1017 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1018 self.peer.send(session.connection_id, update.clone())?;
1019 }
1020
1021 // Stream this worktree's diagnostics.
1022 for summary in worktree.diagnostic_summaries {
1023 self.peer.send(
1024 session.connection_id,
1025 proto::UpdateDiagnosticSummary {
1026 project_id: project_id.to_proto(),
1027 worktree_id: worktree.id.to_proto(),
1028 summary: Some(summary),
1029 },
1030 )?;
1031 }
1032 }
1033
1034 for language_server in &project.language_servers {
1035 self.peer.send(
1036 session.connection_id,
1037 proto::UpdateLanguageServer {
1038 project_id: project_id.to_proto(),
1039 language_server_id: language_server.id,
1040 variant: Some(
1041 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1042 proto::LspDiskBasedDiagnosticsUpdated {},
1043 ),
1044 ),
1045 },
1046 )?;
1047 }
1048
1049 Ok(())
1050 }
1051
1052 async fn leave_project(
1053 self: Arc<Server>,
1054 request: proto::LeaveProject,
1055 session: Session,
1056 ) -> Result<()> {
1057 let sender_id = session.connection_id;
1058 let project_id = ProjectId::from_proto(request.project_id);
1059 let project;
1060 {
1061 project = self
1062 .app_state
1063 .db
1064 .leave_project(project_id, sender_id)
1065 .await?;
1066 tracing::info!(
1067 %project_id,
1068 host_user_id = %project.host_user_id,
1069 host_connection_id = %project.host_connection_id,
1070 "leave project"
1071 );
1072
1073 broadcast(sender_id, project.connection_ids, |conn_id| {
1074 self.peer.send(
1075 conn_id,
1076 proto::RemoveProjectCollaborator {
1077 project_id: project_id.to_proto(),
1078 peer_id: sender_id.0,
1079 },
1080 )
1081 });
1082 }
1083
1084 Ok(())
1085 }
1086
1087 async fn update_project(
1088 self: Arc<Server>,
1089 request: proto::UpdateProject,
1090 response: Response<proto::UpdateProject>,
1091 session: Session,
1092 ) -> Result<()> {
1093 let project_id = ProjectId::from_proto(request.project_id);
1094 let (room, guest_connection_ids) = self
1095 .app_state
1096 .db
1097 .update_project(project_id, session.connection_id, &request.worktrees)
1098 .await?;
1099 broadcast(
1100 session.connection_id,
1101 guest_connection_ids,
1102 |connection_id| {
1103 self.peer
1104 .forward_send(session.connection_id, connection_id, request.clone())
1105 },
1106 );
1107 self.room_updated(&room);
1108 response.send(proto::Ack {})?;
1109
1110 Ok(())
1111 }
1112
1113 async fn update_worktree(
1114 self: Arc<Server>,
1115 request: proto::UpdateWorktree,
1116 response: Response<proto::UpdateWorktree>,
1117 session: Session,
1118 ) -> Result<()> {
1119 let guest_connection_ids = self
1120 .app_state
1121 .db
1122 .update_worktree(&request, session.connection_id)
1123 .await?;
1124
1125 broadcast(
1126 session.connection_id,
1127 guest_connection_ids,
1128 |connection_id| {
1129 self.peer
1130 .forward_send(session.connection_id, connection_id, request.clone())
1131 },
1132 );
1133 response.send(proto::Ack {})?;
1134 Ok(())
1135 }
1136
1137 async fn update_diagnostic_summary(
1138 self: Arc<Server>,
1139 request: proto::UpdateDiagnosticSummary,
1140 response: Response<proto::UpdateDiagnosticSummary>,
1141 session: Session,
1142 ) -> Result<()> {
1143 let guest_connection_ids = self
1144 .app_state
1145 .db
1146 .update_diagnostic_summary(&request, session.connection_id)
1147 .await?;
1148
1149 broadcast(
1150 session.connection_id,
1151 guest_connection_ids,
1152 |connection_id| {
1153 self.peer
1154 .forward_send(session.connection_id, connection_id, request.clone())
1155 },
1156 );
1157
1158 response.send(proto::Ack {})?;
1159 Ok(())
1160 }
1161
1162 async fn start_language_server(
1163 self: Arc<Server>,
1164 request: proto::StartLanguageServer,
1165 session: Session,
1166 ) -> Result<()> {
1167 let guest_connection_ids = self
1168 .app_state
1169 .db
1170 .start_language_server(&request, session.connection_id)
1171 .await?;
1172
1173 broadcast(
1174 session.connection_id,
1175 guest_connection_ids,
1176 |connection_id| {
1177 self.peer
1178 .forward_send(session.connection_id, connection_id, request.clone())
1179 },
1180 );
1181 Ok(())
1182 }
1183
1184 async fn update_language_server(
1185 self: Arc<Server>,
1186 request: proto::UpdateLanguageServer,
1187 session: Session,
1188 ) -> Result<()> {
1189 let project_id = ProjectId::from_proto(request.project_id);
1190 let project_connection_ids = self
1191 .app_state
1192 .db
1193 .project_connection_ids(project_id, session.connection_id)
1194 .await?;
1195 broadcast(
1196 session.connection_id,
1197 project_connection_ids,
1198 |connection_id| {
1199 self.peer
1200 .forward_send(session.connection_id, connection_id, request.clone())
1201 },
1202 );
1203 Ok(())
1204 }
1205
1206 async fn forward_project_request<T>(
1207 self: Arc<Server>,
1208 request: T,
1209 response: Response<T>,
1210 session: Session,
1211 ) -> Result<()>
1212 where
1213 T: EntityMessage + RequestMessage,
1214 {
1215 let project_id = ProjectId::from_proto(request.remote_entity_id());
1216 let collaborators = self
1217 .app_state
1218 .db
1219 .project_collaborators(project_id, session.connection_id)
1220 .await?;
1221 let host = collaborators
1222 .iter()
1223 .find(|collaborator| collaborator.is_host)
1224 .ok_or_else(|| anyhow!("host not found"))?;
1225
1226 let payload = self
1227 .peer
1228 .forward_request(
1229 session.connection_id,
1230 ConnectionId(host.connection_id as u32),
1231 request,
1232 )
1233 .await?;
1234
1235 response.send(payload)?;
1236 Ok(())
1237 }
1238
1239 async fn save_buffer(
1240 self: Arc<Server>,
1241 request: proto::SaveBuffer,
1242 response: Response<proto::SaveBuffer>,
1243 session: Session,
1244 ) -> Result<()> {
1245 let project_id = ProjectId::from_proto(request.project_id);
1246 let collaborators = self
1247 .app_state
1248 .db
1249 .project_collaborators(project_id, session.connection_id)
1250 .await?;
1251 let host = collaborators
1252 .into_iter()
1253 .find(|collaborator| collaborator.is_host)
1254 .ok_or_else(|| anyhow!("host not found"))?;
1255 let host_connection_id = ConnectionId(host.connection_id as u32);
1256 let response_payload = self
1257 .peer
1258 .forward_request(session.connection_id, host_connection_id, request.clone())
1259 .await?;
1260
1261 let mut collaborators = self
1262 .app_state
1263 .db
1264 .project_collaborators(project_id, session.connection_id)
1265 .await?;
1266 collaborators
1267 .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32);
1268 let project_connection_ids = collaborators
1269 .into_iter()
1270 .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
1271 broadcast(host_connection_id, project_connection_ids, |conn_id| {
1272 self.peer
1273 .forward_send(host_connection_id, conn_id, response_payload.clone())
1274 });
1275 response.send(response_payload)?;
1276 Ok(())
1277 }
1278
1279 async fn create_buffer_for_peer(
1280 self: Arc<Server>,
1281 request: proto::CreateBufferForPeer,
1282 session: Session,
1283 ) -> Result<()> {
1284 self.peer.forward_send(
1285 session.connection_id,
1286 ConnectionId(request.peer_id),
1287 request,
1288 )?;
1289 Ok(())
1290 }
1291
1292 async fn update_buffer(
1293 self: Arc<Server>,
1294 request: proto::UpdateBuffer,
1295 response: Response<proto::UpdateBuffer>,
1296 session: Session,
1297 ) -> Result<()> {
1298 let project_id = ProjectId::from_proto(request.project_id);
1299 let project_connection_ids = self
1300 .app_state
1301 .db
1302 .project_connection_ids(project_id, session.connection_id)
1303 .await?;
1304
1305 broadcast(
1306 session.connection_id,
1307 project_connection_ids,
1308 |connection_id| {
1309 self.peer
1310 .forward_send(session.connection_id, connection_id, request.clone())
1311 },
1312 );
1313 response.send(proto::Ack {})?;
1314 Ok(())
1315 }
1316
1317 async fn update_buffer_file(
1318 self: Arc<Server>,
1319 request: proto::UpdateBufferFile,
1320 session: Session,
1321 ) -> Result<()> {
1322 let project_id = ProjectId::from_proto(request.project_id);
1323 let project_connection_ids = self
1324 .app_state
1325 .db
1326 .project_connection_ids(project_id, session.connection_id)
1327 .await?;
1328
1329 broadcast(
1330 session.connection_id,
1331 project_connection_ids,
1332 |connection_id| {
1333 self.peer
1334 .forward_send(session.connection_id, connection_id, request.clone())
1335 },
1336 );
1337 Ok(())
1338 }
1339
1340 async fn buffer_reloaded(
1341 self: Arc<Server>,
1342 request: proto::BufferReloaded,
1343 session: Session,
1344 ) -> Result<()> {
1345 let project_id = ProjectId::from_proto(request.project_id);
1346 let project_connection_ids = self
1347 .app_state
1348 .db
1349 .project_connection_ids(project_id, session.connection_id)
1350 .await?;
1351 broadcast(
1352 session.connection_id,
1353 project_connection_ids,
1354 |connection_id| {
1355 self.peer
1356 .forward_send(session.connection_id, connection_id, request.clone())
1357 },
1358 );
1359 Ok(())
1360 }
1361
1362 async fn buffer_saved(
1363 self: Arc<Server>,
1364 request: proto::BufferSaved,
1365 session: Session,
1366 ) -> Result<()> {
1367 let project_id = ProjectId::from_proto(request.project_id);
1368 let project_connection_ids = self
1369 .app_state
1370 .db
1371 .project_connection_ids(project_id, session.connection_id)
1372 .await?;
1373 broadcast(
1374 session.connection_id,
1375 project_connection_ids,
1376 |connection_id| {
1377 self.peer
1378 .forward_send(session.connection_id, connection_id, request.clone())
1379 },
1380 );
1381 Ok(())
1382 }
1383
1384 async fn follow(
1385 self: Arc<Self>,
1386 request: proto::Follow,
1387 response: Response<proto::Follow>,
1388 session: Session,
1389 ) -> Result<()> {
1390 let project_id = ProjectId::from_proto(request.project_id);
1391 let leader_id = ConnectionId(request.leader_id);
1392 let follower_id = session.connection_id;
1393 let project_connection_ids = self
1394 .app_state
1395 .db
1396 .project_connection_ids(project_id, session.connection_id)
1397 .await?;
1398
1399 if !project_connection_ids.contains(&leader_id) {
1400 Err(anyhow!("no such peer"))?;
1401 }
1402
1403 let mut response_payload = self
1404 .peer
1405 .forward_request(session.connection_id, leader_id, request)
1406 .await?;
1407 response_payload
1408 .views
1409 .retain(|view| view.leader_id != Some(follower_id.0));
1410 response.send(response_payload)?;
1411 Ok(())
1412 }
1413
1414 async fn unfollow(self: Arc<Self>, request: proto::Unfollow, session: Session) -> Result<()> {
1415 let project_id = ProjectId::from_proto(request.project_id);
1416 let leader_id = ConnectionId(request.leader_id);
1417 let project_connection_ids = self
1418 .app_state
1419 .db
1420 .project_connection_ids(project_id, session.connection_id)
1421 .await?;
1422 if !project_connection_ids.contains(&leader_id) {
1423 Err(anyhow!("no such peer"))?;
1424 }
1425 self.peer
1426 .forward_send(session.connection_id, leader_id, request)?;
1427 Ok(())
1428 }
1429
1430 async fn update_followers(
1431 self: Arc<Self>,
1432 request: proto::UpdateFollowers,
1433 session: Session,
1434 ) -> Result<()> {
1435 let project_id = ProjectId::from_proto(request.project_id);
1436 let project_connection_ids = session
1437 .db
1438 .lock()
1439 .await
1440 .project_connection_ids(project_id, session.connection_id)
1441 .await?;
1442
1443 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1444 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1445 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1446 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1447 });
1448 for follower_id in &request.follower_ids {
1449 let follower_id = ConnectionId(*follower_id);
1450 if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1451 self.peer
1452 .forward_send(session.connection_id, follower_id, request.clone())?;
1453 }
1454 }
1455 Ok(())
1456 }
1457
1458 async fn get_users(
1459 self: Arc<Server>,
1460 request: proto::GetUsers,
1461 response: Response<proto::GetUsers>,
1462 _session: Session,
1463 ) -> Result<()> {
1464 let user_ids = request
1465 .user_ids
1466 .into_iter()
1467 .map(UserId::from_proto)
1468 .collect();
1469 let users = self
1470 .app_state
1471 .db
1472 .get_users_by_ids(user_ids)
1473 .await?
1474 .into_iter()
1475 .map(|user| proto::User {
1476 id: user.id.to_proto(),
1477 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1478 github_login: user.github_login,
1479 })
1480 .collect();
1481 response.send(proto::UsersResponse { users })?;
1482 Ok(())
1483 }
1484
1485 async fn fuzzy_search_users(
1486 self: Arc<Server>,
1487 request: proto::FuzzySearchUsers,
1488 response: Response<proto::FuzzySearchUsers>,
1489 session: Session,
1490 ) -> Result<()> {
1491 let query = request.query;
1492 let db = &self.app_state.db;
1493 let users = match query.len() {
1494 0 => vec![],
1495 1 | 2 => db
1496 .get_user_by_github_account(&query, None)
1497 .await?
1498 .into_iter()
1499 .collect(),
1500 _ => db.fuzzy_search_users(&query, 10).await?,
1501 };
1502 let users = users
1503 .into_iter()
1504 .filter(|user| user.id != session.user_id)
1505 .map(|user| proto::User {
1506 id: user.id.to_proto(),
1507 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1508 github_login: user.github_login,
1509 })
1510 .collect();
1511 response.send(proto::UsersResponse { users })?;
1512 Ok(())
1513 }
1514
1515 async fn request_contact(
1516 self: Arc<Server>,
1517 request: proto::RequestContact,
1518 response: Response<proto::RequestContact>,
1519 session: Session,
1520 ) -> Result<()> {
1521 let requester_id = session.user_id;
1522 let responder_id = UserId::from_proto(request.responder_id);
1523 if requester_id == responder_id {
1524 return Err(anyhow!("cannot add yourself as a contact"))?;
1525 }
1526
1527 self.app_state
1528 .db
1529 .send_contact_request(requester_id, responder_id)
1530 .await?;
1531
1532 // Update outgoing contact requests of requester
1533 let mut update = proto::UpdateContacts::default();
1534 update.outgoing_requests.push(responder_id.to_proto());
1535 for connection_id in self
1536 .connection_pool()
1537 .await
1538 .user_connection_ids(requester_id)
1539 {
1540 self.peer.send(connection_id, update.clone())?;
1541 }
1542
1543 // Update incoming contact requests of responder
1544 let mut update = proto::UpdateContacts::default();
1545 update
1546 .incoming_requests
1547 .push(proto::IncomingContactRequest {
1548 requester_id: requester_id.to_proto(),
1549 should_notify: true,
1550 });
1551 for connection_id in self
1552 .connection_pool()
1553 .await
1554 .user_connection_ids(responder_id)
1555 {
1556 self.peer.send(connection_id, update.clone())?;
1557 }
1558
1559 response.send(proto::Ack {})?;
1560 Ok(())
1561 }
1562
1563 async fn respond_to_contact_request(
1564 self: Arc<Server>,
1565 request: proto::RespondToContactRequest,
1566 response: Response<proto::RespondToContactRequest>,
1567 session: Session,
1568 ) -> Result<()> {
1569 let responder_id = session.user_id;
1570 let requester_id = UserId::from_proto(request.requester_id);
1571 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1572 self.app_state
1573 .db
1574 .dismiss_contact_notification(responder_id, requester_id)
1575 .await?;
1576 } else {
1577 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1578 self.app_state
1579 .db
1580 .respond_to_contact_request(responder_id, requester_id, accept)
1581 .await?;
1582 let busy = self.app_state.db.is_user_busy(requester_id).await?;
1583
1584 let pool = self.connection_pool().await;
1585 // Update responder with new contact
1586 let mut update = proto::UpdateContacts::default();
1587 if accept {
1588 update
1589 .contacts
1590 .push(contact_for_user(requester_id, false, busy, &pool));
1591 }
1592 update
1593 .remove_incoming_requests
1594 .push(requester_id.to_proto());
1595 for connection_id in pool.user_connection_ids(responder_id) {
1596 self.peer.send(connection_id, update.clone())?;
1597 }
1598
1599 // Update requester with new contact
1600 let mut update = proto::UpdateContacts::default();
1601 if accept {
1602 update
1603 .contacts
1604 .push(contact_for_user(responder_id, true, busy, &pool));
1605 }
1606 update
1607 .remove_outgoing_requests
1608 .push(responder_id.to_proto());
1609 for connection_id in pool.user_connection_ids(requester_id) {
1610 self.peer.send(connection_id, update.clone())?;
1611 }
1612 }
1613
1614 response.send(proto::Ack {})?;
1615 Ok(())
1616 }
1617
1618 async fn remove_contact(
1619 self: Arc<Server>,
1620 request: proto::RemoveContact,
1621 response: Response<proto::RemoveContact>,
1622 session: Session,
1623 ) -> Result<()> {
1624 let requester_id = session.user_id;
1625 let responder_id = UserId::from_proto(request.user_id);
1626 self.app_state
1627 .db
1628 .remove_contact(requester_id, responder_id)
1629 .await?;
1630
1631 // Update outgoing contact requests of requester
1632 let mut update = proto::UpdateContacts::default();
1633 update
1634 .remove_outgoing_requests
1635 .push(responder_id.to_proto());
1636 for connection_id in self
1637 .connection_pool()
1638 .await
1639 .user_connection_ids(requester_id)
1640 {
1641 self.peer.send(connection_id, update.clone())?;
1642 }
1643
1644 // Update incoming contact requests of responder
1645 let mut update = proto::UpdateContacts::default();
1646 update
1647 .remove_incoming_requests
1648 .push(requester_id.to_proto());
1649 for connection_id in self
1650 .connection_pool()
1651 .await
1652 .user_connection_ids(responder_id)
1653 {
1654 self.peer.send(connection_id, update.clone())?;
1655 }
1656
1657 response.send(proto::Ack {})?;
1658 Ok(())
1659 }
1660
1661 async fn update_diff_base(
1662 self: Arc<Server>,
1663 request: proto::UpdateDiffBase,
1664 session: Session,
1665 ) -> Result<()> {
1666 let project_id = ProjectId::from_proto(request.project_id);
1667 let project_connection_ids = self
1668 .app_state
1669 .db
1670 .project_connection_ids(project_id, session.connection_id)
1671 .await?;
1672 broadcast(
1673 session.connection_id,
1674 project_connection_ids,
1675 |connection_id| {
1676 self.peer
1677 .forward_send(session.connection_id, connection_id, request.clone())
1678 },
1679 );
1680 Ok(())
1681 }
1682
1683 async fn get_private_user_info(
1684 self: Arc<Self>,
1685 _request: proto::GetPrivateUserInfo,
1686 response: Response<proto::GetPrivateUserInfo>,
1687 session: Session,
1688 ) -> Result<()> {
1689 let metrics_id = self
1690 .app_state
1691 .db
1692 .get_user_metrics_id(session.user_id)
1693 .await?;
1694 let user = self
1695 .app_state
1696 .db
1697 .get_user_by_id(session.user_id)
1698 .await?
1699 .ok_or_else(|| anyhow!("user not found"))?;
1700 response.send(proto::GetPrivateUserInfoResponse {
1701 metrics_id,
1702 staff: user.admin,
1703 })?;
1704 Ok(())
1705 }
1706
1707 pub(crate) async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
1708 #[cfg(test)]
1709 tokio::task::yield_now().await;
1710 let guard = self.connection_pool.lock().await;
1711 #[cfg(test)]
1712 tokio::task::yield_now().await;
1713 ConnectionPoolGuard {
1714 guard,
1715 _not_send: PhantomData,
1716 }
1717 }
1718
1719 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1720 ServerSnapshot {
1721 connection_pool: self.connection_pool().await,
1722 peer: &self.peer,
1723 }
1724 }
1725}
1726
1727impl<'a> Deref for ConnectionPoolGuard<'a> {
1728 type Target = ConnectionPool;
1729
1730 fn deref(&self) -> &Self::Target {
1731 &*self.guard
1732 }
1733}
1734
1735impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1736 fn deref_mut(&mut self) -> &mut Self::Target {
1737 &mut *self.guard
1738 }
1739}
1740
1741impl<'a> Drop for ConnectionPoolGuard<'a> {
1742 fn drop(&mut self) {
1743 #[cfg(test)]
1744 self.check_invariants();
1745 }
1746}
1747
1748impl Executor for RealExecutor {
1749 type Sleep = Sleep;
1750
1751 fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
1752 tokio::task::spawn(future);
1753 }
1754
1755 fn sleep(&self, duration: Duration) -> Self::Sleep {
1756 tokio::time::sleep(duration)
1757 }
1758}
1759
1760fn broadcast<F>(
1761 sender_id: ConnectionId,
1762 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1763 mut f: F,
1764) where
1765 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1766{
1767 for receiver_id in receiver_ids {
1768 if receiver_id != sender_id {
1769 f(receiver_id).trace_err();
1770 }
1771 }
1772}
1773
1774lazy_static! {
1775 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
1776}
1777
1778pub struct ProtocolVersion(u32);
1779
1780impl Header for ProtocolVersion {
1781 fn name() -> &'static HeaderName {
1782 &ZED_PROTOCOL_VERSION
1783 }
1784
1785 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1786 where
1787 Self: Sized,
1788 I: Iterator<Item = &'i axum::http::HeaderValue>,
1789 {
1790 let version = values
1791 .next()
1792 .ok_or_else(axum::headers::Error::invalid)?
1793 .to_str()
1794 .map_err(|_| axum::headers::Error::invalid())?
1795 .parse()
1796 .map_err(|_| axum::headers::Error::invalid())?;
1797 Ok(Self(version))
1798 }
1799
1800 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1801 values.extend([self.0.to_string().parse().unwrap()]);
1802 }
1803}
1804
1805pub fn routes(server: Arc<Server>) -> Router<Body> {
1806 Router::new()
1807 .route("/rpc", get(handle_websocket_request))
1808 .layer(
1809 ServiceBuilder::new()
1810 .layer(Extension(server.app_state.clone()))
1811 .layer(middleware::from_fn(auth::validate_header)),
1812 )
1813 .route("/metrics", get(handle_metrics))
1814 .layer(Extension(server))
1815}
1816
1817pub async fn handle_websocket_request(
1818 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1819 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1820 Extension(server): Extension<Arc<Server>>,
1821 Extension(user): Extension<User>,
1822 ws: WebSocketUpgrade,
1823) -> axum::response::Response {
1824 if protocol_version != rpc::PROTOCOL_VERSION {
1825 return (
1826 StatusCode::UPGRADE_REQUIRED,
1827 "client must be upgraded".to_string(),
1828 )
1829 .into_response();
1830 }
1831 let socket_address = socket_address.to_string();
1832 ws.on_upgrade(move |socket| {
1833 use util::ResultExt;
1834 let socket = socket
1835 .map_ok(to_tungstenite_message)
1836 .err_into()
1837 .with(|message| async move { Ok(to_axum_message(message)) });
1838 let connection = Connection::new(Box::pin(socket));
1839 async move {
1840 server
1841 .handle_connection(connection, socket_address, user, None, RealExecutor)
1842 .await
1843 .log_err();
1844 }
1845 })
1846}
1847
1848pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1849 let connections = server
1850 .connection_pool()
1851 .await
1852 .connections()
1853 .filter(|connection| !connection.admin)
1854 .count();
1855
1856 METRIC_CONNECTIONS.set(connections as _);
1857
1858 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1859 METRIC_SHARED_PROJECTS.set(shared_projects as _);
1860
1861 let encoder = prometheus::TextEncoder::new();
1862 let metric_families = prometheus::gather();
1863 let encoded_metrics = encoder
1864 .encode_to_string(&metric_families)
1865 .map_err(|err| anyhow!("{}", err))?;
1866 Ok(encoded_metrics)
1867}
1868
1869fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1870 match message {
1871 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1872 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1873 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1874 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1875 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1876 code: frame.code.into(),
1877 reason: frame.reason,
1878 })),
1879 }
1880}
1881
1882fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1883 match message {
1884 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1885 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1886 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1887 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1888 AxumMessage::Close(frame) => {
1889 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1890 code: frame.code.into(),
1891 reason: frame.reason,
1892 }))
1893 }
1894 }
1895}
1896
1897fn build_initial_contacts_update(
1898 contacts: Vec<db::Contact>,
1899 pool: &ConnectionPool,
1900) -> proto::UpdateContacts {
1901 let mut update = proto::UpdateContacts::default();
1902
1903 for contact in contacts {
1904 match contact {
1905 db::Contact::Accepted {
1906 user_id,
1907 should_notify,
1908 busy,
1909 } => {
1910 update
1911 .contacts
1912 .push(contact_for_user(user_id, should_notify, busy, &pool));
1913 }
1914 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
1915 db::Contact::Incoming {
1916 user_id,
1917 should_notify,
1918 } => update
1919 .incoming_requests
1920 .push(proto::IncomingContactRequest {
1921 requester_id: user_id.to_proto(),
1922 should_notify,
1923 }),
1924 }
1925 }
1926
1927 update
1928}
1929
1930fn contact_for_user(
1931 user_id: UserId,
1932 should_notify: bool,
1933 busy: bool,
1934 pool: &ConnectionPool,
1935) -> proto::Contact {
1936 proto::Contact {
1937 user_id: user_id.to_proto(),
1938 online: pool.is_user_online(user_id),
1939 busy,
1940 should_notify,
1941 }
1942}
1943
1944pub trait ResultExt {
1945 type Ok;
1946
1947 fn trace_err(self) -> Option<Self::Ok>;
1948}
1949
1950impl<T, E> ResultExt for Result<T, E>
1951where
1952 E: std::fmt::Debug,
1953{
1954 type Ok = T;
1955
1956 fn trace_err(self) -> Option<T> {
1957 match self {
1958 Ok(value) => Some(value),
1959 Err(error) => {
1960 tracing::error!("{:?}", error);
1961 None
1962 }
1963 }
1964 }
1965}