1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, Database, ProjectId, RoomId, ServerId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
38 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
39};
40use serde::{Serialize, Serializer};
41use std::{
42 any::TypeId,
43 fmt,
44 future::Future,
45 marker::PhantomData,
46 mem,
47 net::SocketAddr,
48 ops::{Deref, DerefMut},
49 rc::Rc,
50 sync::{
51 atomic::{AtomicBool, Ordering::SeqCst},
52 Arc,
53 },
54 time::Duration,
55};
56use tokio::sync::watch;
57use tower::ServiceBuilder;
58use tracing::{info_span, instrument, Instrument};
59
60pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(5);
61pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
62
63lazy_static! {
64 static ref METRIC_CONNECTIONS: IntGauge =
65 register_int_gauge!("connections", "number of connections").unwrap();
66 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
67 "shared_projects",
68 "number of open projects with one or more guests"
69 )
70 .unwrap();
71}
72
73type MessageHandler =
74 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
75
76struct Response<R> {
77 peer: Arc<Peer>,
78 receipt: Receipt<R>,
79 responded: Arc<AtomicBool>,
80}
81
82impl<R: RequestMessage> Response<R> {
83 fn send(self, payload: R::Response) -> Result<()> {
84 self.responded.store(true, SeqCst);
85 self.peer.respond(self.receipt, payload)?;
86 Ok(())
87 }
88}
89
90#[derive(Clone)]
91struct Session {
92 user_id: UserId,
93 connection_id: ConnectionId,
94 db: Arc<tokio::sync::Mutex<DbHandle>>,
95 peer: Arc<Peer>,
96 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
97 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
98}
99
100impl Session {
101 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
102 #[cfg(test)]
103 tokio::task::yield_now().await;
104 let guard = self.db.lock().await;
105 #[cfg(test)]
106 tokio::task::yield_now().await;
107 guard
108 }
109
110 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
111 #[cfg(test)]
112 tokio::task::yield_now().await;
113 let guard = self.connection_pool.lock();
114 ConnectionPoolGuard {
115 guard,
116 _not_send: PhantomData,
117 }
118 }
119}
120
121impl fmt::Debug for Session {
122 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123 f.debug_struct("Session")
124 .field("user_id", &self.user_id)
125 .field("connection_id", &self.connection_id)
126 .finish()
127 }
128}
129
130struct DbHandle(Arc<Database>);
131
132impl Deref for DbHandle {
133 type Target = Database;
134
135 fn deref(&self) -> &Self::Target {
136 self.0.as_ref()
137 }
138}
139
140pub struct Server {
141 epoch: parking_lot::Mutex<ServerId>,
142 peer: Arc<Peer>,
143 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
144 app_state: Arc<AppState>,
145 executor: Executor,
146 handlers: HashMap<TypeId, MessageHandler>,
147 teardown: watch::Sender<()>,
148}
149
150pub(crate) struct ConnectionPoolGuard<'a> {
151 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
152 _not_send: PhantomData<Rc<()>>,
153}
154
155#[derive(Serialize)]
156pub struct ServerSnapshot<'a> {
157 peer: &'a Peer,
158 #[serde(serialize_with = "serialize_deref")]
159 connection_pool: ConnectionPoolGuard<'a>,
160}
161
162pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
163where
164 S: Serializer,
165 T: Deref<Target = U>,
166 U: Serialize,
167{
168 Serialize::serialize(value.deref(), serializer)
169}
170
171impl Server {
172 pub fn new(epoch: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
173 let mut server = Self {
174 epoch: parking_lot::Mutex::new(epoch),
175 peer: Peer::new(epoch.0 as u32),
176 app_state,
177 executor,
178 connection_pool: Default::default(),
179 handlers: Default::default(),
180 teardown: watch::channel(()).0,
181 };
182
183 server
184 .add_request_handler(ping)
185 .add_request_handler(create_room)
186 .add_request_handler(join_room)
187 .add_message_handler(leave_room)
188 .add_request_handler(call)
189 .add_request_handler(cancel_call)
190 .add_message_handler(decline_call)
191 .add_request_handler(update_participant_location)
192 .add_request_handler(share_project)
193 .add_message_handler(unshare_project)
194 .add_request_handler(join_project)
195 .add_message_handler(leave_project)
196 .add_request_handler(update_project)
197 .add_request_handler(update_worktree)
198 .add_message_handler(start_language_server)
199 .add_message_handler(update_language_server)
200 .add_message_handler(update_diagnostic_summary)
201 .add_request_handler(forward_project_request::<proto::GetHover>)
202 .add_request_handler(forward_project_request::<proto::GetDefinition>)
203 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
204 .add_request_handler(forward_project_request::<proto::GetReferences>)
205 .add_request_handler(forward_project_request::<proto::SearchProject>)
206 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
207 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
208 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
209 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
210 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
211 .add_request_handler(forward_project_request::<proto::GetCompletions>)
212 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
213 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
214 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
215 .add_request_handler(forward_project_request::<proto::PrepareRename>)
216 .add_request_handler(forward_project_request::<proto::PerformRename>)
217 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
218 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
219 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
220 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
221 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
222 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
223 .add_message_handler(create_buffer_for_peer)
224 .add_request_handler(update_buffer)
225 .add_message_handler(update_buffer_file)
226 .add_message_handler(buffer_reloaded)
227 .add_message_handler(buffer_saved)
228 .add_request_handler(save_buffer)
229 .add_request_handler(get_users)
230 .add_request_handler(fuzzy_search_users)
231 .add_request_handler(request_contact)
232 .add_request_handler(remove_contact)
233 .add_request_handler(respond_to_contact_request)
234 .add_request_handler(follow)
235 .add_message_handler(unfollow)
236 .add_message_handler(update_followers)
237 .add_message_handler(update_diff_base)
238 .add_request_handler(get_private_user_info);
239
240 Arc::new(server)
241 }
242
243 pub async fn start(&self) -> Result<()> {
244 let epoch = *self.epoch.lock();
245 let app_state = self.app_state.clone();
246 let peer = self.peer.clone();
247 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
248 let pool = self.connection_pool.clone();
249 let live_kit_client = self.app_state.live_kit_client.clone();
250
251 let span = info_span!("start server");
252 let span_enter = span.enter();
253
254 tracing::info!("begin deleting stale projects");
255 app_state
256 .db
257 .delete_stale_projects(epoch, &app_state.config.zed_environment)
258 .await?;
259 tracing::info!("finish deleting stale projects");
260
261 drop(span_enter);
262 self.executor.spawn_detached(
263 async move {
264 tracing::info!("waiting for cleanup timeout");
265 timeout.await;
266 tracing::info!("cleanup timeout expired, retrieving stale rooms");
267 if let Some(room_ids) = app_state
268 .db
269 .stale_room_ids(epoch, &app_state.config.zed_environment)
270 .await
271 .trace_err()
272 {
273 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
274 for room_id in room_ids {
275 let mut contacts_to_update = HashSet::default();
276 let mut canceled_calls_to_user_ids = Vec::new();
277 let mut live_kit_room = String::new();
278 let mut delete_live_kit_room = false;
279
280 if let Ok(mut refreshed_room) =
281 app_state.db.refresh_room(room_id, epoch).await
282 {
283 tracing::info!(
284 room_id = room_id.0,
285 new_participant_count = refreshed_room.room.participants.len(),
286 "refreshed room"
287 );
288 room_updated(&refreshed_room.room, &peer);
289 contacts_to_update
290 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
291 contacts_to_update
292 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
293 canceled_calls_to_user_ids =
294 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
295 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
296 delete_live_kit_room = refreshed_room.room.participants.is_empty();
297 }
298
299 {
300 let pool = pool.lock();
301 for canceled_user_id in canceled_calls_to_user_ids {
302 for connection_id in pool.user_connection_ids(canceled_user_id) {
303 peer.send(
304 connection_id,
305 proto::CallCanceled {
306 room_id: room_id.to_proto(),
307 },
308 )
309 .trace_err();
310 }
311 }
312 }
313
314 for user_id in contacts_to_update {
315 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
316 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
317 if let Some((busy, contacts)) = busy.zip(contacts) {
318 let pool = pool.lock();
319 let updated_contact = contact_for_user(user_id, false, busy, &pool);
320 for contact in contacts {
321 if let db::Contact::Accepted {
322 user_id: contact_user_id,
323 ..
324 } = contact
325 {
326 for contact_conn_id in
327 pool.user_connection_ids(contact_user_id)
328 {
329 peer.send(
330 contact_conn_id,
331 proto::UpdateContacts {
332 contacts: vec![updated_contact.clone()],
333 remove_contacts: Default::default(),
334 incoming_requests: Default::default(),
335 remove_incoming_requests: Default::default(),
336 outgoing_requests: Default::default(),
337 remove_outgoing_requests: Default::default(),
338 },
339 )
340 .trace_err();
341 }
342 }
343 }
344 }
345 }
346
347 if let Some(live_kit) = live_kit_client.as_ref() {
348 if delete_live_kit_room {
349 live_kit.delete_room(live_kit_room).await.trace_err();
350 }
351 }
352 }
353 }
354
355 app_state
356 .db
357 .delete_stale_servers(epoch, &app_state.config.zed_environment)
358 .await
359 .trace_err();
360 }
361 .instrument(span),
362 );
363 Ok(())
364 }
365
366 pub fn teardown(&self) {
367 self.peer.teardown();
368 self.connection_pool.lock().reset();
369 let _ = self.teardown.send(());
370 }
371
372 #[cfg(test)]
373 pub fn reset(&self, epoch: ServerId) {
374 self.teardown();
375 *self.epoch.lock() = epoch;
376 self.peer.reset(epoch.0 as u32);
377 }
378
379 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
380 where
381 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
382 Fut: 'static + Send + Future<Output = Result<()>>,
383 M: EnvelopedMessage,
384 {
385 let prev_handler = self.handlers.insert(
386 TypeId::of::<M>(),
387 Box::new(move |envelope, session| {
388 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
389 let span = info_span!(
390 "handle message",
391 payload_type = envelope.payload_type_name()
392 );
393 span.in_scope(|| {
394 tracing::info!(
395 payload_type = envelope.payload_type_name(),
396 "message received"
397 );
398 });
399 let future = (handler)(*envelope, session);
400 async move {
401 if let Err(error) = future.await {
402 tracing::error!(%error, "error handling message");
403 }
404 }
405 .instrument(span)
406 .boxed()
407 }),
408 );
409 if prev_handler.is_some() {
410 panic!("registered a handler for the same message twice");
411 }
412 self
413 }
414
415 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
416 where
417 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
418 Fut: 'static + Send + Future<Output = Result<()>>,
419 M: EnvelopedMessage,
420 {
421 self.add_handler(move |envelope, session| handler(envelope.payload, session));
422 self
423 }
424
425 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
426 where
427 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
428 Fut: Send + Future<Output = Result<()>>,
429 M: RequestMessage,
430 {
431 let handler = Arc::new(handler);
432 self.add_handler(move |envelope, session| {
433 let receipt = envelope.receipt();
434 let handler = handler.clone();
435 async move {
436 let peer = session.peer.clone();
437 let responded = Arc::new(AtomicBool::default());
438 let response = Response {
439 peer: peer.clone(),
440 responded: responded.clone(),
441 receipt,
442 };
443 match (handler)(envelope.payload, response, session).await {
444 Ok(()) => {
445 if responded.load(std::sync::atomic::Ordering::SeqCst) {
446 Ok(())
447 } else {
448 Err(anyhow!("handler did not send a response"))?
449 }
450 }
451 Err(error) => {
452 peer.respond_with_error(
453 receipt,
454 proto::Error {
455 message: error.to_string(),
456 },
457 )?;
458 Err(error)
459 }
460 }
461 }
462 })
463 }
464
465 pub fn handle_connection(
466 self: &Arc<Self>,
467 connection: Connection,
468 address: String,
469 user: User,
470 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
471 executor: Executor,
472 ) -> impl Future<Output = Result<()>> {
473 let this = self.clone();
474 let user_id = user.id;
475 let login = user.github_login;
476 let span = info_span!("handle connection", %user_id, %login, %address);
477 let mut teardown = self.teardown.subscribe();
478 async move {
479 let (connection_id, handle_io, mut incoming_rx) = this
480 .peer
481 .add_connection(connection, {
482 let executor = executor.clone();
483 move |duration| executor.sleep(duration)
484 });
485
486 tracing::info!(%user_id, %login, ?connection_id, %address, "connection opened");
487 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
488 tracing::info!(%user_id, %login, ?connection_id, %address, "sent hello message");
489
490 if let Some(send_connection_id) = send_connection_id.take() {
491 let _ = send_connection_id.send(connection_id);
492 }
493
494 if !user.connected_once {
495 this.peer.send(connection_id, proto::ShowContacts {})?;
496 this.app_state.db.set_user_connected_once(user_id, true).await?;
497 }
498
499 let (contacts, invite_code) = future::try_join(
500 this.app_state.db.get_contacts(user_id),
501 this.app_state.db.get_invite_code_for_user(user_id)
502 ).await?;
503
504 {
505 let mut pool = this.connection_pool.lock();
506 pool.add_connection(connection_id, user_id, user.admin);
507 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
508
509 if let Some((code, count)) = invite_code {
510 this.peer.send(connection_id, proto::UpdateInviteInfo {
511 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
512 count: count as u32,
513 })?;
514 }
515 }
516
517 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
518 this.peer.send(connection_id, incoming_call)?;
519 }
520
521 let session = Session {
522 user_id,
523 connection_id,
524 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
525 peer: this.peer.clone(),
526 connection_pool: this.connection_pool.clone(),
527 live_kit_client: this.app_state.live_kit_client.clone()
528 };
529 update_user_contacts(user_id, &session).await?;
530
531 let handle_io = handle_io.fuse();
532 futures::pin_mut!(handle_io);
533
534 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
535 // This prevents deadlocks when e.g., client A performs a request to client B and
536 // client B performs a request to client A. If both clients stop processing further
537 // messages until their respective request completes, they won't have a chance to
538 // respond to the other client's request and cause a deadlock.
539 //
540 // This arrangement ensures we will attempt to process earlier messages first, but fall
541 // back to processing messages arrived later in the spirit of making progress.
542 let mut foreground_message_handlers = FuturesUnordered::new();
543 loop {
544 let next_message = incoming_rx.next().fuse();
545 futures::pin_mut!(next_message);
546 futures::select_biased! {
547 _ = teardown.changed().fuse() => return Ok(()),
548 result = handle_io => {
549 if let Err(error) = result {
550 tracing::error!(?error, %user_id, %login, ?connection_id, %address, "error handling I/O");
551 }
552 break;
553 }
554 _ = foreground_message_handlers.next() => {}
555 message = next_message => {
556 if let Some(message) = message {
557 let type_name = message.payload_type_name();
558 let span = tracing::info_span!("receive message", %user_id, %login, ?connection_id, %address, type_name);
559 let span_enter = span.enter();
560 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
561 let is_background = message.is_background();
562 let handle_message = (handler)(message, session.clone());
563 drop(span_enter);
564
565 let handle_message = handle_message.instrument(span);
566 if is_background {
567 executor.spawn_detached(handle_message);
568 } else {
569 foreground_message_handlers.push(handle_message);
570 }
571 } else {
572 tracing::error!(%user_id, %login, ?connection_id, %address, "no message handler");
573 }
574 } else {
575 tracing::info!(%user_id, %login, ?connection_id, %address, "connection closed");
576 break;
577 }
578 }
579 }
580 }
581
582 drop(foreground_message_handlers);
583 tracing::info!(%user_id, %login, ?connection_id, %address, "signing out");
584 if let Err(error) = sign_out(session, teardown, executor).await {
585 tracing::error!(%user_id, %login, ?connection_id, %address, ?error, "error signing out");
586 }
587
588 Ok(())
589 }.instrument(span)
590 }
591
592 pub async fn invite_code_redeemed(
593 self: &Arc<Self>,
594 inviter_id: UserId,
595 invitee_id: UserId,
596 ) -> Result<()> {
597 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
598 if let Some(code) = &user.invite_code {
599 let pool = self.connection_pool.lock();
600 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
601 for connection_id in pool.user_connection_ids(inviter_id) {
602 self.peer.send(
603 connection_id,
604 proto::UpdateContacts {
605 contacts: vec![invitee_contact.clone()],
606 ..Default::default()
607 },
608 )?;
609 self.peer.send(
610 connection_id,
611 proto::UpdateInviteInfo {
612 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
613 count: user.invite_count as u32,
614 },
615 )?;
616 }
617 }
618 }
619 Ok(())
620 }
621
622 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
623 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
624 if let Some(invite_code) = &user.invite_code {
625 let pool = self.connection_pool.lock();
626 for connection_id in pool.user_connection_ids(user_id) {
627 self.peer.send(
628 connection_id,
629 proto::UpdateInviteInfo {
630 url: format!(
631 "{}{}",
632 self.app_state.config.invite_link_prefix, invite_code
633 ),
634 count: user.invite_count as u32,
635 },
636 )?;
637 }
638 }
639 }
640 Ok(())
641 }
642
643 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
644 ServerSnapshot {
645 connection_pool: ConnectionPoolGuard {
646 guard: self.connection_pool.lock(),
647 _not_send: PhantomData,
648 },
649 peer: &self.peer,
650 }
651 }
652}
653
654impl<'a> Deref for ConnectionPoolGuard<'a> {
655 type Target = ConnectionPool;
656
657 fn deref(&self) -> &Self::Target {
658 &*self.guard
659 }
660}
661
662impl<'a> DerefMut for ConnectionPoolGuard<'a> {
663 fn deref_mut(&mut self) -> &mut Self::Target {
664 &mut *self.guard
665 }
666}
667
668impl<'a> Drop for ConnectionPoolGuard<'a> {
669 fn drop(&mut self) {
670 #[cfg(test)]
671 self.check_invariants();
672 }
673}
674
675fn broadcast<F>(
676 sender_id: ConnectionId,
677 receiver_ids: impl IntoIterator<Item = ConnectionId>,
678 mut f: F,
679) where
680 F: FnMut(ConnectionId) -> anyhow::Result<()>,
681{
682 for receiver_id in receiver_ids {
683 if receiver_id != sender_id {
684 f(receiver_id).trace_err();
685 }
686 }
687}
688
689lazy_static! {
690 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
691}
692
693pub struct ProtocolVersion(u32);
694
695impl Header for ProtocolVersion {
696 fn name() -> &'static HeaderName {
697 &ZED_PROTOCOL_VERSION
698 }
699
700 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
701 where
702 Self: Sized,
703 I: Iterator<Item = &'i axum::http::HeaderValue>,
704 {
705 let version = values
706 .next()
707 .ok_or_else(axum::headers::Error::invalid)?
708 .to_str()
709 .map_err(|_| axum::headers::Error::invalid())?
710 .parse()
711 .map_err(|_| axum::headers::Error::invalid())?;
712 Ok(Self(version))
713 }
714
715 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
716 values.extend([self.0.to_string().parse().unwrap()]);
717 }
718}
719
720pub fn routes(server: Arc<Server>) -> Router<Body> {
721 Router::new()
722 .route("/rpc", get(handle_websocket_request))
723 .layer(
724 ServiceBuilder::new()
725 .layer(Extension(server.app_state.clone()))
726 .layer(middleware::from_fn(auth::validate_header)),
727 )
728 .route("/metrics", get(handle_metrics))
729 .layer(Extension(server))
730}
731
732pub async fn handle_websocket_request(
733 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
734 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
735 Extension(server): Extension<Arc<Server>>,
736 Extension(user): Extension<User>,
737 ws: WebSocketUpgrade,
738) -> axum::response::Response {
739 if protocol_version != rpc::PROTOCOL_VERSION {
740 return (
741 StatusCode::UPGRADE_REQUIRED,
742 "client must be upgraded".to_string(),
743 )
744 .into_response();
745 }
746 let socket_address = socket_address.to_string();
747 ws.on_upgrade(move |socket| {
748 use util::ResultExt;
749 let socket = socket
750 .map_ok(to_tungstenite_message)
751 .err_into()
752 .with(|message| async move { Ok(to_axum_message(message)) });
753 let connection = Connection::new(Box::pin(socket));
754 async move {
755 server
756 .handle_connection(connection, socket_address, user, None, Executor::Production)
757 .await
758 .log_err();
759 }
760 })
761}
762
763pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
764 let connections = server
765 .connection_pool
766 .lock()
767 .connections()
768 .filter(|connection| !connection.admin)
769 .count();
770
771 METRIC_CONNECTIONS.set(connections as _);
772
773 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
774 METRIC_SHARED_PROJECTS.set(shared_projects as _);
775
776 let encoder = prometheus::TextEncoder::new();
777 let metric_families = prometheus::gather();
778 let encoded_metrics = encoder
779 .encode_to_string(&metric_families)
780 .map_err(|err| anyhow!("{}", err))?;
781 Ok(encoded_metrics)
782}
783
784#[instrument(err, skip(executor))]
785async fn sign_out(
786 session: Session,
787 mut teardown: watch::Receiver<()>,
788 executor: Executor,
789) -> Result<()> {
790 session.peer.disconnect(session.connection_id);
791 session
792 .connection_pool()
793 .await
794 .remove_connection(session.connection_id)?;
795
796 if let Some(mut left_projects) = session
797 .db()
798 .await
799 .connection_lost(session.connection_id)
800 .await
801 .trace_err()
802 {
803 for left_project in mem::take(&mut *left_projects) {
804 project_left(&left_project, &session);
805 }
806 }
807
808 futures::select_biased! {
809 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
810 leave_room_for_session(&session).await.trace_err();
811
812 if !session
813 .connection_pool()
814 .await
815 .is_user_online(session.user_id)
816 {
817 let db = session.db().await;
818 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err() {
819 room_updated(&room, &session.peer);
820 }
821 }
822 update_user_contacts(session.user_id, &session).await?;
823 }
824 _ = teardown.changed().fuse() => {}
825 }
826
827 Ok(())
828}
829
830async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
831 response.send(proto::Ack {})?;
832 Ok(())
833}
834
835async fn create_room(
836 _request: proto::CreateRoom,
837 response: Response<proto::CreateRoom>,
838 session: Session,
839) -> Result<()> {
840 let live_kit_room = nanoid::nanoid!(30);
841 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
842 if let Some(_) = live_kit
843 .create_room(live_kit_room.clone())
844 .await
845 .trace_err()
846 {
847 if let Some(token) = live_kit
848 .room_token(&live_kit_room, &session.connection_id.to_string())
849 .trace_err()
850 {
851 Some(proto::LiveKitConnectionInfo {
852 server_url: live_kit.url().into(),
853 token,
854 })
855 } else {
856 None
857 }
858 } else {
859 None
860 }
861 } else {
862 None
863 };
864
865 {
866 let room = session
867 .db()
868 .await
869 .create_room(session.user_id, session.connection_id, &live_kit_room)
870 .await?;
871
872 response.send(proto::CreateRoomResponse {
873 room: Some(room.clone()),
874 live_kit_connection_info,
875 })?;
876 }
877
878 update_user_contacts(session.user_id, &session).await?;
879 Ok(())
880}
881
882async fn join_room(
883 request: proto::JoinRoom,
884 response: Response<proto::JoinRoom>,
885 session: Session,
886) -> Result<()> {
887 let room_id = RoomId::from_proto(request.id);
888 let room = {
889 let room = session
890 .db()
891 .await
892 .join_room(room_id, session.user_id, session.connection_id)
893 .await?;
894 room_updated(&room, &session.peer);
895 room.clone()
896 };
897
898 for connection_id in session
899 .connection_pool()
900 .await
901 .user_connection_ids(session.user_id)
902 {
903 session
904 .peer
905 .send(
906 connection_id,
907 proto::CallCanceled {
908 room_id: room_id.to_proto(),
909 },
910 )
911 .trace_err();
912 }
913
914 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
915 if let Some(token) = live_kit
916 .room_token(&room.live_kit_room, &session.connection_id.to_string())
917 .trace_err()
918 {
919 Some(proto::LiveKitConnectionInfo {
920 server_url: live_kit.url().into(),
921 token,
922 })
923 } else {
924 None
925 }
926 } else {
927 None
928 };
929
930 response.send(proto::JoinRoomResponse {
931 room: Some(room),
932 live_kit_connection_info,
933 })?;
934
935 update_user_contacts(session.user_id, &session).await?;
936 Ok(())
937}
938
939async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
940 leave_room_for_session(&session).await
941}
942
943async fn call(
944 request: proto::Call,
945 response: Response<proto::Call>,
946 session: Session,
947) -> Result<()> {
948 let room_id = RoomId::from_proto(request.room_id);
949 let calling_user_id = session.user_id;
950 let calling_connection_id = session.connection_id;
951 let called_user_id = UserId::from_proto(request.called_user_id);
952 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
953 if !session
954 .db()
955 .await
956 .has_contact(calling_user_id, called_user_id)
957 .await?
958 {
959 return Err(anyhow!("cannot call a user who isn't a contact"))?;
960 }
961
962 let incoming_call = {
963 let (room, incoming_call) = &mut *session
964 .db()
965 .await
966 .call(
967 room_id,
968 calling_user_id,
969 calling_connection_id,
970 called_user_id,
971 initial_project_id,
972 )
973 .await?;
974 room_updated(&room, &session.peer);
975 mem::take(incoming_call)
976 };
977 update_user_contacts(called_user_id, &session).await?;
978
979 let mut calls = session
980 .connection_pool()
981 .await
982 .user_connection_ids(called_user_id)
983 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
984 .collect::<FuturesUnordered<_>>();
985
986 while let Some(call_response) = calls.next().await {
987 match call_response.as_ref() {
988 Ok(_) => {
989 response.send(proto::Ack {})?;
990 return Ok(());
991 }
992 Err(_) => {
993 call_response.trace_err();
994 }
995 }
996 }
997
998 {
999 let room = session
1000 .db()
1001 .await
1002 .call_failed(room_id, called_user_id)
1003 .await?;
1004 room_updated(&room, &session.peer);
1005 }
1006 update_user_contacts(called_user_id, &session).await?;
1007
1008 Err(anyhow!("failed to ring user"))?
1009}
1010
1011async fn cancel_call(
1012 request: proto::CancelCall,
1013 response: Response<proto::CancelCall>,
1014 session: Session,
1015) -> Result<()> {
1016 let called_user_id = UserId::from_proto(request.called_user_id);
1017 let room_id = RoomId::from_proto(request.room_id);
1018 {
1019 let room = session
1020 .db()
1021 .await
1022 .cancel_call(Some(room_id), session.connection_id, called_user_id)
1023 .await?;
1024 room_updated(&room, &session.peer);
1025 }
1026
1027 for connection_id in session
1028 .connection_pool()
1029 .await
1030 .user_connection_ids(called_user_id)
1031 {
1032 session
1033 .peer
1034 .send(
1035 connection_id,
1036 proto::CallCanceled {
1037 room_id: room_id.to_proto(),
1038 },
1039 )
1040 .trace_err();
1041 }
1042 response.send(proto::Ack {})?;
1043
1044 update_user_contacts(called_user_id, &session).await?;
1045 Ok(())
1046}
1047
1048async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1049 let room_id = RoomId::from_proto(message.room_id);
1050 {
1051 let room = session
1052 .db()
1053 .await
1054 .decline_call(Some(room_id), session.user_id)
1055 .await?;
1056 room_updated(&room, &session.peer);
1057 }
1058
1059 for connection_id in session
1060 .connection_pool()
1061 .await
1062 .user_connection_ids(session.user_id)
1063 {
1064 session
1065 .peer
1066 .send(
1067 connection_id,
1068 proto::CallCanceled {
1069 room_id: room_id.to_proto(),
1070 },
1071 )
1072 .trace_err();
1073 }
1074 update_user_contacts(session.user_id, &session).await?;
1075 Ok(())
1076}
1077
1078async fn update_participant_location(
1079 request: proto::UpdateParticipantLocation,
1080 response: Response<proto::UpdateParticipantLocation>,
1081 session: Session,
1082) -> Result<()> {
1083 let room_id = RoomId::from_proto(request.room_id);
1084 let location = request
1085 .location
1086 .ok_or_else(|| anyhow!("invalid location"))?;
1087 let room = session
1088 .db()
1089 .await
1090 .update_room_participant_location(room_id, session.connection_id, location)
1091 .await?;
1092 room_updated(&room, &session.peer);
1093 response.send(proto::Ack {})?;
1094 Ok(())
1095}
1096
1097async fn share_project(
1098 request: proto::ShareProject,
1099 response: Response<proto::ShareProject>,
1100 session: Session,
1101) -> Result<()> {
1102 let (project_id, room) = &*session
1103 .db()
1104 .await
1105 .share_project(
1106 RoomId::from_proto(request.room_id),
1107 session.connection_id,
1108 &request.worktrees,
1109 )
1110 .await?;
1111 response.send(proto::ShareProjectResponse {
1112 project_id: project_id.to_proto(),
1113 })?;
1114 room_updated(&room, &session.peer);
1115
1116 Ok(())
1117}
1118
1119async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1120 let project_id = ProjectId::from_proto(message.project_id);
1121
1122 let (room, guest_connection_ids) = &*session
1123 .db()
1124 .await
1125 .unshare_project(project_id, session.connection_id)
1126 .await?;
1127
1128 broadcast(
1129 session.connection_id,
1130 guest_connection_ids.iter().copied(),
1131 |conn_id| session.peer.send(conn_id, message.clone()),
1132 );
1133 room_updated(&room, &session.peer);
1134
1135 Ok(())
1136}
1137
1138async fn join_project(
1139 request: proto::JoinProject,
1140 response: Response<proto::JoinProject>,
1141 session: Session,
1142) -> Result<()> {
1143 let project_id = ProjectId::from_proto(request.project_id);
1144 let guest_user_id = session.user_id;
1145
1146 tracing::info!(%project_id, "join project");
1147
1148 let (project, replica_id) = &mut *session
1149 .db()
1150 .await
1151 .join_project(project_id, session.connection_id)
1152 .await?;
1153
1154 let collaborators = project
1155 .collaborators
1156 .iter()
1157 .map(|collaborator| {
1158 let peer_id = proto::PeerId {
1159 epoch: collaborator.connection_server_id.0 as u32,
1160 id: collaborator.connection_id as u32,
1161 };
1162 proto::Collaborator {
1163 peer_id: Some(peer_id),
1164 replica_id: collaborator.replica_id.0 as u32,
1165 user_id: collaborator.user_id.to_proto(),
1166 }
1167 })
1168 .filter(|collaborator| collaborator.peer_id != Some(session.connection_id.into()))
1169 .collect::<Vec<_>>();
1170 let worktrees = project
1171 .worktrees
1172 .iter()
1173 .map(|(id, worktree)| proto::WorktreeMetadata {
1174 id: *id,
1175 root_name: worktree.root_name.clone(),
1176 visible: worktree.visible,
1177 abs_path: worktree.abs_path.clone(),
1178 })
1179 .collect::<Vec<_>>();
1180
1181 for collaborator in &collaborators {
1182 session
1183 .peer
1184 .send(
1185 collaborator.peer_id.unwrap().into(),
1186 proto::AddProjectCollaborator {
1187 project_id: project_id.to_proto(),
1188 collaborator: Some(proto::Collaborator {
1189 peer_id: Some(session.connection_id.into()),
1190 replica_id: replica_id.0 as u32,
1191 user_id: guest_user_id.to_proto(),
1192 }),
1193 },
1194 )
1195 .trace_err();
1196 }
1197
1198 // First, we send the metadata associated with each worktree.
1199 response.send(proto::JoinProjectResponse {
1200 worktrees: worktrees.clone(),
1201 replica_id: replica_id.0 as u32,
1202 collaborators: collaborators.clone(),
1203 language_servers: project.language_servers.clone(),
1204 })?;
1205
1206 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1207 #[cfg(any(test, feature = "test-support"))]
1208 const MAX_CHUNK_SIZE: usize = 2;
1209 #[cfg(not(any(test, feature = "test-support")))]
1210 const MAX_CHUNK_SIZE: usize = 256;
1211
1212 // Stream this worktree's entries.
1213 let message = proto::UpdateWorktree {
1214 project_id: project_id.to_proto(),
1215 worktree_id,
1216 abs_path: worktree.abs_path.clone(),
1217 root_name: worktree.root_name,
1218 updated_entries: worktree.entries,
1219 removed_entries: Default::default(),
1220 scan_id: worktree.scan_id,
1221 is_last_update: worktree.is_complete,
1222 };
1223 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1224 session.peer.send(session.connection_id, update.clone())?;
1225 }
1226
1227 // Stream this worktree's diagnostics.
1228 for summary in worktree.diagnostic_summaries {
1229 session.peer.send(
1230 session.connection_id,
1231 proto::UpdateDiagnosticSummary {
1232 project_id: project_id.to_proto(),
1233 worktree_id: worktree.id,
1234 summary: Some(summary),
1235 },
1236 )?;
1237 }
1238 }
1239
1240 for language_server in &project.language_servers {
1241 session.peer.send(
1242 session.connection_id,
1243 proto::UpdateLanguageServer {
1244 project_id: project_id.to_proto(),
1245 language_server_id: language_server.id,
1246 variant: Some(
1247 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1248 proto::LspDiskBasedDiagnosticsUpdated {},
1249 ),
1250 ),
1251 },
1252 )?;
1253 }
1254
1255 Ok(())
1256}
1257
1258async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1259 let sender_id = session.connection_id;
1260 let project_id = ProjectId::from_proto(request.project_id);
1261
1262 let project = session
1263 .db()
1264 .await
1265 .leave_project(project_id, sender_id)
1266 .await?;
1267 tracing::info!(
1268 %project_id,
1269 host_user_id = %project.host_user_id,
1270 host_connection_id = %project.host_connection_id,
1271 "leave project"
1272 );
1273 project_left(&project, &session);
1274
1275 Ok(())
1276}
1277
1278async fn update_project(
1279 request: proto::UpdateProject,
1280 response: Response<proto::UpdateProject>,
1281 session: Session,
1282) -> Result<()> {
1283 let project_id = ProjectId::from_proto(request.project_id);
1284 let (room, guest_connection_ids) = &*session
1285 .db()
1286 .await
1287 .update_project(project_id, session.connection_id, &request.worktrees)
1288 .await?;
1289 broadcast(
1290 session.connection_id,
1291 guest_connection_ids.iter().copied(),
1292 |connection_id| {
1293 session
1294 .peer
1295 .forward_send(session.connection_id, connection_id, request.clone())
1296 },
1297 );
1298 room_updated(&room, &session.peer);
1299 response.send(proto::Ack {})?;
1300
1301 Ok(())
1302}
1303
1304async fn update_worktree(
1305 request: proto::UpdateWorktree,
1306 response: Response<proto::UpdateWorktree>,
1307 session: Session,
1308) -> Result<()> {
1309 let guest_connection_ids = session
1310 .db()
1311 .await
1312 .update_worktree(&request, session.connection_id)
1313 .await?;
1314
1315 broadcast(
1316 session.connection_id,
1317 guest_connection_ids.iter().copied(),
1318 |connection_id| {
1319 session
1320 .peer
1321 .forward_send(session.connection_id, connection_id, request.clone())
1322 },
1323 );
1324 response.send(proto::Ack {})?;
1325 Ok(())
1326}
1327
1328async fn update_diagnostic_summary(
1329 message: proto::UpdateDiagnosticSummary,
1330 session: Session,
1331) -> Result<()> {
1332 let guest_connection_ids = session
1333 .db()
1334 .await
1335 .update_diagnostic_summary(&message, session.connection_id)
1336 .await?;
1337
1338 broadcast(
1339 session.connection_id,
1340 guest_connection_ids.iter().copied(),
1341 |connection_id| {
1342 session
1343 .peer
1344 .forward_send(session.connection_id, connection_id, message.clone())
1345 },
1346 );
1347
1348 Ok(())
1349}
1350
1351async fn start_language_server(
1352 request: proto::StartLanguageServer,
1353 session: Session,
1354) -> Result<()> {
1355 let guest_connection_ids = session
1356 .db()
1357 .await
1358 .start_language_server(&request, session.connection_id)
1359 .await?;
1360
1361 broadcast(
1362 session.connection_id,
1363 guest_connection_ids.iter().copied(),
1364 |connection_id| {
1365 session
1366 .peer
1367 .forward_send(session.connection_id, connection_id, request.clone())
1368 },
1369 );
1370 Ok(())
1371}
1372
1373async fn update_language_server(
1374 request: proto::UpdateLanguageServer,
1375 session: Session,
1376) -> Result<()> {
1377 let project_id = ProjectId::from_proto(request.project_id);
1378 let project_connection_ids = session
1379 .db()
1380 .await
1381 .project_connection_ids(project_id, session.connection_id)
1382 .await?;
1383 broadcast(
1384 session.connection_id,
1385 project_connection_ids.iter().copied(),
1386 |connection_id| {
1387 session
1388 .peer
1389 .forward_send(session.connection_id, connection_id, request.clone())
1390 },
1391 );
1392 Ok(())
1393}
1394
1395async fn forward_project_request<T>(
1396 request: T,
1397 response: Response<T>,
1398 session: Session,
1399) -> Result<()>
1400where
1401 T: EntityMessage + RequestMessage,
1402{
1403 let project_id = ProjectId::from_proto(request.remote_entity_id());
1404 let host_connection_id = {
1405 let collaborators = session
1406 .db()
1407 .await
1408 .project_collaborators(project_id, session.connection_id)
1409 .await?;
1410 let host = collaborators
1411 .iter()
1412 .find(|collaborator| collaborator.is_host)
1413 .ok_or_else(|| anyhow!("host not found"))?;
1414 ConnectionId {
1415 epoch: host.connection_server_id.0 as u32,
1416 id: host.connection_id as u32,
1417 }
1418 };
1419
1420 let payload = session
1421 .peer
1422 .forward_request(session.connection_id, host_connection_id, request)
1423 .await?;
1424
1425 response.send(payload)?;
1426 Ok(())
1427}
1428
1429async fn save_buffer(
1430 request: proto::SaveBuffer,
1431 response: Response<proto::SaveBuffer>,
1432 session: Session,
1433) -> Result<()> {
1434 let project_id = ProjectId::from_proto(request.project_id);
1435 let host_connection_id = {
1436 let collaborators = session
1437 .db()
1438 .await
1439 .project_collaborators(project_id, session.connection_id)
1440 .await?;
1441 let host = collaborators
1442 .iter()
1443 .find(|collaborator| collaborator.is_host)
1444 .ok_or_else(|| anyhow!("host not found"))?;
1445 ConnectionId {
1446 epoch: host.connection_server_id.0 as u32,
1447 id: host.connection_id as u32,
1448 }
1449 };
1450 let response_payload = session
1451 .peer
1452 .forward_request(session.connection_id, host_connection_id, request.clone())
1453 .await?;
1454
1455 let mut collaborators = session
1456 .db()
1457 .await
1458 .project_collaborators(project_id, session.connection_id)
1459 .await?;
1460 collaborators.retain(|collaborator| {
1461 let collaborator_connection = ConnectionId {
1462 epoch: collaborator.connection_server_id.0 as u32,
1463 id: collaborator.connection_id as u32,
1464 };
1465 collaborator_connection != session.connection_id
1466 });
1467 let project_connection_ids = collaborators.iter().map(|collaborator| ConnectionId {
1468 epoch: collaborator.connection_server_id.0 as u32,
1469 id: collaborator.connection_id as u32,
1470 });
1471 broadcast(host_connection_id, project_connection_ids, |conn_id| {
1472 session
1473 .peer
1474 .forward_send(host_connection_id, conn_id, response_payload.clone())
1475 });
1476 response.send(response_payload)?;
1477 Ok(())
1478}
1479
1480async fn create_buffer_for_peer(
1481 request: proto::CreateBufferForPeer,
1482 session: Session,
1483) -> Result<()> {
1484 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1485 session
1486 .peer
1487 .forward_send(session.connection_id, peer_id.into(), request)?;
1488 Ok(())
1489}
1490
1491async fn update_buffer(
1492 request: proto::UpdateBuffer,
1493 response: Response<proto::UpdateBuffer>,
1494 session: Session,
1495) -> Result<()> {
1496 let project_id = ProjectId::from_proto(request.project_id);
1497 let project_connection_ids = session
1498 .db()
1499 .await
1500 .project_connection_ids(project_id, session.connection_id)
1501 .await?;
1502
1503 dbg!(session.connection_id, &*project_connection_ids);
1504 broadcast(
1505 session.connection_id,
1506 project_connection_ids.iter().copied(),
1507 |connection_id| {
1508 session
1509 .peer
1510 .forward_send(session.connection_id, connection_id, request.clone())
1511 },
1512 );
1513 response.send(proto::Ack {})?;
1514 Ok(())
1515}
1516
1517async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1518 let project_id = ProjectId::from_proto(request.project_id);
1519 let project_connection_ids = session
1520 .db()
1521 .await
1522 .project_connection_ids(project_id, session.connection_id)
1523 .await?;
1524
1525 broadcast(
1526 session.connection_id,
1527 project_connection_ids.iter().copied(),
1528 |connection_id| {
1529 session
1530 .peer
1531 .forward_send(session.connection_id, connection_id, request.clone())
1532 },
1533 );
1534 Ok(())
1535}
1536
1537async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1538 let project_id = ProjectId::from_proto(request.project_id);
1539 let project_connection_ids = session
1540 .db()
1541 .await
1542 .project_connection_ids(project_id, session.connection_id)
1543 .await?;
1544 broadcast(
1545 session.connection_id,
1546 project_connection_ids.iter().copied(),
1547 |connection_id| {
1548 session
1549 .peer
1550 .forward_send(session.connection_id, connection_id, request.clone())
1551 },
1552 );
1553 Ok(())
1554}
1555
1556async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1557 let project_id = ProjectId::from_proto(request.project_id);
1558 let project_connection_ids = session
1559 .db()
1560 .await
1561 .project_connection_ids(project_id, session.connection_id)
1562 .await?;
1563 broadcast(
1564 session.connection_id,
1565 project_connection_ids.iter().copied(),
1566 |connection_id| {
1567 session
1568 .peer
1569 .forward_send(session.connection_id, connection_id, request.clone())
1570 },
1571 );
1572 Ok(())
1573}
1574
1575async fn follow(
1576 request: proto::Follow,
1577 response: Response<proto::Follow>,
1578 session: Session,
1579) -> Result<()> {
1580 let project_id = ProjectId::from_proto(request.project_id);
1581 let leader_id = request
1582 .leader_id
1583 .ok_or_else(|| anyhow!("invalid leader id"))?
1584 .into();
1585 let follower_id = session.connection_id;
1586 {
1587 let project_connection_ids = session
1588 .db()
1589 .await
1590 .project_connection_ids(project_id, session.connection_id)
1591 .await?;
1592
1593 if !project_connection_ids.contains(&leader_id) {
1594 Err(anyhow!("no such peer"))?;
1595 }
1596 }
1597
1598 let mut response_payload = session
1599 .peer
1600 .forward_request(session.connection_id, leader_id, request)
1601 .await?;
1602 response_payload
1603 .views
1604 .retain(|view| view.leader_id != Some(follower_id.into()));
1605 response.send(response_payload)?;
1606 Ok(())
1607}
1608
1609async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1610 let project_id = ProjectId::from_proto(request.project_id);
1611 let leader_id = request
1612 .leader_id
1613 .ok_or_else(|| anyhow!("invalid leader id"))?
1614 .into();
1615 let project_connection_ids = session
1616 .db()
1617 .await
1618 .project_connection_ids(project_id, session.connection_id)
1619 .await?;
1620 if !project_connection_ids.contains(&leader_id) {
1621 Err(anyhow!("no such peer"))?;
1622 }
1623 session
1624 .peer
1625 .forward_send(session.connection_id, leader_id, request)?;
1626 Ok(())
1627}
1628
1629async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1630 let project_id = ProjectId::from_proto(request.project_id);
1631 let project_connection_ids = session
1632 .db
1633 .lock()
1634 .await
1635 .project_connection_ids(project_id, session.connection_id)
1636 .await?;
1637
1638 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1639 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1640 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1641 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1642 });
1643 for follower_peer_id in request.follower_ids.iter().copied() {
1644 let follower_connection_id = follower_peer_id.into();
1645 if project_connection_ids.contains(&follower_connection_id)
1646 && Some(follower_peer_id) != leader_id
1647 {
1648 session.peer.forward_send(
1649 session.connection_id,
1650 follower_connection_id,
1651 request.clone(),
1652 )?;
1653 }
1654 }
1655 Ok(())
1656}
1657
1658async fn get_users(
1659 request: proto::GetUsers,
1660 response: Response<proto::GetUsers>,
1661 session: Session,
1662) -> Result<()> {
1663 let user_ids = request
1664 .user_ids
1665 .into_iter()
1666 .map(UserId::from_proto)
1667 .collect();
1668 let users = session
1669 .db()
1670 .await
1671 .get_users_by_ids(user_ids)
1672 .await?
1673 .into_iter()
1674 .map(|user| proto::User {
1675 id: user.id.to_proto(),
1676 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1677 github_login: user.github_login,
1678 })
1679 .collect();
1680 response.send(proto::UsersResponse { users })?;
1681 Ok(())
1682}
1683
1684async fn fuzzy_search_users(
1685 request: proto::FuzzySearchUsers,
1686 response: Response<proto::FuzzySearchUsers>,
1687 session: Session,
1688) -> Result<()> {
1689 let query = request.query;
1690 let users = match query.len() {
1691 0 => vec![],
1692 1 | 2 => session
1693 .db()
1694 .await
1695 .get_user_by_github_account(&query, None)
1696 .await?
1697 .into_iter()
1698 .collect(),
1699 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1700 };
1701 let users = users
1702 .into_iter()
1703 .filter(|user| user.id != session.user_id)
1704 .map(|user| proto::User {
1705 id: user.id.to_proto(),
1706 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1707 github_login: user.github_login,
1708 })
1709 .collect();
1710 response.send(proto::UsersResponse { users })?;
1711 Ok(())
1712}
1713
1714async fn request_contact(
1715 request: proto::RequestContact,
1716 response: Response<proto::RequestContact>,
1717 session: Session,
1718) -> Result<()> {
1719 let requester_id = session.user_id;
1720 let responder_id = UserId::from_proto(request.responder_id);
1721 if requester_id == responder_id {
1722 return Err(anyhow!("cannot add yourself as a contact"))?;
1723 }
1724
1725 session
1726 .db()
1727 .await
1728 .send_contact_request(requester_id, responder_id)
1729 .await?;
1730
1731 // Update outgoing contact requests of requester
1732 let mut update = proto::UpdateContacts::default();
1733 update.outgoing_requests.push(responder_id.to_proto());
1734 for connection_id in session
1735 .connection_pool()
1736 .await
1737 .user_connection_ids(requester_id)
1738 {
1739 session.peer.send(connection_id, update.clone())?;
1740 }
1741
1742 // Update incoming contact requests of responder
1743 let mut update = proto::UpdateContacts::default();
1744 update
1745 .incoming_requests
1746 .push(proto::IncomingContactRequest {
1747 requester_id: requester_id.to_proto(),
1748 should_notify: true,
1749 });
1750 for connection_id in session
1751 .connection_pool()
1752 .await
1753 .user_connection_ids(responder_id)
1754 {
1755 session.peer.send(connection_id, update.clone())?;
1756 }
1757
1758 response.send(proto::Ack {})?;
1759 Ok(())
1760}
1761
1762async fn respond_to_contact_request(
1763 request: proto::RespondToContactRequest,
1764 response: Response<proto::RespondToContactRequest>,
1765 session: Session,
1766) -> Result<()> {
1767 let responder_id = session.user_id;
1768 let requester_id = UserId::from_proto(request.requester_id);
1769 let db = session.db().await;
1770 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1771 db.dismiss_contact_notification(responder_id, requester_id)
1772 .await?;
1773 } else {
1774 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1775
1776 db.respond_to_contact_request(responder_id, requester_id, accept)
1777 .await?;
1778 let requester_busy = db.is_user_busy(requester_id).await?;
1779 let responder_busy = db.is_user_busy(responder_id).await?;
1780
1781 let pool = session.connection_pool().await;
1782 // Update responder with new contact
1783 let mut update = proto::UpdateContacts::default();
1784 if accept {
1785 update
1786 .contacts
1787 .push(contact_for_user(requester_id, false, requester_busy, &pool));
1788 }
1789 update
1790 .remove_incoming_requests
1791 .push(requester_id.to_proto());
1792 for connection_id in pool.user_connection_ids(responder_id) {
1793 session.peer.send(connection_id, update.clone())?;
1794 }
1795
1796 // Update requester with new contact
1797 let mut update = proto::UpdateContacts::default();
1798 if accept {
1799 update
1800 .contacts
1801 .push(contact_for_user(responder_id, true, responder_busy, &pool));
1802 }
1803 update
1804 .remove_outgoing_requests
1805 .push(responder_id.to_proto());
1806 for connection_id in pool.user_connection_ids(requester_id) {
1807 session.peer.send(connection_id, update.clone())?;
1808 }
1809 }
1810
1811 response.send(proto::Ack {})?;
1812 Ok(())
1813}
1814
1815async fn remove_contact(
1816 request: proto::RemoveContact,
1817 response: Response<proto::RemoveContact>,
1818 session: Session,
1819) -> Result<()> {
1820 let requester_id = session.user_id;
1821 let responder_id = UserId::from_proto(request.user_id);
1822 let db = session.db().await;
1823 db.remove_contact(requester_id, responder_id).await?;
1824
1825 let pool = session.connection_pool().await;
1826 // Update outgoing contact requests of requester
1827 let mut update = proto::UpdateContacts::default();
1828 update
1829 .remove_outgoing_requests
1830 .push(responder_id.to_proto());
1831 for connection_id in pool.user_connection_ids(requester_id) {
1832 session.peer.send(connection_id, update.clone())?;
1833 }
1834
1835 // Update incoming contact requests of responder
1836 let mut update = proto::UpdateContacts::default();
1837 update
1838 .remove_incoming_requests
1839 .push(requester_id.to_proto());
1840 for connection_id in pool.user_connection_ids(responder_id) {
1841 session.peer.send(connection_id, update.clone())?;
1842 }
1843
1844 response.send(proto::Ack {})?;
1845 Ok(())
1846}
1847
1848async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1849 let project_id = ProjectId::from_proto(request.project_id);
1850 let project_connection_ids = session
1851 .db()
1852 .await
1853 .project_connection_ids(project_id, session.connection_id)
1854 .await?;
1855 broadcast(
1856 session.connection_id,
1857 project_connection_ids.iter().copied(),
1858 |connection_id| {
1859 session
1860 .peer
1861 .forward_send(session.connection_id, connection_id, request.clone())
1862 },
1863 );
1864 Ok(())
1865}
1866
1867async fn get_private_user_info(
1868 _request: proto::GetPrivateUserInfo,
1869 response: Response<proto::GetPrivateUserInfo>,
1870 session: Session,
1871) -> Result<()> {
1872 let metrics_id = session
1873 .db()
1874 .await
1875 .get_user_metrics_id(session.user_id)
1876 .await?;
1877 let user = session
1878 .db()
1879 .await
1880 .get_user_by_id(session.user_id)
1881 .await?
1882 .ok_or_else(|| anyhow!("user not found"))?;
1883 response.send(proto::GetPrivateUserInfoResponse {
1884 metrics_id,
1885 staff: user.admin,
1886 })?;
1887 Ok(())
1888}
1889
1890fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1891 match message {
1892 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1893 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1894 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1895 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1896 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1897 code: frame.code.into(),
1898 reason: frame.reason,
1899 })),
1900 }
1901}
1902
1903fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1904 match message {
1905 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1906 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1907 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1908 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1909 AxumMessage::Close(frame) => {
1910 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1911 code: frame.code.into(),
1912 reason: frame.reason,
1913 }))
1914 }
1915 }
1916}
1917
1918fn build_initial_contacts_update(
1919 contacts: Vec<db::Contact>,
1920 pool: &ConnectionPool,
1921) -> proto::UpdateContacts {
1922 let mut update = proto::UpdateContacts::default();
1923
1924 for contact in contacts {
1925 match contact {
1926 db::Contact::Accepted {
1927 user_id,
1928 should_notify,
1929 busy,
1930 } => {
1931 update
1932 .contacts
1933 .push(contact_for_user(user_id, should_notify, busy, &pool));
1934 }
1935 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
1936 db::Contact::Incoming {
1937 user_id,
1938 should_notify,
1939 } => update
1940 .incoming_requests
1941 .push(proto::IncomingContactRequest {
1942 requester_id: user_id.to_proto(),
1943 should_notify,
1944 }),
1945 }
1946 }
1947
1948 update
1949}
1950
1951fn contact_for_user(
1952 user_id: UserId,
1953 should_notify: bool,
1954 busy: bool,
1955 pool: &ConnectionPool,
1956) -> proto::Contact {
1957 proto::Contact {
1958 user_id: user_id.to_proto(),
1959 online: pool.is_user_online(user_id),
1960 busy,
1961 should_notify,
1962 }
1963}
1964
1965fn room_updated(room: &proto::Room, peer: &Peer) {
1966 for participant in &room.participants {
1967 if let Some(peer_id) = participant
1968 .peer_id
1969 .ok_or_else(|| anyhow!("invalid participant peer id"))
1970 .trace_err()
1971 {
1972 peer.send(
1973 peer_id.into(),
1974 proto::RoomUpdated {
1975 room: Some(room.clone()),
1976 },
1977 )
1978 .trace_err();
1979 }
1980 }
1981}
1982
1983async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
1984 let db = session.db().await;
1985 let contacts = db.get_contacts(user_id).await?;
1986 let busy = db.is_user_busy(user_id).await?;
1987
1988 let pool = session.connection_pool().await;
1989 let updated_contact = contact_for_user(user_id, false, busy, &pool);
1990 for contact in contacts {
1991 if let db::Contact::Accepted {
1992 user_id: contact_user_id,
1993 ..
1994 } = contact
1995 {
1996 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
1997 session
1998 .peer
1999 .send(
2000 contact_conn_id,
2001 proto::UpdateContacts {
2002 contacts: vec![updated_contact.clone()],
2003 remove_contacts: Default::default(),
2004 incoming_requests: Default::default(),
2005 remove_incoming_requests: Default::default(),
2006 outgoing_requests: Default::default(),
2007 remove_outgoing_requests: Default::default(),
2008 },
2009 )
2010 .trace_err();
2011 }
2012 }
2013 }
2014 Ok(())
2015}
2016
2017async fn leave_room_for_session(session: &Session) -> Result<()> {
2018 let mut contacts_to_update = HashSet::default();
2019
2020 let room_id;
2021 let canceled_calls_to_user_ids;
2022 let live_kit_room;
2023 let delete_live_kit_room;
2024 {
2025 let mut left_room = session.db().await.leave_room(session.connection_id).await?;
2026 contacts_to_update.insert(session.user_id);
2027
2028 for project in left_room.left_projects.values() {
2029 project_left(project, session);
2030 }
2031
2032 room_updated(&left_room.room, &session.peer);
2033 room_id = RoomId::from_proto(left_room.room.id);
2034 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2035 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2036 delete_live_kit_room = left_room.room.participants.is_empty();
2037 }
2038
2039 {
2040 let pool = session.connection_pool().await;
2041 for canceled_user_id in canceled_calls_to_user_ids {
2042 for connection_id in pool.user_connection_ids(canceled_user_id) {
2043 session
2044 .peer
2045 .send(
2046 connection_id,
2047 proto::CallCanceled {
2048 room_id: room_id.to_proto(),
2049 },
2050 )
2051 .trace_err();
2052 }
2053 contacts_to_update.insert(canceled_user_id);
2054 }
2055 }
2056
2057 for contact_user_id in contacts_to_update {
2058 update_user_contacts(contact_user_id, &session).await?;
2059 }
2060
2061 if let Some(live_kit) = session.live_kit_client.as_ref() {
2062 live_kit
2063 .remove_participant(live_kit_room.clone(), session.connection_id.to_string())
2064 .await
2065 .trace_err();
2066
2067 if delete_live_kit_room {
2068 live_kit.delete_room(live_kit_room).await.trace_err();
2069 }
2070 }
2071
2072 Ok(())
2073}
2074
2075fn project_left(project: &db::LeftProject, session: &Session) {
2076 for connection_id in &project.connection_ids {
2077 if project.host_user_id == session.user_id {
2078 session
2079 .peer
2080 .send(
2081 *connection_id,
2082 proto::UnshareProject {
2083 project_id: project.id.to_proto(),
2084 },
2085 )
2086 .trace_err();
2087 } else {
2088 session
2089 .peer
2090 .send(
2091 *connection_id,
2092 proto::RemoveProjectCollaborator {
2093 project_id: project.id.to_proto(),
2094 peer_id: Some(session.connection_id.into()),
2095 },
2096 )
2097 .trace_err();
2098 }
2099 }
2100
2101 session
2102 .peer
2103 .send(
2104 session.connection_id,
2105 proto::UnshareProject {
2106 project_id: project.id.to_proto(),
2107 },
2108 )
2109 .trace_err();
2110}
2111
2112pub trait ResultExt {
2113 type Ok;
2114
2115 fn trace_err(self) -> Option<Self::Ok>;
2116}
2117
2118impl<T, E> ResultExt for Result<T, E>
2119where
2120 E: std::fmt::Debug,
2121{
2122 type Ok = T;
2123
2124 fn trace_err(self) -> Option<T> {
2125 match self {
2126 Ok(value) => Some(value),
2127 Err(error) => {
2128 tracing::error!("{:?}", error);
2129 None
2130 }
2131 }
2132 }
2133}