1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, Database, ProjectId, RoomId, ServerId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
38 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
39};
40use serde::{Serialize, Serializer};
41use std::{
42 any::TypeId,
43 fmt,
44 future::Future,
45 marker::PhantomData,
46 mem,
47 net::SocketAddr,
48 ops::{Deref, DerefMut},
49 rc::Rc,
50 sync::{
51 atomic::{AtomicBool, Ordering::SeqCst},
52 Arc,
53 },
54 time::Duration,
55};
56use tokio::sync::watch;
57use tower::ServiceBuilder;
58use tracing::{info_span, instrument, Instrument};
59
60pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
61pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
62
63lazy_static! {
64 static ref METRIC_CONNECTIONS: IntGauge =
65 register_int_gauge!("connections", "number of connections").unwrap();
66 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
67 "shared_projects",
68 "number of open projects with one or more guests"
69 )
70 .unwrap();
71}
72
73type MessageHandler =
74 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
75
76struct Response<R> {
77 peer: Arc<Peer>,
78 receipt: Receipt<R>,
79 responded: Arc<AtomicBool>,
80}
81
82impl<R: RequestMessage> Response<R> {
83 fn send(self, payload: R::Response) -> Result<()> {
84 self.responded.store(true, SeqCst);
85 self.peer.respond(self.receipt, payload)?;
86 Ok(())
87 }
88}
89
90#[derive(Clone)]
91struct Session {
92 user_id: UserId,
93 connection_id: ConnectionId,
94 db: Arc<tokio::sync::Mutex<DbHandle>>,
95 peer: Arc<Peer>,
96 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
97 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
98 executor: Executor,
99}
100
101impl Session {
102 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
103 #[cfg(test)]
104 tokio::task::yield_now().await;
105 let guard = self.db.lock().await;
106 #[cfg(test)]
107 tokio::task::yield_now().await;
108 guard
109 }
110
111 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
112 #[cfg(test)]
113 tokio::task::yield_now().await;
114 let guard = self.connection_pool.lock();
115 ConnectionPoolGuard {
116 guard,
117 _not_send: PhantomData,
118 }
119 }
120}
121
122impl fmt::Debug for Session {
123 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124 f.debug_struct("Session")
125 .field("user_id", &self.user_id)
126 .field("connection_id", &self.connection_id)
127 .finish()
128 }
129}
130
131struct DbHandle(Arc<Database>);
132
133impl Deref for DbHandle {
134 type Target = Database;
135
136 fn deref(&self) -> &Self::Target {
137 self.0.as_ref()
138 }
139}
140
141pub struct Server {
142 id: parking_lot::Mutex<ServerId>,
143 peer: Arc<Peer>,
144 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
145 app_state: Arc<AppState>,
146 executor: Executor,
147 handlers: HashMap<TypeId, MessageHandler>,
148 teardown: watch::Sender<()>,
149}
150
151pub(crate) struct ConnectionPoolGuard<'a> {
152 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
153 _not_send: PhantomData<Rc<()>>,
154}
155
156#[derive(Serialize)]
157pub struct ServerSnapshot<'a> {
158 peer: &'a Peer,
159 #[serde(serialize_with = "serialize_deref")]
160 connection_pool: ConnectionPoolGuard<'a>,
161}
162
163pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
164where
165 S: Serializer,
166 T: Deref<Target = U>,
167 U: Serialize,
168{
169 Serialize::serialize(value.deref(), serializer)
170}
171
172impl Server {
173 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
174 let mut server = Self {
175 id: parking_lot::Mutex::new(id),
176 peer: Peer::new(id.0 as u32),
177 app_state,
178 executor,
179 connection_pool: Default::default(),
180 handlers: Default::default(),
181 teardown: watch::channel(()).0,
182 };
183
184 server
185 .add_request_handler(ping)
186 .add_request_handler(create_room)
187 .add_request_handler(join_room)
188 .add_request_handler(rejoin_room)
189 .add_request_handler(leave_room)
190 .add_request_handler(call)
191 .add_request_handler(cancel_call)
192 .add_message_handler(decline_call)
193 .add_request_handler(update_participant_location)
194 .add_request_handler(share_project)
195 .add_message_handler(unshare_project)
196 .add_request_handler(join_project)
197 .add_message_handler(leave_project)
198 .add_request_handler(update_project)
199 .add_request_handler(update_worktree)
200 .add_message_handler(start_language_server)
201 .add_message_handler(update_language_server)
202 .add_message_handler(update_diagnostic_summary)
203 .add_request_handler(forward_project_request::<proto::GetHover>)
204 .add_request_handler(forward_project_request::<proto::GetDefinition>)
205 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
206 .add_request_handler(forward_project_request::<proto::GetReferences>)
207 .add_request_handler(forward_project_request::<proto::SearchProject>)
208 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
209 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
210 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
211 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
212 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
213 .add_request_handler(forward_project_request::<proto::GetCompletions>)
214 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
215 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
216 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
217 .add_request_handler(forward_project_request::<proto::PrepareRename>)
218 .add_request_handler(forward_project_request::<proto::PerformRename>)
219 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
220 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
221 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
222 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
223 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
224 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
225 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
226 .add_message_handler(create_buffer_for_peer)
227 .add_request_handler(update_buffer)
228 .add_message_handler(update_buffer_file)
229 .add_message_handler(buffer_reloaded)
230 .add_message_handler(buffer_saved)
231 .add_request_handler(save_buffer)
232 .add_request_handler(get_users)
233 .add_request_handler(fuzzy_search_users)
234 .add_request_handler(request_contact)
235 .add_request_handler(remove_contact)
236 .add_request_handler(respond_to_contact_request)
237 .add_request_handler(follow)
238 .add_message_handler(unfollow)
239 .add_message_handler(update_followers)
240 .add_message_handler(update_diff_base)
241 .add_request_handler(get_private_user_info);
242
243 Arc::new(server)
244 }
245
246 pub async fn start(&self) -> Result<()> {
247 let server_id = *self.id.lock();
248 let app_state = self.app_state.clone();
249 let peer = self.peer.clone();
250 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
251 let pool = self.connection_pool.clone();
252 let live_kit_client = self.app_state.live_kit_client.clone();
253
254 let span = info_span!("start server");
255 self.executor.spawn_detached(
256 async move {
257 tracing::info!("waiting for cleanup timeout");
258 timeout.await;
259 tracing::info!("cleanup timeout expired, retrieving stale rooms");
260 if let Some(room_ids) = app_state
261 .db
262 .stale_room_ids(&app_state.config.zed_environment, server_id)
263 .await
264 .trace_err()
265 {
266 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
267 for room_id in room_ids {
268 let mut contacts_to_update = HashSet::default();
269 let mut canceled_calls_to_user_ids = Vec::new();
270 let mut live_kit_room = String::new();
271 let mut delete_live_kit_room = false;
272
273 if let Some(mut refreshed_room) = app_state
274 .db
275 .refresh_room(room_id, server_id)
276 .await
277 .trace_err()
278 {
279 tracing::info!(
280 room_id = room_id.0,
281 new_participant_count = refreshed_room.room.participants.len(),
282 "refreshed room"
283 );
284 room_updated(&refreshed_room.room, &peer);
285 contacts_to_update
286 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
287 contacts_to_update
288 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
289 canceled_calls_to_user_ids =
290 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
291 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
292 delete_live_kit_room = refreshed_room.room.participants.is_empty();
293 }
294
295 {
296 let pool = pool.lock();
297 for canceled_user_id in canceled_calls_to_user_ids {
298 for connection_id in pool.user_connection_ids(canceled_user_id) {
299 peer.send(
300 connection_id,
301 proto::CallCanceled {
302 room_id: room_id.to_proto(),
303 },
304 )
305 .trace_err();
306 }
307 }
308 }
309
310 for user_id in contacts_to_update {
311 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
312 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
313 if let Some((busy, contacts)) = busy.zip(contacts) {
314 let pool = pool.lock();
315 let updated_contact = contact_for_user(user_id, false, busy, &pool);
316 for contact in contacts {
317 if let db::Contact::Accepted {
318 user_id: contact_user_id,
319 ..
320 } = contact
321 {
322 for contact_conn_id in
323 pool.user_connection_ids(contact_user_id)
324 {
325 peer.send(
326 contact_conn_id,
327 proto::UpdateContacts {
328 contacts: vec![updated_contact.clone()],
329 remove_contacts: Default::default(),
330 incoming_requests: Default::default(),
331 remove_incoming_requests: Default::default(),
332 outgoing_requests: Default::default(),
333 remove_outgoing_requests: Default::default(),
334 },
335 )
336 .trace_err();
337 }
338 }
339 }
340 }
341 }
342
343 if let Some(live_kit) = live_kit_client.as_ref() {
344 if delete_live_kit_room {
345 live_kit.delete_room(live_kit_room).await.trace_err();
346 }
347 }
348 }
349 }
350
351 app_state
352 .db
353 .delete_stale_servers(&app_state.config.zed_environment, server_id)
354 .await
355 .trace_err();
356 }
357 .instrument(span),
358 );
359 Ok(())
360 }
361
362 pub fn teardown(&self) {
363 self.peer.teardown();
364 self.connection_pool.lock().reset();
365 let _ = self.teardown.send(());
366 }
367
368 #[cfg(test)]
369 pub fn reset(&self, id: ServerId) {
370 self.teardown();
371 *self.id.lock() = id;
372 self.peer.reset(id.0 as u32);
373 }
374
375 #[cfg(test)]
376 pub fn id(&self) -> ServerId {
377 *self.id.lock()
378 }
379
380 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
381 where
382 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
383 Fut: 'static + Send + Future<Output = Result<()>>,
384 M: EnvelopedMessage,
385 {
386 let prev_handler = self.handlers.insert(
387 TypeId::of::<M>(),
388 Box::new(move |envelope, session| {
389 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
390 let span = info_span!(
391 "handle message",
392 payload_type = envelope.payload_type_name()
393 );
394 span.in_scope(|| {
395 tracing::info!(
396 payload_type = envelope.payload_type_name(),
397 "message received"
398 );
399 });
400 let future = (handler)(*envelope, session);
401 async move {
402 if let Err(error) = future.await {
403 tracing::error!(%error, "error handling message");
404 }
405 }
406 .instrument(span)
407 .boxed()
408 }),
409 );
410 if prev_handler.is_some() {
411 panic!("registered a handler for the same message twice");
412 }
413 self
414 }
415
416 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
417 where
418 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
419 Fut: 'static + Send + Future<Output = Result<()>>,
420 M: EnvelopedMessage,
421 {
422 self.add_handler(move |envelope, session| handler(envelope.payload, session));
423 self
424 }
425
426 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
427 where
428 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
429 Fut: Send + Future<Output = Result<()>>,
430 M: RequestMessage,
431 {
432 let handler = Arc::new(handler);
433 self.add_handler(move |envelope, session| {
434 let receipt = envelope.receipt();
435 let handler = handler.clone();
436 async move {
437 let peer = session.peer.clone();
438 let responded = Arc::new(AtomicBool::default());
439 let response = Response {
440 peer: peer.clone(),
441 responded: responded.clone(),
442 receipt,
443 };
444 match (handler)(envelope.payload, response, session).await {
445 Ok(()) => {
446 if responded.load(std::sync::atomic::Ordering::SeqCst) {
447 Ok(())
448 } else {
449 Err(anyhow!("handler did not send a response"))?
450 }
451 }
452 Err(error) => {
453 peer.respond_with_error(
454 receipt,
455 proto::Error {
456 message: error.to_string(),
457 },
458 )?;
459 Err(error)
460 }
461 }
462 }
463 })
464 }
465
466 pub fn handle_connection(
467 self: &Arc<Self>,
468 connection: Connection,
469 address: String,
470 user: User,
471 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
472 executor: Executor,
473 ) -> impl Future<Output = Result<()>> {
474 let this = self.clone();
475 let user_id = user.id;
476 let login = user.github_login;
477 let span = info_span!("handle connection", %user_id, %login, %address);
478 let mut teardown = self.teardown.subscribe();
479 async move {
480 let (connection_id, handle_io, mut incoming_rx) = this
481 .peer
482 .add_connection(connection, {
483 let executor = executor.clone();
484 move |duration| executor.sleep(duration)
485 });
486
487 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
488 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
489 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
490
491 if let Some(send_connection_id) = send_connection_id.take() {
492 let _ = send_connection_id.send(connection_id);
493 }
494
495 if !user.connected_once {
496 this.peer.send(connection_id, proto::ShowContacts {})?;
497 this.app_state.db.set_user_connected_once(user_id, true).await?;
498 }
499
500 let (contacts, invite_code) = future::try_join(
501 this.app_state.db.get_contacts(user_id),
502 this.app_state.db.get_invite_code_for_user(user_id)
503 ).await?;
504
505 {
506 let mut pool = this.connection_pool.lock();
507 pool.add_connection(connection_id, user_id, user.admin);
508 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
509
510 if let Some((code, count)) = invite_code {
511 this.peer.send(connection_id, proto::UpdateInviteInfo {
512 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
513 count: count as u32,
514 })?;
515 }
516 }
517
518 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
519 this.peer.send(connection_id, incoming_call)?;
520 }
521
522 let session = Session {
523 user_id,
524 connection_id,
525 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
526 peer: this.peer.clone(),
527 connection_pool: this.connection_pool.clone(),
528 live_kit_client: this.app_state.live_kit_client.clone(),
529 executor: executor.clone(),
530 };
531 update_user_contacts(user_id, &session).await?;
532
533 let handle_io = handle_io.fuse();
534 futures::pin_mut!(handle_io);
535
536 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
537 // This prevents deadlocks when e.g., client A performs a request to client B and
538 // client B performs a request to client A. If both clients stop processing further
539 // messages until their respective request completes, they won't have a chance to
540 // respond to the other client's request and cause a deadlock.
541 //
542 // This arrangement ensures we will attempt to process earlier messages first, but fall
543 // back to processing messages arrived later in the spirit of making progress.
544 let mut foreground_message_handlers = FuturesUnordered::new();
545 loop {
546 let next_message = incoming_rx.next().fuse();
547 futures::pin_mut!(next_message);
548 futures::select_biased! {
549 _ = teardown.changed().fuse() => return Ok(()),
550 result = handle_io => {
551 if let Err(error) = result {
552 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
553 }
554 break;
555 }
556 _ = foreground_message_handlers.next() => {}
557 message = next_message => {
558 if let Some(message) = message {
559 let type_name = message.payload_type_name();
560 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
561 let span_enter = span.enter();
562 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
563 let is_background = message.is_background();
564 let handle_message = (handler)(message, session.clone());
565 drop(span_enter);
566
567 let handle_message = handle_message.instrument(span);
568 if is_background {
569 executor.spawn_detached(handle_message);
570 } else {
571 foreground_message_handlers.push(handle_message);
572 }
573 } else {
574 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
575 }
576 } else {
577 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
578 break;
579 }
580 }
581 }
582 }
583
584 drop(foreground_message_handlers);
585 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
586 if let Err(error) = connection_lost(session, teardown, executor).await {
587 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
588 }
589
590 Ok(())
591 }.instrument(span)
592 }
593
594 pub async fn invite_code_redeemed(
595 self: &Arc<Self>,
596 inviter_id: UserId,
597 invitee_id: UserId,
598 ) -> Result<()> {
599 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
600 if let Some(code) = &user.invite_code {
601 let pool = self.connection_pool.lock();
602 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
603 for connection_id in pool.user_connection_ids(inviter_id) {
604 self.peer.send(
605 connection_id,
606 proto::UpdateContacts {
607 contacts: vec![invitee_contact.clone()],
608 ..Default::default()
609 },
610 )?;
611 self.peer.send(
612 connection_id,
613 proto::UpdateInviteInfo {
614 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
615 count: user.invite_count as u32,
616 },
617 )?;
618 }
619 }
620 }
621 Ok(())
622 }
623
624 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
625 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
626 if let Some(invite_code) = &user.invite_code {
627 let pool = self.connection_pool.lock();
628 for connection_id in pool.user_connection_ids(user_id) {
629 self.peer.send(
630 connection_id,
631 proto::UpdateInviteInfo {
632 url: format!(
633 "{}{}",
634 self.app_state.config.invite_link_prefix, invite_code
635 ),
636 count: user.invite_count as u32,
637 },
638 )?;
639 }
640 }
641 }
642 Ok(())
643 }
644
645 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
646 ServerSnapshot {
647 connection_pool: ConnectionPoolGuard {
648 guard: self.connection_pool.lock(),
649 _not_send: PhantomData,
650 },
651 peer: &self.peer,
652 }
653 }
654}
655
656impl<'a> Deref for ConnectionPoolGuard<'a> {
657 type Target = ConnectionPool;
658
659 fn deref(&self) -> &Self::Target {
660 &*self.guard
661 }
662}
663
664impl<'a> DerefMut for ConnectionPoolGuard<'a> {
665 fn deref_mut(&mut self) -> &mut Self::Target {
666 &mut *self.guard
667 }
668}
669
670impl<'a> Drop for ConnectionPoolGuard<'a> {
671 fn drop(&mut self) {
672 #[cfg(test)]
673 self.check_invariants();
674 }
675}
676
677fn broadcast<F>(
678 sender_id: Option<ConnectionId>,
679 receiver_ids: impl IntoIterator<Item = ConnectionId>,
680 mut f: F,
681) where
682 F: FnMut(ConnectionId) -> anyhow::Result<()>,
683{
684 for receiver_id in receiver_ids {
685 if Some(receiver_id) != sender_id {
686 if let Err(error) = f(receiver_id) {
687 tracing::error!("failed to send to {:?} {}", receiver_id, error);
688 }
689 }
690 }
691}
692
693lazy_static! {
694 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
695}
696
697pub struct ProtocolVersion(u32);
698
699impl Header for ProtocolVersion {
700 fn name() -> &'static HeaderName {
701 &ZED_PROTOCOL_VERSION
702 }
703
704 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
705 where
706 Self: Sized,
707 I: Iterator<Item = &'i axum::http::HeaderValue>,
708 {
709 let version = values
710 .next()
711 .ok_or_else(axum::headers::Error::invalid)?
712 .to_str()
713 .map_err(|_| axum::headers::Error::invalid())?
714 .parse()
715 .map_err(|_| axum::headers::Error::invalid())?;
716 Ok(Self(version))
717 }
718
719 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
720 values.extend([self.0.to_string().parse().unwrap()]);
721 }
722}
723
724pub fn routes(server: Arc<Server>) -> Router<Body> {
725 Router::new()
726 .route("/rpc", get(handle_websocket_request))
727 .layer(
728 ServiceBuilder::new()
729 .layer(Extension(server.app_state.clone()))
730 .layer(middleware::from_fn(auth::validate_header)),
731 )
732 .route("/metrics", get(handle_metrics))
733 .layer(Extension(server))
734}
735
736pub async fn handle_websocket_request(
737 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
738 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
739 Extension(server): Extension<Arc<Server>>,
740 Extension(user): Extension<User>,
741 ws: WebSocketUpgrade,
742) -> axum::response::Response {
743 if protocol_version != rpc::PROTOCOL_VERSION {
744 return (
745 StatusCode::UPGRADE_REQUIRED,
746 "client must be upgraded".to_string(),
747 )
748 .into_response();
749 }
750 let socket_address = socket_address.to_string();
751 ws.on_upgrade(move |socket| {
752 use util::ResultExt;
753 let socket = socket
754 .map_ok(to_tungstenite_message)
755 .err_into()
756 .with(|message| async move { Ok(to_axum_message(message)) });
757 let connection = Connection::new(Box::pin(socket));
758 async move {
759 server
760 .handle_connection(connection, socket_address, user, None, Executor::Production)
761 .await
762 .log_err();
763 }
764 })
765}
766
767pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
768 let connections = server
769 .connection_pool
770 .lock()
771 .connections()
772 .filter(|connection| !connection.admin)
773 .count();
774
775 METRIC_CONNECTIONS.set(connections as _);
776
777 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
778 METRIC_SHARED_PROJECTS.set(shared_projects as _);
779
780 let encoder = prometheus::TextEncoder::new();
781 let metric_families = prometheus::gather();
782 let encoded_metrics = encoder
783 .encode_to_string(&metric_families)
784 .map_err(|err| anyhow!("{}", err))?;
785 Ok(encoded_metrics)
786}
787
788#[instrument(err, skip(executor))]
789async fn connection_lost(
790 session: Session,
791 mut teardown: watch::Receiver<()>,
792 executor: Executor,
793) -> Result<()> {
794 session.peer.disconnect(session.connection_id);
795 session
796 .connection_pool()
797 .await
798 .remove_connection(session.connection_id)?;
799
800 session
801 .db()
802 .await
803 .connection_lost(session.connection_id)
804 .await
805 .trace_err();
806
807 futures::select_biased! {
808 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
809 leave_room_for_session(&session).await.trace_err();
810
811 if !session
812 .connection_pool()
813 .await
814 .is_user_online(session.user_id)
815 {
816 let db = session.db().await;
817 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
818 room_updated(&room, &session.peer);
819 }
820 }
821 update_user_contacts(session.user_id, &session).await?;
822 }
823 _ = teardown.changed().fuse() => {}
824 }
825
826 Ok(())
827}
828
829async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
830 response.send(proto::Ack {})?;
831 Ok(())
832}
833
834async fn create_room(
835 _request: proto::CreateRoom,
836 response: Response<proto::CreateRoom>,
837 session: Session,
838) -> Result<()> {
839 let live_kit_room = nanoid::nanoid!(30);
840 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
841 if let Some(_) = live_kit
842 .create_room(live_kit_room.clone())
843 .await
844 .trace_err()
845 {
846 if let Some(token) = live_kit
847 .room_token(&live_kit_room, &session.user_id.to_string())
848 .trace_err()
849 {
850 Some(proto::LiveKitConnectionInfo {
851 server_url: live_kit.url().into(),
852 token,
853 })
854 } else {
855 None
856 }
857 } else {
858 None
859 }
860 } else {
861 None
862 };
863
864 {
865 let room = session
866 .db()
867 .await
868 .create_room(session.user_id, session.connection_id, &live_kit_room)
869 .await?;
870
871 response.send(proto::CreateRoomResponse {
872 room: Some(room.clone()),
873 live_kit_connection_info,
874 })?;
875 }
876
877 update_user_contacts(session.user_id, &session).await?;
878 Ok(())
879}
880
881async fn join_room(
882 request: proto::JoinRoom,
883 response: Response<proto::JoinRoom>,
884 session: Session,
885) -> Result<()> {
886 let room_id = RoomId::from_proto(request.id);
887 let room = {
888 let room = session
889 .db()
890 .await
891 .join_room(room_id, session.user_id, session.connection_id)
892 .await?;
893 room_updated(&room, &session.peer);
894 room.clone()
895 };
896
897 for connection_id in session
898 .connection_pool()
899 .await
900 .user_connection_ids(session.user_id)
901 {
902 session
903 .peer
904 .send(
905 connection_id,
906 proto::CallCanceled {
907 room_id: room_id.to_proto(),
908 },
909 )
910 .trace_err();
911 }
912
913 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
914 if let Some(token) = live_kit
915 .room_token(&room.live_kit_room, &session.user_id.to_string())
916 .trace_err()
917 {
918 Some(proto::LiveKitConnectionInfo {
919 server_url: live_kit.url().into(),
920 token,
921 })
922 } else {
923 None
924 }
925 } else {
926 None
927 };
928
929 response.send(proto::JoinRoomResponse {
930 room: Some(room),
931 live_kit_connection_info,
932 })?;
933
934 update_user_contacts(session.user_id, &session).await?;
935 Ok(())
936}
937
938async fn rejoin_room(
939 request: proto::RejoinRoom,
940 response: Response<proto::RejoinRoom>,
941 session: Session,
942) -> Result<()> {
943 {
944 let mut rejoined_room = session
945 .db()
946 .await
947 .rejoin_room(request, session.user_id, session.connection_id)
948 .await?;
949
950 response.send(proto::RejoinRoomResponse {
951 room: Some(rejoined_room.room.clone()),
952 reshared_projects: rejoined_room
953 .reshared_projects
954 .iter()
955 .map(|project| proto::ResharedProject {
956 id: project.id.to_proto(),
957 collaborators: project
958 .collaborators
959 .iter()
960 .map(|collaborator| collaborator.to_proto())
961 .collect(),
962 })
963 .collect(),
964 rejoined_projects: rejoined_room
965 .rejoined_projects
966 .iter()
967 .map(|rejoined_project| proto::RejoinedProject {
968 id: rejoined_project.id.to_proto(),
969 worktrees: rejoined_project
970 .worktrees
971 .iter()
972 .map(|worktree| proto::WorktreeMetadata {
973 id: worktree.id,
974 root_name: worktree.root_name.clone(),
975 visible: worktree.visible,
976 abs_path: worktree.abs_path.clone(),
977 })
978 .collect(),
979 collaborators: rejoined_project
980 .collaborators
981 .iter()
982 .map(|collaborator| collaborator.to_proto())
983 .collect(),
984 language_servers: rejoined_project.language_servers.clone(),
985 })
986 .collect(),
987 })?;
988 room_updated(&rejoined_room.room, &session.peer);
989
990 for project in &rejoined_room.reshared_projects {
991 for collaborator in &project.collaborators {
992 session
993 .peer
994 .send(
995 collaborator.connection_id,
996 proto::UpdateProjectCollaborator {
997 project_id: project.id.to_proto(),
998 old_peer_id: Some(project.old_connection_id.into()),
999 new_peer_id: Some(session.connection_id.into()),
1000 },
1001 )
1002 .trace_err();
1003 }
1004
1005 broadcast(
1006 Some(session.connection_id),
1007 project
1008 .collaborators
1009 .iter()
1010 .map(|collaborator| collaborator.connection_id),
1011 |connection_id| {
1012 session.peer.forward_send(
1013 session.connection_id,
1014 connection_id,
1015 proto::UpdateProject {
1016 project_id: project.id.to_proto(),
1017 worktrees: project.worktrees.clone(),
1018 },
1019 )
1020 },
1021 );
1022 }
1023
1024 for project in &rejoined_room.rejoined_projects {
1025 for collaborator in &project.collaborators {
1026 session
1027 .peer
1028 .send(
1029 collaborator.connection_id,
1030 proto::UpdateProjectCollaborator {
1031 project_id: project.id.to_proto(),
1032 old_peer_id: Some(project.old_connection_id.into()),
1033 new_peer_id: Some(session.connection_id.into()),
1034 },
1035 )
1036 .trace_err();
1037 }
1038 }
1039
1040 for project in &mut rejoined_room.rejoined_projects {
1041 for worktree in mem::take(&mut project.worktrees) {
1042 #[cfg(any(test, feature = "test-support"))]
1043 const MAX_CHUNK_SIZE: usize = 2;
1044 #[cfg(not(any(test, feature = "test-support")))]
1045 const MAX_CHUNK_SIZE: usize = 256;
1046
1047 // Stream this worktree's entries.
1048 let message = proto::UpdateWorktree {
1049 project_id: project.id.to_proto(),
1050 worktree_id: worktree.id,
1051 abs_path: worktree.abs_path.clone(),
1052 root_name: worktree.root_name,
1053 updated_entries: worktree.updated_entries,
1054 removed_entries: worktree.removed_entries,
1055 scan_id: worktree.scan_id,
1056 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1057 };
1058 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1059 session.peer.send(session.connection_id, update.clone())?;
1060 }
1061
1062 // Stream this worktree's diagnostics.
1063 for summary in worktree.diagnostic_summaries {
1064 session.peer.send(
1065 session.connection_id,
1066 proto::UpdateDiagnosticSummary {
1067 project_id: project.id.to_proto(),
1068 worktree_id: worktree.id,
1069 summary: Some(summary),
1070 },
1071 )?;
1072 }
1073 }
1074
1075 for language_server in &project.language_servers {
1076 session.peer.send(
1077 session.connection_id,
1078 proto::UpdateLanguageServer {
1079 project_id: project.id.to_proto(),
1080 language_server_id: language_server.id,
1081 variant: Some(
1082 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1083 proto::LspDiskBasedDiagnosticsUpdated {},
1084 ),
1085 ),
1086 },
1087 )?;
1088 }
1089 }
1090 }
1091
1092 update_user_contacts(session.user_id, &session).await?;
1093 Ok(())
1094}
1095
1096async fn leave_room(
1097 _: proto::LeaveRoom,
1098 response: Response<proto::LeaveRoom>,
1099 session: Session,
1100) -> Result<()> {
1101 leave_room_for_session(&session).await?;
1102 response.send(proto::Ack {})?;
1103 Ok(())
1104}
1105
1106async fn call(
1107 request: proto::Call,
1108 response: Response<proto::Call>,
1109 session: Session,
1110) -> Result<()> {
1111 let room_id = RoomId::from_proto(request.room_id);
1112 let calling_user_id = session.user_id;
1113 let calling_connection_id = session.connection_id;
1114 let called_user_id = UserId::from_proto(request.called_user_id);
1115 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1116 if !session
1117 .db()
1118 .await
1119 .has_contact(calling_user_id, called_user_id)
1120 .await?
1121 {
1122 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1123 }
1124
1125 let incoming_call = {
1126 let (room, incoming_call) = &mut *session
1127 .db()
1128 .await
1129 .call(
1130 room_id,
1131 calling_user_id,
1132 calling_connection_id,
1133 called_user_id,
1134 initial_project_id,
1135 )
1136 .await?;
1137 room_updated(&room, &session.peer);
1138 mem::take(incoming_call)
1139 };
1140 update_user_contacts(called_user_id, &session).await?;
1141
1142 let mut calls = session
1143 .connection_pool()
1144 .await
1145 .user_connection_ids(called_user_id)
1146 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1147 .collect::<FuturesUnordered<_>>();
1148
1149 while let Some(call_response) = calls.next().await {
1150 match call_response.as_ref() {
1151 Ok(_) => {
1152 response.send(proto::Ack {})?;
1153 return Ok(());
1154 }
1155 Err(_) => {
1156 call_response.trace_err();
1157 }
1158 }
1159 }
1160
1161 {
1162 let room = session
1163 .db()
1164 .await
1165 .call_failed(room_id, called_user_id)
1166 .await?;
1167 room_updated(&room, &session.peer);
1168 }
1169 update_user_contacts(called_user_id, &session).await?;
1170
1171 Err(anyhow!("failed to ring user"))?
1172}
1173
1174async fn cancel_call(
1175 request: proto::CancelCall,
1176 response: Response<proto::CancelCall>,
1177 session: Session,
1178) -> Result<()> {
1179 let called_user_id = UserId::from_proto(request.called_user_id);
1180 let room_id = RoomId::from_proto(request.room_id);
1181 {
1182 let room = session
1183 .db()
1184 .await
1185 .cancel_call(room_id, session.connection_id, called_user_id)
1186 .await?;
1187 room_updated(&room, &session.peer);
1188 }
1189
1190 for connection_id in session
1191 .connection_pool()
1192 .await
1193 .user_connection_ids(called_user_id)
1194 {
1195 session
1196 .peer
1197 .send(
1198 connection_id,
1199 proto::CallCanceled {
1200 room_id: room_id.to_proto(),
1201 },
1202 )
1203 .trace_err();
1204 }
1205 response.send(proto::Ack {})?;
1206
1207 update_user_contacts(called_user_id, &session).await?;
1208 Ok(())
1209}
1210
1211async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1212 let room_id = RoomId::from_proto(message.room_id);
1213 {
1214 let room = session
1215 .db()
1216 .await
1217 .decline_call(Some(room_id), session.user_id)
1218 .await?
1219 .ok_or_else(|| anyhow!("failed to decline call"))?;
1220 room_updated(&room, &session.peer);
1221 }
1222
1223 for connection_id in session
1224 .connection_pool()
1225 .await
1226 .user_connection_ids(session.user_id)
1227 {
1228 session
1229 .peer
1230 .send(
1231 connection_id,
1232 proto::CallCanceled {
1233 room_id: room_id.to_proto(),
1234 },
1235 )
1236 .trace_err();
1237 }
1238 update_user_contacts(session.user_id, &session).await?;
1239 Ok(())
1240}
1241
1242async fn update_participant_location(
1243 request: proto::UpdateParticipantLocation,
1244 response: Response<proto::UpdateParticipantLocation>,
1245 session: Session,
1246) -> Result<()> {
1247 let room_id = RoomId::from_proto(request.room_id);
1248 let location = request
1249 .location
1250 .ok_or_else(|| anyhow!("invalid location"))?;
1251 let room = session
1252 .db()
1253 .await
1254 .update_room_participant_location(room_id, session.connection_id, location)
1255 .await?;
1256 room_updated(&room, &session.peer);
1257 response.send(proto::Ack {})?;
1258 Ok(())
1259}
1260
1261async fn share_project(
1262 request: proto::ShareProject,
1263 response: Response<proto::ShareProject>,
1264 session: Session,
1265) -> Result<()> {
1266 let (project_id, room) = &*session
1267 .db()
1268 .await
1269 .share_project(
1270 RoomId::from_proto(request.room_id),
1271 session.connection_id,
1272 &request.worktrees,
1273 )
1274 .await?;
1275 response.send(proto::ShareProjectResponse {
1276 project_id: project_id.to_proto(),
1277 })?;
1278 room_updated(&room, &session.peer);
1279
1280 Ok(())
1281}
1282
1283async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1284 let project_id = ProjectId::from_proto(message.project_id);
1285
1286 let (room, guest_connection_ids) = &*session
1287 .db()
1288 .await
1289 .unshare_project(project_id, session.connection_id)
1290 .await?;
1291
1292 broadcast(
1293 Some(session.connection_id),
1294 guest_connection_ids.iter().copied(),
1295 |conn_id| session.peer.send(conn_id, message.clone()),
1296 );
1297 room_updated(&room, &session.peer);
1298
1299 Ok(())
1300}
1301
1302async fn join_project(
1303 request: proto::JoinProject,
1304 response: Response<proto::JoinProject>,
1305 session: Session,
1306) -> Result<()> {
1307 let project_id = ProjectId::from_proto(request.project_id);
1308 let guest_user_id = session.user_id;
1309
1310 tracing::info!(%project_id, "join project");
1311
1312 let (project, replica_id) = &mut *session
1313 .db()
1314 .await
1315 .join_project(project_id, session.connection_id)
1316 .await?;
1317
1318 let collaborators = project
1319 .collaborators
1320 .iter()
1321 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1322 .map(|collaborator| collaborator.to_proto())
1323 .collect::<Vec<_>>();
1324
1325 let worktrees = project
1326 .worktrees
1327 .iter()
1328 .map(|(id, worktree)| proto::WorktreeMetadata {
1329 id: *id,
1330 root_name: worktree.root_name.clone(),
1331 visible: worktree.visible,
1332 abs_path: worktree.abs_path.clone(),
1333 })
1334 .collect::<Vec<_>>();
1335
1336 for collaborator in &collaborators {
1337 session
1338 .peer
1339 .send(
1340 collaborator.peer_id.unwrap().into(),
1341 proto::AddProjectCollaborator {
1342 project_id: project_id.to_proto(),
1343 collaborator: Some(proto::Collaborator {
1344 peer_id: Some(session.connection_id.into()),
1345 replica_id: replica_id.0 as u32,
1346 user_id: guest_user_id.to_proto(),
1347 }),
1348 },
1349 )
1350 .trace_err();
1351 }
1352
1353 // First, we send the metadata associated with each worktree.
1354 response.send(proto::JoinProjectResponse {
1355 worktrees: worktrees.clone(),
1356 replica_id: replica_id.0 as u32,
1357 collaborators: collaborators.clone(),
1358 language_servers: project.language_servers.clone(),
1359 })?;
1360
1361 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1362 #[cfg(any(test, feature = "test-support"))]
1363 const MAX_CHUNK_SIZE: usize = 2;
1364 #[cfg(not(any(test, feature = "test-support")))]
1365 const MAX_CHUNK_SIZE: usize = 256;
1366
1367 // Stream this worktree's entries.
1368 let message = proto::UpdateWorktree {
1369 project_id: project_id.to_proto(),
1370 worktree_id,
1371 abs_path: worktree.abs_path.clone(),
1372 root_name: worktree.root_name,
1373 updated_entries: worktree.entries,
1374 removed_entries: Default::default(),
1375 scan_id: worktree.scan_id,
1376 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1377 };
1378 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1379 session.peer.send(session.connection_id, update.clone())?;
1380 }
1381
1382 // Stream this worktree's diagnostics.
1383 for summary in worktree.diagnostic_summaries {
1384 session.peer.send(
1385 session.connection_id,
1386 proto::UpdateDiagnosticSummary {
1387 project_id: project_id.to_proto(),
1388 worktree_id: worktree.id,
1389 summary: Some(summary),
1390 },
1391 )?;
1392 }
1393 }
1394
1395 for language_server in &project.language_servers {
1396 session.peer.send(
1397 session.connection_id,
1398 proto::UpdateLanguageServer {
1399 project_id: project_id.to_proto(),
1400 language_server_id: language_server.id,
1401 variant: Some(
1402 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1403 proto::LspDiskBasedDiagnosticsUpdated {},
1404 ),
1405 ),
1406 },
1407 )?;
1408 }
1409
1410 Ok(())
1411}
1412
1413async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1414 let sender_id = session.connection_id;
1415 let project_id = ProjectId::from_proto(request.project_id);
1416
1417 let (room, project) = &*session
1418 .db()
1419 .await
1420 .leave_project(project_id, sender_id)
1421 .await?;
1422 tracing::info!(
1423 %project_id,
1424 host_user_id = %project.host_user_id,
1425 host_connection_id = %project.host_connection_id,
1426 "leave project"
1427 );
1428
1429 project_left(&project, &session);
1430 room_updated(&room, &session.peer);
1431
1432 Ok(())
1433}
1434
1435async fn update_project(
1436 request: proto::UpdateProject,
1437 response: Response<proto::UpdateProject>,
1438 session: Session,
1439) -> Result<()> {
1440 let project_id = ProjectId::from_proto(request.project_id);
1441 let (room, guest_connection_ids) = &*session
1442 .db()
1443 .await
1444 .update_project(project_id, session.connection_id, &request.worktrees)
1445 .await?;
1446 broadcast(
1447 Some(session.connection_id),
1448 guest_connection_ids.iter().copied(),
1449 |connection_id| {
1450 session
1451 .peer
1452 .forward_send(session.connection_id, connection_id, request.clone())
1453 },
1454 );
1455 room_updated(&room, &session.peer);
1456 response.send(proto::Ack {})?;
1457
1458 Ok(())
1459}
1460
1461async fn update_worktree(
1462 request: proto::UpdateWorktree,
1463 response: Response<proto::UpdateWorktree>,
1464 session: Session,
1465) -> Result<()> {
1466 let guest_connection_ids = session
1467 .db()
1468 .await
1469 .update_worktree(&request, session.connection_id)
1470 .await?;
1471
1472 broadcast(
1473 Some(session.connection_id),
1474 guest_connection_ids.iter().copied(),
1475 |connection_id| {
1476 session
1477 .peer
1478 .forward_send(session.connection_id, connection_id, request.clone())
1479 },
1480 );
1481 response.send(proto::Ack {})?;
1482 Ok(())
1483}
1484
1485async fn update_diagnostic_summary(
1486 message: proto::UpdateDiagnosticSummary,
1487 session: Session,
1488) -> Result<()> {
1489 let guest_connection_ids = session
1490 .db()
1491 .await
1492 .update_diagnostic_summary(&message, session.connection_id)
1493 .await?;
1494
1495 broadcast(
1496 Some(session.connection_id),
1497 guest_connection_ids.iter().copied(),
1498 |connection_id| {
1499 session
1500 .peer
1501 .forward_send(session.connection_id, connection_id, message.clone())
1502 },
1503 );
1504
1505 Ok(())
1506}
1507
1508async fn start_language_server(
1509 request: proto::StartLanguageServer,
1510 session: Session,
1511) -> Result<()> {
1512 let guest_connection_ids = session
1513 .db()
1514 .await
1515 .start_language_server(&request, session.connection_id)
1516 .await?;
1517
1518 broadcast(
1519 Some(session.connection_id),
1520 guest_connection_ids.iter().copied(),
1521 |connection_id| {
1522 session
1523 .peer
1524 .forward_send(session.connection_id, connection_id, request.clone())
1525 },
1526 );
1527 Ok(())
1528}
1529
1530async fn update_language_server(
1531 request: proto::UpdateLanguageServer,
1532 session: Session,
1533) -> Result<()> {
1534 session.executor.record_backtrace();
1535 let project_id = ProjectId::from_proto(request.project_id);
1536 let project_connection_ids = session
1537 .db()
1538 .await
1539 .project_connection_ids(project_id, session.connection_id)
1540 .await?;
1541 broadcast(
1542 Some(session.connection_id),
1543 project_connection_ids.iter().copied(),
1544 |connection_id| {
1545 session
1546 .peer
1547 .forward_send(session.connection_id, connection_id, request.clone())
1548 },
1549 );
1550 Ok(())
1551}
1552
1553async fn forward_project_request<T>(
1554 request: T,
1555 response: Response<T>,
1556 session: Session,
1557) -> Result<()>
1558where
1559 T: EntityMessage + RequestMessage,
1560{
1561 session.executor.record_backtrace();
1562 let project_id = ProjectId::from_proto(request.remote_entity_id());
1563 let host_connection_id = {
1564 let collaborators = session
1565 .db()
1566 .await
1567 .project_collaborators(project_id, session.connection_id)
1568 .await?;
1569 collaborators
1570 .iter()
1571 .find(|collaborator| collaborator.is_host)
1572 .ok_or_else(|| anyhow!("host not found"))?
1573 .connection_id
1574 };
1575
1576 let payload = session
1577 .peer
1578 .forward_request(session.connection_id, host_connection_id, request)
1579 .await?;
1580
1581 response.send(payload)?;
1582 Ok(())
1583}
1584
1585async fn save_buffer(
1586 request: proto::SaveBuffer,
1587 response: Response<proto::SaveBuffer>,
1588 session: Session,
1589) -> Result<()> {
1590 let project_id = ProjectId::from_proto(request.project_id);
1591 let host_connection_id = {
1592 let collaborators = session
1593 .db()
1594 .await
1595 .project_collaborators(project_id, session.connection_id)
1596 .await?;
1597 collaborators
1598 .iter()
1599 .find(|collaborator| collaborator.is_host)
1600 .ok_or_else(|| anyhow!("host not found"))?
1601 .connection_id
1602 };
1603 let response_payload = session
1604 .peer
1605 .forward_request(session.connection_id, host_connection_id, request.clone())
1606 .await?;
1607
1608 let mut collaborators = session
1609 .db()
1610 .await
1611 .project_collaborators(project_id, session.connection_id)
1612 .await?;
1613 collaborators.retain(|collaborator| collaborator.connection_id != session.connection_id);
1614 let project_connection_ids = collaborators
1615 .iter()
1616 .map(|collaborator| collaborator.connection_id);
1617 broadcast(
1618 Some(host_connection_id),
1619 project_connection_ids,
1620 |conn_id| {
1621 session
1622 .peer
1623 .forward_send(host_connection_id, conn_id, response_payload.clone())
1624 },
1625 );
1626 response.send(response_payload)?;
1627 Ok(())
1628}
1629
1630async fn create_buffer_for_peer(
1631 request: proto::CreateBufferForPeer,
1632 session: Session,
1633) -> Result<()> {
1634 session.executor.record_backtrace();
1635 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1636 session
1637 .peer
1638 .forward_send(session.connection_id, peer_id.into(), request)?;
1639 Ok(())
1640}
1641
1642async fn update_buffer(
1643 request: proto::UpdateBuffer,
1644 response: Response<proto::UpdateBuffer>,
1645 session: Session,
1646) -> Result<()> {
1647 session.executor.record_backtrace();
1648 let project_id = ProjectId::from_proto(request.project_id);
1649 let project_connection_ids = session
1650 .db()
1651 .await
1652 .project_connection_ids(project_id, session.connection_id)
1653 .await?;
1654
1655 session.executor.record_backtrace();
1656
1657 broadcast(
1658 Some(session.connection_id),
1659 project_connection_ids.iter().copied(),
1660 |connection_id| {
1661 session
1662 .peer
1663 .forward_send(session.connection_id, connection_id, request.clone())
1664 },
1665 );
1666 response.send(proto::Ack {})?;
1667 Ok(())
1668}
1669
1670async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1671 let project_id = ProjectId::from_proto(request.project_id);
1672 let project_connection_ids = session
1673 .db()
1674 .await
1675 .project_connection_ids(project_id, session.connection_id)
1676 .await?;
1677
1678 broadcast(
1679 Some(session.connection_id),
1680 project_connection_ids.iter().copied(),
1681 |connection_id| {
1682 session
1683 .peer
1684 .forward_send(session.connection_id, connection_id, request.clone())
1685 },
1686 );
1687 Ok(())
1688}
1689
1690async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1691 let project_id = ProjectId::from_proto(request.project_id);
1692 let project_connection_ids = session
1693 .db()
1694 .await
1695 .project_connection_ids(project_id, session.connection_id)
1696 .await?;
1697 broadcast(
1698 Some(session.connection_id),
1699 project_connection_ids.iter().copied(),
1700 |connection_id| {
1701 session
1702 .peer
1703 .forward_send(session.connection_id, connection_id, request.clone())
1704 },
1705 );
1706 Ok(())
1707}
1708
1709async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1710 let project_id = ProjectId::from_proto(request.project_id);
1711 let project_connection_ids = session
1712 .db()
1713 .await
1714 .project_connection_ids(project_id, session.connection_id)
1715 .await?;
1716 broadcast(
1717 Some(session.connection_id),
1718 project_connection_ids.iter().copied(),
1719 |connection_id| {
1720 session
1721 .peer
1722 .forward_send(session.connection_id, connection_id, request.clone())
1723 },
1724 );
1725 Ok(())
1726}
1727
1728async fn follow(
1729 request: proto::Follow,
1730 response: Response<proto::Follow>,
1731 session: Session,
1732) -> Result<()> {
1733 let project_id = ProjectId::from_proto(request.project_id);
1734 let leader_id = request
1735 .leader_id
1736 .ok_or_else(|| anyhow!("invalid leader id"))?
1737 .into();
1738 let follower_id = session.connection_id;
1739
1740 {
1741 let project_connection_ids = session
1742 .db()
1743 .await
1744 .project_connection_ids(project_id, session.connection_id)
1745 .await?;
1746
1747 if !project_connection_ids.contains(&leader_id) {
1748 Err(anyhow!("no such peer"))?;
1749 }
1750 }
1751
1752 let mut response_payload = session
1753 .peer
1754 .forward_request(session.connection_id, leader_id, request)
1755 .await?;
1756 response_payload
1757 .views
1758 .retain(|view| view.leader_id != Some(follower_id.into()));
1759 response.send(response_payload)?;
1760
1761 let room = session
1762 .db()
1763 .await
1764 .follow(project_id, leader_id, follower_id)
1765 .await?;
1766 room_updated(&room, &session.peer);
1767
1768 Ok(())
1769}
1770
1771async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1772 let project_id = ProjectId::from_proto(request.project_id);
1773 let leader_id = request
1774 .leader_id
1775 .ok_or_else(|| anyhow!("invalid leader id"))?
1776 .into();
1777 let follower_id = session.connection_id;
1778
1779 if !session
1780 .db()
1781 .await
1782 .project_connection_ids(project_id, session.connection_id)
1783 .await?
1784 .contains(&leader_id)
1785 {
1786 Err(anyhow!("no such peer"))?;
1787 }
1788
1789 session
1790 .peer
1791 .forward_send(session.connection_id, leader_id, request)?;
1792
1793 let room = session
1794 .db()
1795 .await
1796 .unfollow(project_id, leader_id, follower_id)
1797 .await?;
1798 room_updated(&room, &session.peer);
1799
1800 Ok(())
1801}
1802
1803async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1804 let project_id = ProjectId::from_proto(request.project_id);
1805 let project_connection_ids = session
1806 .db
1807 .lock()
1808 .await
1809 .project_connection_ids(project_id, session.connection_id)
1810 .await?;
1811
1812 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1813 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1814 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1815 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1816 });
1817 for follower_peer_id in request.follower_ids.iter().copied() {
1818 let follower_connection_id = follower_peer_id.into();
1819 if project_connection_ids.contains(&follower_connection_id)
1820 && Some(follower_peer_id) != leader_id
1821 {
1822 session.peer.forward_send(
1823 session.connection_id,
1824 follower_connection_id,
1825 request.clone(),
1826 )?;
1827 }
1828 }
1829 Ok(())
1830}
1831
1832async fn get_users(
1833 request: proto::GetUsers,
1834 response: Response<proto::GetUsers>,
1835 session: Session,
1836) -> Result<()> {
1837 let user_ids = request
1838 .user_ids
1839 .into_iter()
1840 .map(UserId::from_proto)
1841 .collect();
1842 let users = session
1843 .db()
1844 .await
1845 .get_users_by_ids(user_ids)
1846 .await?
1847 .into_iter()
1848 .map(|user| proto::User {
1849 id: user.id.to_proto(),
1850 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1851 github_login: user.github_login,
1852 })
1853 .collect();
1854 response.send(proto::UsersResponse { users })?;
1855 Ok(())
1856}
1857
1858async fn fuzzy_search_users(
1859 request: proto::FuzzySearchUsers,
1860 response: Response<proto::FuzzySearchUsers>,
1861 session: Session,
1862) -> Result<()> {
1863 let query = request.query;
1864 let users = match query.len() {
1865 0 => vec![],
1866 1 | 2 => session
1867 .db()
1868 .await
1869 .get_user_by_github_account(&query, None)
1870 .await?
1871 .into_iter()
1872 .collect(),
1873 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1874 };
1875 let users = users
1876 .into_iter()
1877 .filter(|user| user.id != session.user_id)
1878 .map(|user| proto::User {
1879 id: user.id.to_proto(),
1880 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1881 github_login: user.github_login,
1882 })
1883 .collect();
1884 response.send(proto::UsersResponse { users })?;
1885 Ok(())
1886}
1887
1888async fn request_contact(
1889 request: proto::RequestContact,
1890 response: Response<proto::RequestContact>,
1891 session: Session,
1892) -> Result<()> {
1893 let requester_id = session.user_id;
1894 let responder_id = UserId::from_proto(request.responder_id);
1895 if requester_id == responder_id {
1896 return Err(anyhow!("cannot add yourself as a contact"))?;
1897 }
1898
1899 session
1900 .db()
1901 .await
1902 .send_contact_request(requester_id, responder_id)
1903 .await?;
1904
1905 // Update outgoing contact requests of requester
1906 let mut update = proto::UpdateContacts::default();
1907 update.outgoing_requests.push(responder_id.to_proto());
1908 for connection_id in session
1909 .connection_pool()
1910 .await
1911 .user_connection_ids(requester_id)
1912 {
1913 session.peer.send(connection_id, update.clone())?;
1914 }
1915
1916 // Update incoming contact requests of responder
1917 let mut update = proto::UpdateContacts::default();
1918 update
1919 .incoming_requests
1920 .push(proto::IncomingContactRequest {
1921 requester_id: requester_id.to_proto(),
1922 should_notify: true,
1923 });
1924 for connection_id in session
1925 .connection_pool()
1926 .await
1927 .user_connection_ids(responder_id)
1928 {
1929 session.peer.send(connection_id, update.clone())?;
1930 }
1931
1932 response.send(proto::Ack {})?;
1933 Ok(())
1934}
1935
1936async fn respond_to_contact_request(
1937 request: proto::RespondToContactRequest,
1938 response: Response<proto::RespondToContactRequest>,
1939 session: Session,
1940) -> Result<()> {
1941 let responder_id = session.user_id;
1942 let requester_id = UserId::from_proto(request.requester_id);
1943 let db = session.db().await;
1944 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1945 db.dismiss_contact_notification(responder_id, requester_id)
1946 .await?;
1947 } else {
1948 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1949
1950 db.respond_to_contact_request(responder_id, requester_id, accept)
1951 .await?;
1952 let requester_busy = db.is_user_busy(requester_id).await?;
1953 let responder_busy = db.is_user_busy(responder_id).await?;
1954
1955 let pool = session.connection_pool().await;
1956 // Update responder with new contact
1957 let mut update = proto::UpdateContacts::default();
1958 if accept {
1959 update
1960 .contacts
1961 .push(contact_for_user(requester_id, false, requester_busy, &pool));
1962 }
1963 update
1964 .remove_incoming_requests
1965 .push(requester_id.to_proto());
1966 for connection_id in pool.user_connection_ids(responder_id) {
1967 session.peer.send(connection_id, update.clone())?;
1968 }
1969
1970 // Update requester with new contact
1971 let mut update = proto::UpdateContacts::default();
1972 if accept {
1973 update
1974 .contacts
1975 .push(contact_for_user(responder_id, true, responder_busy, &pool));
1976 }
1977 update
1978 .remove_outgoing_requests
1979 .push(responder_id.to_proto());
1980 for connection_id in pool.user_connection_ids(requester_id) {
1981 session.peer.send(connection_id, update.clone())?;
1982 }
1983 }
1984
1985 response.send(proto::Ack {})?;
1986 Ok(())
1987}
1988
1989async fn remove_contact(
1990 request: proto::RemoveContact,
1991 response: Response<proto::RemoveContact>,
1992 session: Session,
1993) -> Result<()> {
1994 let requester_id = session.user_id;
1995 let responder_id = UserId::from_proto(request.user_id);
1996 let db = session.db().await;
1997 let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
1998
1999 let pool = session.connection_pool().await;
2000 // Update outgoing contact requests of requester
2001 let mut update = proto::UpdateContacts::default();
2002 if contact_accepted {
2003 update.remove_contacts.push(responder_id.to_proto());
2004 } else {
2005 update
2006 .remove_outgoing_requests
2007 .push(responder_id.to_proto());
2008 }
2009 for connection_id in pool.user_connection_ids(requester_id) {
2010 session.peer.send(connection_id, update.clone())?;
2011 }
2012
2013 // Update incoming contact requests of responder
2014 let mut update = proto::UpdateContacts::default();
2015 if contact_accepted {
2016 update.remove_contacts.push(requester_id.to_proto());
2017 } else {
2018 update
2019 .remove_incoming_requests
2020 .push(requester_id.to_proto());
2021 }
2022 for connection_id in pool.user_connection_ids(responder_id) {
2023 session.peer.send(connection_id, update.clone())?;
2024 }
2025
2026 response.send(proto::Ack {})?;
2027 Ok(())
2028}
2029
2030async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2031 let project_id = ProjectId::from_proto(request.project_id);
2032 let project_connection_ids = session
2033 .db()
2034 .await
2035 .project_connection_ids(project_id, session.connection_id)
2036 .await?;
2037 broadcast(
2038 Some(session.connection_id),
2039 project_connection_ids.iter().copied(),
2040 |connection_id| {
2041 session
2042 .peer
2043 .forward_send(session.connection_id, connection_id, request.clone())
2044 },
2045 );
2046 Ok(())
2047}
2048
2049async fn get_private_user_info(
2050 _request: proto::GetPrivateUserInfo,
2051 response: Response<proto::GetPrivateUserInfo>,
2052 session: Session,
2053) -> Result<()> {
2054 let metrics_id = session
2055 .db()
2056 .await
2057 .get_user_metrics_id(session.user_id)
2058 .await?;
2059 let user = session
2060 .db()
2061 .await
2062 .get_user_by_id(session.user_id)
2063 .await?
2064 .ok_or_else(|| anyhow!("user not found"))?;
2065 response.send(proto::GetPrivateUserInfoResponse {
2066 metrics_id,
2067 staff: user.admin,
2068 })?;
2069 Ok(())
2070}
2071
2072fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2073 match message {
2074 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2075 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2076 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2077 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2078 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2079 code: frame.code.into(),
2080 reason: frame.reason,
2081 })),
2082 }
2083}
2084
2085fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2086 match message {
2087 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2088 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2089 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2090 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2091 AxumMessage::Close(frame) => {
2092 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2093 code: frame.code.into(),
2094 reason: frame.reason,
2095 }))
2096 }
2097 }
2098}
2099
2100fn build_initial_contacts_update(
2101 contacts: Vec<db::Contact>,
2102 pool: &ConnectionPool,
2103) -> proto::UpdateContacts {
2104 let mut update = proto::UpdateContacts::default();
2105
2106 for contact in contacts {
2107 match contact {
2108 db::Contact::Accepted {
2109 user_id,
2110 should_notify,
2111 busy,
2112 } => {
2113 update
2114 .contacts
2115 .push(contact_for_user(user_id, should_notify, busy, &pool));
2116 }
2117 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2118 db::Contact::Incoming {
2119 user_id,
2120 should_notify,
2121 } => update
2122 .incoming_requests
2123 .push(proto::IncomingContactRequest {
2124 requester_id: user_id.to_proto(),
2125 should_notify,
2126 }),
2127 }
2128 }
2129
2130 update
2131}
2132
2133fn contact_for_user(
2134 user_id: UserId,
2135 should_notify: bool,
2136 busy: bool,
2137 pool: &ConnectionPool,
2138) -> proto::Contact {
2139 proto::Contact {
2140 user_id: user_id.to_proto(),
2141 online: pool.is_user_online(user_id),
2142 busy,
2143 should_notify,
2144 }
2145}
2146
2147fn room_updated(room: &proto::Room, peer: &Peer) {
2148 broadcast(
2149 None,
2150 room.participants
2151 .iter()
2152 .filter_map(|participant| Some(participant.peer_id?.into())),
2153 |peer_id| {
2154 peer.send(
2155 peer_id.into(),
2156 proto::RoomUpdated {
2157 room: Some(room.clone()),
2158 },
2159 )
2160 },
2161 );
2162}
2163
2164async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2165 let db = session.db().await;
2166 let contacts = db.get_contacts(user_id).await?;
2167 let busy = db.is_user_busy(user_id).await?;
2168
2169 let pool = session.connection_pool().await;
2170 let updated_contact = contact_for_user(user_id, false, busy, &pool);
2171 for contact in contacts {
2172 if let db::Contact::Accepted {
2173 user_id: contact_user_id,
2174 ..
2175 } = contact
2176 {
2177 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2178 session
2179 .peer
2180 .send(
2181 contact_conn_id,
2182 proto::UpdateContacts {
2183 contacts: vec![updated_contact.clone()],
2184 remove_contacts: Default::default(),
2185 incoming_requests: Default::default(),
2186 remove_incoming_requests: Default::default(),
2187 outgoing_requests: Default::default(),
2188 remove_outgoing_requests: Default::default(),
2189 },
2190 )
2191 .trace_err();
2192 }
2193 }
2194 }
2195 Ok(())
2196}
2197
2198async fn leave_room_for_session(session: &Session) -> Result<()> {
2199 let mut contacts_to_update = HashSet::default();
2200
2201 let room_id;
2202 let canceled_calls_to_user_ids;
2203 let live_kit_room;
2204 let delete_live_kit_room;
2205 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2206 contacts_to_update.insert(session.user_id);
2207
2208 for project in left_room.left_projects.values() {
2209 project_left(project, session);
2210 }
2211
2212 room_updated(&left_room.room, &session.peer);
2213 room_id = RoomId::from_proto(left_room.room.id);
2214 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2215 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2216 delete_live_kit_room = left_room.room.participants.is_empty();
2217 } else {
2218 return Ok(());
2219 }
2220
2221 {
2222 let pool = session.connection_pool().await;
2223 for canceled_user_id in canceled_calls_to_user_ids {
2224 for connection_id in pool.user_connection_ids(canceled_user_id) {
2225 session
2226 .peer
2227 .send(
2228 connection_id,
2229 proto::CallCanceled {
2230 room_id: room_id.to_proto(),
2231 },
2232 )
2233 .trace_err();
2234 }
2235 contacts_to_update.insert(canceled_user_id);
2236 }
2237 }
2238
2239 for contact_user_id in contacts_to_update {
2240 update_user_contacts(contact_user_id, &session).await?;
2241 }
2242
2243 if let Some(live_kit) = session.live_kit_client.as_ref() {
2244 live_kit
2245 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2246 .await
2247 .trace_err();
2248
2249 if delete_live_kit_room {
2250 live_kit.delete_room(live_kit_room).await.trace_err();
2251 }
2252 }
2253
2254 Ok(())
2255}
2256
2257fn project_left(project: &db::LeftProject, session: &Session) {
2258 for connection_id in &project.connection_ids {
2259 if project.host_user_id == session.user_id {
2260 session
2261 .peer
2262 .send(
2263 *connection_id,
2264 proto::UnshareProject {
2265 project_id: project.id.to_proto(),
2266 },
2267 )
2268 .trace_err();
2269 } else {
2270 session
2271 .peer
2272 .send(
2273 *connection_id,
2274 proto::RemoveProjectCollaborator {
2275 project_id: project.id.to_proto(),
2276 peer_id: Some(session.connection_id.into()),
2277 },
2278 )
2279 .trace_err();
2280 }
2281 }
2282}
2283
2284pub trait ResultExt {
2285 type Ok;
2286
2287 fn trace_err(self) -> Option<Self::Ok>;
2288}
2289
2290impl<T, E> ResultExt for Result<T, E>
2291where
2292 E: std::fmt::Debug,
2293{
2294 type Ok = T;
2295
2296 fn trace_err(self) -> Option<T> {
2297 match self {
2298 Ok(value) => Some(value),
2299 Err(error) => {
2300 tracing::error!("{:?}", error);
2301 None
2302 }
2303 }
2304 }
2305}