1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, Database, ProjectId, RoomId, ServerId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
38 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
39};
40use serde::{Serialize, Serializer};
41use std::{
42 any::TypeId,
43 fmt,
44 future::Future,
45 marker::PhantomData,
46 mem,
47 net::SocketAddr,
48 ops::{Deref, DerefMut},
49 rc::Rc,
50 sync::{
51 atomic::{AtomicBool, Ordering::SeqCst},
52 Arc,
53 },
54 time::Duration,
55};
56use tokio::sync::watch;
57use tower::ServiceBuilder;
58use tracing::{info_span, instrument, Instrument};
59
60pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(5);
61pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
62
63lazy_static! {
64 static ref METRIC_CONNECTIONS: IntGauge =
65 register_int_gauge!("connections", "number of connections").unwrap();
66 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
67 "shared_projects",
68 "number of open projects with one or more guests"
69 )
70 .unwrap();
71}
72
73type MessageHandler =
74 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
75
76struct Response<R> {
77 peer: Arc<Peer>,
78 receipt: Receipt<R>,
79 responded: Arc<AtomicBool>,
80}
81
82impl<R: RequestMessage> Response<R> {
83 fn send(self, payload: R::Response) -> Result<()> {
84 self.responded.store(true, SeqCst);
85 self.peer.respond(self.receipt, payload)?;
86 Ok(())
87 }
88}
89
90#[derive(Clone)]
91struct Session {
92 user_id: UserId,
93 connection_id: ConnectionId,
94 db: Arc<tokio::sync::Mutex<DbHandle>>,
95 peer: Arc<Peer>,
96 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
97 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
98 executor: Executor,
99}
100
101impl Session {
102 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
103 #[cfg(test)]
104 tokio::task::yield_now().await;
105 let guard = self.db.lock().await;
106 #[cfg(test)]
107 tokio::task::yield_now().await;
108 guard
109 }
110
111 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
112 #[cfg(test)]
113 tokio::task::yield_now().await;
114 let guard = self.connection_pool.lock();
115 ConnectionPoolGuard {
116 guard,
117 _not_send: PhantomData,
118 }
119 }
120}
121
122impl fmt::Debug for Session {
123 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124 f.debug_struct("Session")
125 .field("user_id", &self.user_id)
126 .field("connection_id", &self.connection_id)
127 .finish()
128 }
129}
130
131struct DbHandle(Arc<Database>);
132
133impl Deref for DbHandle {
134 type Target = Database;
135
136 fn deref(&self) -> &Self::Target {
137 self.0.as_ref()
138 }
139}
140
141pub struct Server {
142 id: parking_lot::Mutex<ServerId>,
143 peer: Arc<Peer>,
144 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
145 app_state: Arc<AppState>,
146 executor: Executor,
147 handlers: HashMap<TypeId, MessageHandler>,
148 teardown: watch::Sender<()>,
149}
150
151pub(crate) struct ConnectionPoolGuard<'a> {
152 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
153 _not_send: PhantomData<Rc<()>>,
154}
155
156#[derive(Serialize)]
157pub struct ServerSnapshot<'a> {
158 peer: &'a Peer,
159 #[serde(serialize_with = "serialize_deref")]
160 connection_pool: ConnectionPoolGuard<'a>,
161}
162
163pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
164where
165 S: Serializer,
166 T: Deref<Target = U>,
167 U: Serialize,
168{
169 Serialize::serialize(value.deref(), serializer)
170}
171
172impl Server {
173 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
174 let mut server = Self {
175 id: parking_lot::Mutex::new(id),
176 peer: Peer::new(id.0 as u32),
177 app_state,
178 executor,
179 connection_pool: Default::default(),
180 handlers: Default::default(),
181 teardown: watch::channel(()).0,
182 };
183
184 server
185 .add_request_handler(ping)
186 .add_request_handler(create_room)
187 .add_request_handler(join_room)
188 .add_request_handler(rejoin_room)
189 .add_message_handler(leave_room)
190 .add_request_handler(call)
191 .add_request_handler(cancel_call)
192 .add_message_handler(decline_call)
193 .add_request_handler(update_participant_location)
194 .add_request_handler(share_project)
195 .add_message_handler(unshare_project)
196 .add_request_handler(join_project)
197 .add_message_handler(leave_project)
198 .add_request_handler(update_project)
199 .add_request_handler(update_worktree)
200 .add_message_handler(start_language_server)
201 .add_message_handler(update_language_server)
202 .add_message_handler(update_diagnostic_summary)
203 .add_request_handler(forward_project_request::<proto::GetHover>)
204 .add_request_handler(forward_project_request::<proto::GetDefinition>)
205 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
206 .add_request_handler(forward_project_request::<proto::GetReferences>)
207 .add_request_handler(forward_project_request::<proto::SearchProject>)
208 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
209 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
210 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
211 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
212 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
213 .add_request_handler(forward_project_request::<proto::GetCompletions>)
214 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
215 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
216 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
217 .add_request_handler(forward_project_request::<proto::PrepareRename>)
218 .add_request_handler(forward_project_request::<proto::PerformRename>)
219 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
220 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
221 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
222 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
223 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
224 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
225 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
226 .add_message_handler(create_buffer_for_peer)
227 .add_request_handler(update_buffer)
228 .add_message_handler(update_buffer_file)
229 .add_message_handler(buffer_reloaded)
230 .add_message_handler(buffer_saved)
231 .add_request_handler(save_buffer)
232 .add_request_handler(get_users)
233 .add_request_handler(fuzzy_search_users)
234 .add_request_handler(request_contact)
235 .add_request_handler(remove_contact)
236 .add_request_handler(respond_to_contact_request)
237 .add_request_handler(follow)
238 .add_message_handler(unfollow)
239 .add_message_handler(update_followers)
240 .add_message_handler(update_diff_base)
241 .add_request_handler(get_private_user_info);
242
243 Arc::new(server)
244 }
245
246 pub async fn start(&self) -> Result<()> {
247 let server_id = *self.id.lock();
248 let app_state = self.app_state.clone();
249 let peer = self.peer.clone();
250 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
251 let pool = self.connection_pool.clone();
252 let live_kit_client = self.app_state.live_kit_client.clone();
253
254 let span = info_span!("start server");
255 self.executor.spawn_detached(
256 async move {
257 tracing::info!("waiting for cleanup timeout");
258 timeout.await;
259 tracing::info!("cleanup timeout expired, retrieving stale rooms");
260 if let Some(room_ids) = app_state
261 .db
262 .stale_room_ids(&app_state.config.zed_environment, server_id)
263 .await
264 .trace_err()
265 {
266 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
267 for room_id in room_ids {
268 let mut contacts_to_update = HashSet::default();
269 let mut canceled_calls_to_user_ids = Vec::new();
270 let mut live_kit_room = String::new();
271 let mut delete_live_kit_room = false;
272
273 if let Ok(mut refreshed_room) =
274 app_state.db.refresh_room(room_id, server_id).await
275 {
276 tracing::info!(
277 room_id = room_id.0,
278 new_participant_count = refreshed_room.room.participants.len(),
279 "refreshed room"
280 );
281 room_updated(&refreshed_room.room, &peer);
282 contacts_to_update
283 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
284 contacts_to_update
285 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
286 canceled_calls_to_user_ids =
287 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
288 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
289 delete_live_kit_room = refreshed_room.room.participants.is_empty();
290 }
291
292 {
293 let pool = pool.lock();
294 for canceled_user_id in canceled_calls_to_user_ids {
295 for connection_id in pool.user_connection_ids(canceled_user_id) {
296 peer.send(
297 connection_id,
298 proto::CallCanceled {
299 room_id: room_id.to_proto(),
300 },
301 )
302 .trace_err();
303 }
304 }
305 }
306
307 for user_id in contacts_to_update {
308 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
309 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
310 if let Some((busy, contacts)) = busy.zip(contacts) {
311 let pool = pool.lock();
312 let updated_contact = contact_for_user(user_id, false, busy, &pool);
313 for contact in contacts {
314 if let db::Contact::Accepted {
315 user_id: contact_user_id,
316 ..
317 } = contact
318 {
319 for contact_conn_id in
320 pool.user_connection_ids(contact_user_id)
321 {
322 peer.send(
323 contact_conn_id,
324 proto::UpdateContacts {
325 contacts: vec![updated_contact.clone()],
326 remove_contacts: Default::default(),
327 incoming_requests: Default::default(),
328 remove_incoming_requests: Default::default(),
329 outgoing_requests: Default::default(),
330 remove_outgoing_requests: Default::default(),
331 },
332 )
333 .trace_err();
334 }
335 }
336 }
337 }
338 }
339
340 if let Some(live_kit) = live_kit_client.as_ref() {
341 if delete_live_kit_room {
342 live_kit.delete_room(live_kit_room).await.trace_err();
343 }
344 }
345 }
346 }
347
348 app_state
349 .db
350 .delete_stale_servers(&app_state.config.zed_environment, server_id)
351 .await
352 .trace_err();
353 }
354 .instrument(span),
355 );
356 Ok(())
357 }
358
359 pub fn teardown(&self) {
360 self.peer.teardown();
361 self.connection_pool.lock().reset();
362 let _ = self.teardown.send(());
363 }
364
365 #[cfg(test)]
366 pub fn reset(&self, id: ServerId) {
367 self.teardown();
368 *self.id.lock() = id;
369 self.peer.reset(id.0 as u32);
370 }
371
372 #[cfg(test)]
373 pub fn id(&self) -> ServerId {
374 *self.id.lock()
375 }
376
377 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
378 where
379 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
380 Fut: 'static + Send + Future<Output = Result<()>>,
381 M: EnvelopedMessage,
382 {
383 let prev_handler = self.handlers.insert(
384 TypeId::of::<M>(),
385 Box::new(move |envelope, session| {
386 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
387 let span = info_span!(
388 "handle message",
389 payload_type = envelope.payload_type_name()
390 );
391 span.in_scope(|| {
392 tracing::info!(
393 payload_type = envelope.payload_type_name(),
394 "message received"
395 );
396 });
397 let future = (handler)(*envelope, session);
398 async move {
399 if let Err(error) = future.await {
400 tracing::error!(%error, "error handling message");
401 }
402 }
403 .instrument(span)
404 .boxed()
405 }),
406 );
407 if prev_handler.is_some() {
408 panic!("registered a handler for the same message twice");
409 }
410 self
411 }
412
413 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
414 where
415 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
416 Fut: 'static + Send + Future<Output = Result<()>>,
417 M: EnvelopedMessage,
418 {
419 self.add_handler(move |envelope, session| handler(envelope.payload, session));
420 self
421 }
422
423 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
424 where
425 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
426 Fut: Send + Future<Output = Result<()>>,
427 M: RequestMessage,
428 {
429 let handler = Arc::new(handler);
430 self.add_handler(move |envelope, session| {
431 let receipt = envelope.receipt();
432 let handler = handler.clone();
433 async move {
434 let peer = session.peer.clone();
435 let responded = Arc::new(AtomicBool::default());
436 let response = Response {
437 peer: peer.clone(),
438 responded: responded.clone(),
439 receipt,
440 };
441 match (handler)(envelope.payload, response, session).await {
442 Ok(()) => {
443 if responded.load(std::sync::atomic::Ordering::SeqCst) {
444 Ok(())
445 } else {
446 Err(anyhow!("handler did not send a response"))?
447 }
448 }
449 Err(error) => {
450 peer.respond_with_error(
451 receipt,
452 proto::Error {
453 message: error.to_string(),
454 },
455 )?;
456 Err(error)
457 }
458 }
459 }
460 })
461 }
462
463 pub fn handle_connection(
464 self: &Arc<Self>,
465 connection: Connection,
466 address: String,
467 user: User,
468 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
469 executor: Executor,
470 ) -> impl Future<Output = Result<()>> {
471 let this = self.clone();
472 let user_id = user.id;
473 let login = user.github_login;
474 let span = info_span!("handle connection", %user_id, %login, %address);
475 let mut teardown = self.teardown.subscribe();
476 async move {
477 let (connection_id, handle_io, mut incoming_rx) = this
478 .peer
479 .add_connection(connection, {
480 let executor = executor.clone();
481 move |duration| executor.sleep(duration)
482 });
483
484 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
485 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
486 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
487
488 if let Some(send_connection_id) = send_connection_id.take() {
489 let _ = send_connection_id.send(connection_id);
490 }
491
492 if !user.connected_once {
493 this.peer.send(connection_id, proto::ShowContacts {})?;
494 this.app_state.db.set_user_connected_once(user_id, true).await?;
495 }
496
497 let (contacts, invite_code) = future::try_join(
498 this.app_state.db.get_contacts(user_id),
499 this.app_state.db.get_invite_code_for_user(user_id)
500 ).await?;
501
502 {
503 let mut pool = this.connection_pool.lock();
504 pool.add_connection(connection_id, user_id, user.admin);
505 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
506
507 if let Some((code, count)) = invite_code {
508 this.peer.send(connection_id, proto::UpdateInviteInfo {
509 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
510 count: count as u32,
511 })?;
512 }
513 }
514
515 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
516 this.peer.send(connection_id, incoming_call)?;
517 }
518
519 let session = Session {
520 user_id,
521 connection_id,
522 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
523 peer: this.peer.clone(),
524 connection_pool: this.connection_pool.clone(),
525 live_kit_client: this.app_state.live_kit_client.clone(),
526 executor: executor.clone(),
527 };
528 update_user_contacts(user_id, &session).await?;
529
530 let handle_io = handle_io.fuse();
531 futures::pin_mut!(handle_io);
532
533 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
534 // This prevents deadlocks when e.g., client A performs a request to client B and
535 // client B performs a request to client A. If both clients stop processing further
536 // messages until their respective request completes, they won't have a chance to
537 // respond to the other client's request and cause a deadlock.
538 //
539 // This arrangement ensures we will attempt to process earlier messages first, but fall
540 // back to processing messages arrived later in the spirit of making progress.
541 let mut foreground_message_handlers = FuturesUnordered::new();
542 loop {
543 let next_message = incoming_rx.next().fuse();
544 futures::pin_mut!(next_message);
545 futures::select_biased! {
546 _ = teardown.changed().fuse() => return Ok(()),
547 result = handle_io => {
548 if let Err(error) = result {
549 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
550 }
551 break;
552 }
553 _ = foreground_message_handlers.next() => {}
554 message = next_message => {
555 if let Some(message) = message {
556 let type_name = message.payload_type_name();
557 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
558 let span_enter = span.enter();
559 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
560 let is_background = message.is_background();
561 let handle_message = (handler)(message, session.clone());
562 drop(span_enter);
563
564 let handle_message = handle_message.instrument(span);
565 if is_background {
566 executor.spawn_detached(handle_message);
567 } else {
568 foreground_message_handlers.push(handle_message);
569 }
570 } else {
571 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
572 }
573 } else {
574 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
575 break;
576 }
577 }
578 }
579 }
580
581 drop(foreground_message_handlers);
582 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
583 if let Err(error) = sign_out(session, teardown, executor).await {
584 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
585 }
586
587 Ok(())
588 }.instrument(span)
589 }
590
591 pub async fn invite_code_redeemed(
592 self: &Arc<Self>,
593 inviter_id: UserId,
594 invitee_id: UserId,
595 ) -> Result<()> {
596 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
597 if let Some(code) = &user.invite_code {
598 let pool = self.connection_pool.lock();
599 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
600 for connection_id in pool.user_connection_ids(inviter_id) {
601 self.peer.send(
602 connection_id,
603 proto::UpdateContacts {
604 contacts: vec![invitee_contact.clone()],
605 ..Default::default()
606 },
607 )?;
608 self.peer.send(
609 connection_id,
610 proto::UpdateInviteInfo {
611 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
612 count: user.invite_count as u32,
613 },
614 )?;
615 }
616 }
617 }
618 Ok(())
619 }
620
621 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
622 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
623 if let Some(invite_code) = &user.invite_code {
624 let pool = self.connection_pool.lock();
625 for connection_id in pool.user_connection_ids(user_id) {
626 self.peer.send(
627 connection_id,
628 proto::UpdateInviteInfo {
629 url: format!(
630 "{}{}",
631 self.app_state.config.invite_link_prefix, invite_code
632 ),
633 count: user.invite_count as u32,
634 },
635 )?;
636 }
637 }
638 }
639 Ok(())
640 }
641
642 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
643 ServerSnapshot {
644 connection_pool: ConnectionPoolGuard {
645 guard: self.connection_pool.lock(),
646 _not_send: PhantomData,
647 },
648 peer: &self.peer,
649 }
650 }
651}
652
653impl<'a> Deref for ConnectionPoolGuard<'a> {
654 type Target = ConnectionPool;
655
656 fn deref(&self) -> &Self::Target {
657 &*self.guard
658 }
659}
660
661impl<'a> DerefMut for ConnectionPoolGuard<'a> {
662 fn deref_mut(&mut self) -> &mut Self::Target {
663 &mut *self.guard
664 }
665}
666
667impl<'a> Drop for ConnectionPoolGuard<'a> {
668 fn drop(&mut self) {
669 #[cfg(test)]
670 self.check_invariants();
671 }
672}
673
674fn broadcast<F>(
675 sender_id: ConnectionId,
676 receiver_ids: impl IntoIterator<Item = ConnectionId>,
677 mut f: F,
678) where
679 F: FnMut(ConnectionId) -> anyhow::Result<()>,
680{
681 for receiver_id in receiver_ids {
682 if receiver_id != sender_id {
683 f(receiver_id).trace_err();
684 }
685 }
686}
687
688lazy_static! {
689 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
690}
691
692pub struct ProtocolVersion(u32);
693
694impl Header for ProtocolVersion {
695 fn name() -> &'static HeaderName {
696 &ZED_PROTOCOL_VERSION
697 }
698
699 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
700 where
701 Self: Sized,
702 I: Iterator<Item = &'i axum::http::HeaderValue>,
703 {
704 let version = values
705 .next()
706 .ok_or_else(axum::headers::Error::invalid)?
707 .to_str()
708 .map_err(|_| axum::headers::Error::invalid())?
709 .parse()
710 .map_err(|_| axum::headers::Error::invalid())?;
711 Ok(Self(version))
712 }
713
714 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
715 values.extend([self.0.to_string().parse().unwrap()]);
716 }
717}
718
719pub fn routes(server: Arc<Server>) -> Router<Body> {
720 Router::new()
721 .route("/rpc", get(handle_websocket_request))
722 .layer(
723 ServiceBuilder::new()
724 .layer(Extension(server.app_state.clone()))
725 .layer(middleware::from_fn(auth::validate_header)),
726 )
727 .route("/metrics", get(handle_metrics))
728 .layer(Extension(server))
729}
730
731pub async fn handle_websocket_request(
732 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
733 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
734 Extension(server): Extension<Arc<Server>>,
735 Extension(user): Extension<User>,
736 ws: WebSocketUpgrade,
737) -> axum::response::Response {
738 if protocol_version != rpc::PROTOCOL_VERSION {
739 return (
740 StatusCode::UPGRADE_REQUIRED,
741 "client must be upgraded".to_string(),
742 )
743 .into_response();
744 }
745 let socket_address = socket_address.to_string();
746 ws.on_upgrade(move |socket| {
747 use util::ResultExt;
748 let socket = socket
749 .map_ok(to_tungstenite_message)
750 .err_into()
751 .with(|message| async move { Ok(to_axum_message(message)) });
752 let connection = Connection::new(Box::pin(socket));
753 async move {
754 server
755 .handle_connection(connection, socket_address, user, None, Executor::Production)
756 .await
757 .log_err();
758 }
759 })
760}
761
762pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
763 let connections = server
764 .connection_pool
765 .lock()
766 .connections()
767 .filter(|connection| !connection.admin)
768 .count();
769
770 METRIC_CONNECTIONS.set(connections as _);
771
772 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
773 METRIC_SHARED_PROJECTS.set(shared_projects as _);
774
775 let encoder = prometheus::TextEncoder::new();
776 let metric_families = prometheus::gather();
777 let encoded_metrics = encoder
778 .encode_to_string(&metric_families)
779 .map_err(|err| anyhow!("{}", err))?;
780 Ok(encoded_metrics)
781}
782
783#[instrument(err, skip(executor))]
784async fn sign_out(
785 session: Session,
786 mut teardown: watch::Receiver<()>,
787 executor: Executor,
788) -> Result<()> {
789 session.peer.disconnect(session.connection_id);
790 session
791 .connection_pool()
792 .await
793 .remove_connection(session.connection_id)?;
794
795 session
796 .db()
797 .await
798 .connection_lost(session.connection_id)
799 .await
800 .trace_err();
801
802 futures::select_biased! {
803 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
804 leave_room_for_session(&session).await.trace_err();
805
806 if !session
807 .connection_pool()
808 .await
809 .is_user_online(session.user_id)
810 {
811 let db = session.db().await;
812 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
813 room_updated(&room, &session.peer);
814 }
815 }
816 update_user_contacts(session.user_id, &session).await?;
817 }
818 _ = teardown.changed().fuse() => {}
819 }
820
821 Ok(())
822}
823
824async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
825 response.send(proto::Ack {})?;
826 Ok(())
827}
828
829async fn create_room(
830 _request: proto::CreateRoom,
831 response: Response<proto::CreateRoom>,
832 session: Session,
833) -> Result<()> {
834 let live_kit_room = nanoid::nanoid!(30);
835 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
836 if let Some(_) = live_kit
837 .create_room(live_kit_room.clone())
838 .await
839 .trace_err()
840 {
841 if let Some(token) = live_kit
842 .room_token(&live_kit_room, &session.user_id.to_string())
843 .trace_err()
844 {
845 Some(proto::LiveKitConnectionInfo {
846 server_url: live_kit.url().into(),
847 token,
848 })
849 } else {
850 None
851 }
852 } else {
853 None
854 }
855 } else {
856 None
857 };
858
859 {
860 let room = session
861 .db()
862 .await
863 .create_room(session.user_id, session.connection_id, &live_kit_room)
864 .await?;
865
866 response.send(proto::CreateRoomResponse {
867 room: Some(room.clone()),
868 live_kit_connection_info,
869 })?;
870 }
871
872 update_user_contacts(session.user_id, &session).await?;
873 Ok(())
874}
875
876async fn join_room(
877 request: proto::JoinRoom,
878 response: Response<proto::JoinRoom>,
879 session: Session,
880) -> Result<()> {
881 let room_id = RoomId::from_proto(request.id);
882 let room = {
883 let room = session
884 .db()
885 .await
886 .join_room(room_id, session.user_id, session.connection_id)
887 .await?;
888 room_updated(&room, &session.peer);
889 room.clone()
890 };
891
892 for connection_id in session
893 .connection_pool()
894 .await
895 .user_connection_ids(session.user_id)
896 {
897 session
898 .peer
899 .send(
900 connection_id,
901 proto::CallCanceled {
902 room_id: room_id.to_proto(),
903 },
904 )
905 .trace_err();
906 }
907
908 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
909 if let Some(token) = live_kit
910 .room_token(&room.live_kit_room, &session.user_id.to_string())
911 .trace_err()
912 {
913 Some(proto::LiveKitConnectionInfo {
914 server_url: live_kit.url().into(),
915 token,
916 })
917 } else {
918 None
919 }
920 } else {
921 None
922 };
923
924 response.send(proto::JoinRoomResponse {
925 room: Some(room),
926 live_kit_connection_info,
927 })?;
928
929 update_user_contacts(session.user_id, &session).await?;
930 Ok(())
931}
932
933async fn rejoin_room(
934 request: proto::RejoinRoom,
935 response: Response<proto::RejoinRoom>,
936 session: Session,
937) -> Result<()> {
938 {
939 let mut rejoined_room = session
940 .db()
941 .await
942 .rejoin_room(request, session.user_id, session.connection_id)
943 .await?;
944
945 response.send(proto::RejoinRoomResponse {
946 room: Some(rejoined_room.room.clone()),
947 reshared_projects: rejoined_room
948 .reshared_projects
949 .iter()
950 .map(|project| proto::ResharedProject {
951 id: project.id.to_proto(),
952 collaborators: project
953 .collaborators
954 .iter()
955 .map(|collaborator| collaborator.to_proto())
956 .collect(),
957 })
958 .collect(),
959 rejoined_projects: rejoined_room
960 .rejoined_projects
961 .iter()
962 .map(|rejoined_project| proto::RejoinedProject {
963 id: rejoined_project.id.to_proto(),
964 worktrees: rejoined_project
965 .worktrees
966 .iter()
967 .map(|worktree| proto::WorktreeMetadata {
968 id: worktree.id,
969 root_name: worktree.root_name.clone(),
970 visible: worktree.visible,
971 abs_path: worktree.abs_path.clone(),
972 })
973 .collect(),
974 collaborators: rejoined_project
975 .collaborators
976 .iter()
977 .map(|collaborator| collaborator.to_proto())
978 .collect(),
979 language_servers: rejoined_project.language_servers.clone(),
980 })
981 .collect(),
982 })?;
983 room_updated(&rejoined_room.room, &session.peer);
984
985 for project in &rejoined_room.reshared_projects {
986 for collaborator in &project.collaborators {
987 session
988 .peer
989 .send(
990 collaborator.connection_id,
991 proto::UpdateProjectCollaborator {
992 project_id: project.id.to_proto(),
993 old_peer_id: Some(project.old_connection_id.into()),
994 new_peer_id: Some(session.connection_id.into()),
995 },
996 )
997 .trace_err();
998 }
999
1000 broadcast(
1001 session.connection_id,
1002 project
1003 .collaborators
1004 .iter()
1005 .map(|collaborator| collaborator.connection_id),
1006 |connection_id| {
1007 session.peer.forward_send(
1008 session.connection_id,
1009 connection_id,
1010 proto::UpdateProject {
1011 project_id: project.id.to_proto(),
1012 worktrees: project.worktrees.clone(),
1013 },
1014 )
1015 },
1016 );
1017 }
1018
1019 for project in &rejoined_room.rejoined_projects {
1020 for collaborator in &project.collaborators {
1021 session
1022 .peer
1023 .send(
1024 collaborator.connection_id,
1025 proto::UpdateProjectCollaborator {
1026 project_id: project.id.to_proto(),
1027 old_peer_id: Some(project.old_connection_id.into()),
1028 new_peer_id: Some(session.connection_id.into()),
1029 },
1030 )
1031 .trace_err();
1032 }
1033 }
1034
1035 for project in &mut rejoined_room.rejoined_projects {
1036 for worktree in mem::take(&mut project.worktrees) {
1037 #[cfg(any(test, feature = "test-support"))]
1038 const MAX_CHUNK_SIZE: usize = 2;
1039 #[cfg(not(any(test, feature = "test-support")))]
1040 const MAX_CHUNK_SIZE: usize = 256;
1041
1042 // Stream this worktree's entries.
1043 let message = proto::UpdateWorktree {
1044 project_id: project.id.to_proto(),
1045 worktree_id: worktree.id,
1046 abs_path: worktree.abs_path.clone(),
1047 root_name: worktree.root_name,
1048 updated_entries: worktree.updated_entries,
1049 removed_entries: worktree.removed_entries,
1050 scan_id: worktree.scan_id,
1051 is_last_update: worktree.is_complete,
1052 };
1053 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1054 session.peer.send(session.connection_id, update.clone())?;
1055 }
1056
1057 // Stream this worktree's diagnostics.
1058 for summary in worktree.diagnostic_summaries {
1059 session.peer.send(
1060 session.connection_id,
1061 proto::UpdateDiagnosticSummary {
1062 project_id: project.id.to_proto(),
1063 worktree_id: worktree.id,
1064 summary: Some(summary),
1065 },
1066 )?;
1067 }
1068 }
1069
1070 for language_server in &project.language_servers {
1071 session.peer.send(
1072 session.connection_id,
1073 proto::UpdateLanguageServer {
1074 project_id: project.id.to_proto(),
1075 language_server_id: language_server.id,
1076 variant: Some(
1077 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1078 proto::LspDiskBasedDiagnosticsUpdated {},
1079 ),
1080 ),
1081 },
1082 )?;
1083 }
1084 }
1085 }
1086
1087 update_user_contacts(session.user_id, &session).await?;
1088 Ok(())
1089}
1090
1091async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
1092 leave_room_for_session(&session).await
1093}
1094
1095async fn call(
1096 request: proto::Call,
1097 response: Response<proto::Call>,
1098 session: Session,
1099) -> Result<()> {
1100 let room_id = RoomId::from_proto(request.room_id);
1101 let calling_user_id = session.user_id;
1102 let calling_connection_id = session.connection_id;
1103 let called_user_id = UserId::from_proto(request.called_user_id);
1104 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1105 if !session
1106 .db()
1107 .await
1108 .has_contact(calling_user_id, called_user_id)
1109 .await?
1110 {
1111 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1112 }
1113
1114 let incoming_call = {
1115 let (room, incoming_call) = &mut *session
1116 .db()
1117 .await
1118 .call(
1119 room_id,
1120 calling_user_id,
1121 calling_connection_id,
1122 called_user_id,
1123 initial_project_id,
1124 )
1125 .await?;
1126 room_updated(&room, &session.peer);
1127 mem::take(incoming_call)
1128 };
1129 update_user_contacts(called_user_id, &session).await?;
1130
1131 let mut calls = session
1132 .connection_pool()
1133 .await
1134 .user_connection_ids(called_user_id)
1135 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1136 .collect::<FuturesUnordered<_>>();
1137
1138 while let Some(call_response) = calls.next().await {
1139 match call_response.as_ref() {
1140 Ok(_) => {
1141 response.send(proto::Ack {})?;
1142 return Ok(());
1143 }
1144 Err(_) => {
1145 call_response.trace_err();
1146 }
1147 }
1148 }
1149
1150 {
1151 let room = session
1152 .db()
1153 .await
1154 .call_failed(room_id, called_user_id)
1155 .await?;
1156 room_updated(&room, &session.peer);
1157 }
1158 update_user_contacts(called_user_id, &session).await?;
1159
1160 Err(anyhow!("failed to ring user"))?
1161}
1162
1163async fn cancel_call(
1164 request: proto::CancelCall,
1165 response: Response<proto::CancelCall>,
1166 session: Session,
1167) -> Result<()> {
1168 let called_user_id = UserId::from_proto(request.called_user_id);
1169 let room_id = RoomId::from_proto(request.room_id);
1170 {
1171 let room = session
1172 .db()
1173 .await
1174 .cancel_call(room_id, session.connection_id, called_user_id)
1175 .await?;
1176 room_updated(&room, &session.peer);
1177 }
1178
1179 for connection_id in session
1180 .connection_pool()
1181 .await
1182 .user_connection_ids(called_user_id)
1183 {
1184 session
1185 .peer
1186 .send(
1187 connection_id,
1188 proto::CallCanceled {
1189 room_id: room_id.to_proto(),
1190 },
1191 )
1192 .trace_err();
1193 }
1194 response.send(proto::Ack {})?;
1195
1196 update_user_contacts(called_user_id, &session).await?;
1197 Ok(())
1198}
1199
1200async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1201 let room_id = RoomId::from_proto(message.room_id);
1202 {
1203 let room = session
1204 .db()
1205 .await
1206 .decline_call(Some(room_id), session.user_id)
1207 .await?
1208 .ok_or_else(|| anyhow!("failed to decline call"))?;
1209 room_updated(&room, &session.peer);
1210 }
1211
1212 for connection_id in session
1213 .connection_pool()
1214 .await
1215 .user_connection_ids(session.user_id)
1216 {
1217 session
1218 .peer
1219 .send(
1220 connection_id,
1221 proto::CallCanceled {
1222 room_id: room_id.to_proto(),
1223 },
1224 )
1225 .trace_err();
1226 }
1227 update_user_contacts(session.user_id, &session).await?;
1228 Ok(())
1229}
1230
1231async fn update_participant_location(
1232 request: proto::UpdateParticipantLocation,
1233 response: Response<proto::UpdateParticipantLocation>,
1234 session: Session,
1235) -> Result<()> {
1236 let room_id = RoomId::from_proto(request.room_id);
1237 let location = request
1238 .location
1239 .ok_or_else(|| anyhow!("invalid location"))?;
1240 let room = session
1241 .db()
1242 .await
1243 .update_room_participant_location(room_id, session.connection_id, location)
1244 .await?;
1245 room_updated(&room, &session.peer);
1246 response.send(proto::Ack {})?;
1247 Ok(())
1248}
1249
1250async fn share_project(
1251 request: proto::ShareProject,
1252 response: Response<proto::ShareProject>,
1253 session: Session,
1254) -> Result<()> {
1255 let (project_id, room) = &*session
1256 .db()
1257 .await
1258 .share_project(
1259 RoomId::from_proto(request.room_id),
1260 session.connection_id,
1261 &request.worktrees,
1262 )
1263 .await?;
1264 response.send(proto::ShareProjectResponse {
1265 project_id: project_id.to_proto(),
1266 })?;
1267 room_updated(&room, &session.peer);
1268
1269 Ok(())
1270}
1271
1272async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1273 let project_id = ProjectId::from_proto(message.project_id);
1274
1275 let (room, guest_connection_ids) = &*session
1276 .db()
1277 .await
1278 .unshare_project(project_id, session.connection_id)
1279 .await?;
1280
1281 broadcast(
1282 session.connection_id,
1283 guest_connection_ids.iter().copied(),
1284 |conn_id| session.peer.send(conn_id, message.clone()),
1285 );
1286 room_updated(&room, &session.peer);
1287
1288 Ok(())
1289}
1290
1291async fn join_project(
1292 request: proto::JoinProject,
1293 response: Response<proto::JoinProject>,
1294 session: Session,
1295) -> Result<()> {
1296 let project_id = ProjectId::from_proto(request.project_id);
1297 let guest_user_id = session.user_id;
1298
1299 tracing::info!(%project_id, "join project");
1300
1301 let (project, replica_id) = &mut *session
1302 .db()
1303 .await
1304 .join_project(project_id, session.connection_id)
1305 .await?;
1306
1307 let collaborators = project
1308 .collaborators
1309 .iter()
1310 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1311 .map(|collaborator| collaborator.to_proto())
1312 .collect::<Vec<_>>();
1313 let worktrees = project
1314 .worktrees
1315 .iter()
1316 .map(|(id, worktree)| proto::WorktreeMetadata {
1317 id: *id,
1318 root_name: worktree.root_name.clone(),
1319 visible: worktree.visible,
1320 abs_path: worktree.abs_path.clone(),
1321 })
1322 .collect::<Vec<_>>();
1323
1324 for collaborator in &collaborators {
1325 session
1326 .peer
1327 .send(
1328 collaborator.peer_id.unwrap().into(),
1329 proto::AddProjectCollaborator {
1330 project_id: project_id.to_proto(),
1331 collaborator: Some(proto::Collaborator {
1332 peer_id: Some(session.connection_id.into()),
1333 replica_id: replica_id.0 as u32,
1334 user_id: guest_user_id.to_proto(),
1335 }),
1336 },
1337 )
1338 .trace_err();
1339 }
1340
1341 // First, we send the metadata associated with each worktree.
1342 response.send(proto::JoinProjectResponse {
1343 worktrees: worktrees.clone(),
1344 replica_id: replica_id.0 as u32,
1345 collaborators: collaborators.clone(),
1346 language_servers: project.language_servers.clone(),
1347 })?;
1348
1349 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1350 #[cfg(any(test, feature = "test-support"))]
1351 const MAX_CHUNK_SIZE: usize = 2;
1352 #[cfg(not(any(test, feature = "test-support")))]
1353 const MAX_CHUNK_SIZE: usize = 256;
1354
1355 // Stream this worktree's entries.
1356 let message = proto::UpdateWorktree {
1357 project_id: project_id.to_proto(),
1358 worktree_id,
1359 abs_path: worktree.abs_path.clone(),
1360 root_name: worktree.root_name,
1361 updated_entries: worktree.entries,
1362 removed_entries: Default::default(),
1363 scan_id: worktree.scan_id,
1364 is_last_update: worktree.is_complete,
1365 };
1366 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1367 session.peer.send(session.connection_id, update.clone())?;
1368 }
1369
1370 // Stream this worktree's diagnostics.
1371 for summary in worktree.diagnostic_summaries {
1372 session.peer.send(
1373 session.connection_id,
1374 proto::UpdateDiagnosticSummary {
1375 project_id: project_id.to_proto(),
1376 worktree_id: worktree.id,
1377 summary: Some(summary),
1378 },
1379 )?;
1380 }
1381 }
1382
1383 for language_server in &project.language_servers {
1384 session.peer.send(
1385 session.connection_id,
1386 proto::UpdateLanguageServer {
1387 project_id: project_id.to_proto(),
1388 language_server_id: language_server.id,
1389 variant: Some(
1390 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1391 proto::LspDiskBasedDiagnosticsUpdated {},
1392 ),
1393 ),
1394 },
1395 )?;
1396 }
1397
1398 Ok(())
1399}
1400
1401async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1402 let sender_id = session.connection_id;
1403 let project_id = ProjectId::from_proto(request.project_id);
1404
1405 let project = session
1406 .db()
1407 .await
1408 .leave_project(project_id, sender_id)
1409 .await?;
1410 tracing::info!(
1411 %project_id,
1412 host_user_id = %project.host_user_id,
1413 host_connection_id = %project.host_connection_id,
1414 "leave project"
1415 );
1416 project_left(&project, &session);
1417
1418 Ok(())
1419}
1420
1421async fn update_project(
1422 request: proto::UpdateProject,
1423 response: Response<proto::UpdateProject>,
1424 session: Session,
1425) -> Result<()> {
1426 let project_id = ProjectId::from_proto(request.project_id);
1427 let (room, guest_connection_ids) = &*session
1428 .db()
1429 .await
1430 .update_project(project_id, session.connection_id, &request.worktrees)
1431 .await?;
1432 broadcast(
1433 session.connection_id,
1434 guest_connection_ids.iter().copied(),
1435 |connection_id| {
1436 session
1437 .peer
1438 .forward_send(session.connection_id, connection_id, request.clone())
1439 },
1440 );
1441 room_updated(&room, &session.peer);
1442 response.send(proto::Ack {})?;
1443
1444 Ok(())
1445}
1446
1447async fn update_worktree(
1448 request: proto::UpdateWorktree,
1449 response: Response<proto::UpdateWorktree>,
1450 session: Session,
1451) -> Result<()> {
1452 let guest_connection_ids = session
1453 .db()
1454 .await
1455 .update_worktree(&request, session.connection_id)
1456 .await?;
1457
1458 broadcast(
1459 session.connection_id,
1460 guest_connection_ids.iter().copied(),
1461 |connection_id| {
1462 session
1463 .peer
1464 .forward_send(session.connection_id, connection_id, request.clone())
1465 },
1466 );
1467 response.send(proto::Ack {})?;
1468 Ok(())
1469}
1470
1471async fn update_diagnostic_summary(
1472 message: proto::UpdateDiagnosticSummary,
1473 session: Session,
1474) -> Result<()> {
1475 let guest_connection_ids = session
1476 .db()
1477 .await
1478 .update_diagnostic_summary(&message, session.connection_id)
1479 .await?;
1480
1481 broadcast(
1482 session.connection_id,
1483 guest_connection_ids.iter().copied(),
1484 |connection_id| {
1485 session
1486 .peer
1487 .forward_send(session.connection_id, connection_id, message.clone())
1488 },
1489 );
1490
1491 Ok(())
1492}
1493
1494async fn start_language_server(
1495 request: proto::StartLanguageServer,
1496 session: Session,
1497) -> Result<()> {
1498 let guest_connection_ids = session
1499 .db()
1500 .await
1501 .start_language_server(&request, session.connection_id)
1502 .await?;
1503
1504 broadcast(
1505 session.connection_id,
1506 guest_connection_ids.iter().copied(),
1507 |connection_id| {
1508 session
1509 .peer
1510 .forward_send(session.connection_id, connection_id, request.clone())
1511 },
1512 );
1513 Ok(())
1514}
1515
1516async fn update_language_server(
1517 request: proto::UpdateLanguageServer,
1518 session: Session,
1519) -> Result<()> {
1520 session.executor.record_backtrace();
1521 let project_id = ProjectId::from_proto(request.project_id);
1522 let project_connection_ids = session
1523 .db()
1524 .await
1525 .project_connection_ids(project_id, session.connection_id)
1526 .await?;
1527 broadcast(
1528 session.connection_id,
1529 project_connection_ids.iter().copied(),
1530 |connection_id| {
1531 session
1532 .peer
1533 .forward_send(session.connection_id, connection_id, request.clone())
1534 },
1535 );
1536 Ok(())
1537}
1538
1539async fn forward_project_request<T>(
1540 request: T,
1541 response: Response<T>,
1542 session: Session,
1543) -> Result<()>
1544where
1545 T: EntityMessage + RequestMessage,
1546{
1547 session.executor.record_backtrace();
1548 let project_id = ProjectId::from_proto(request.remote_entity_id());
1549 let host_connection_id = {
1550 let collaborators = session
1551 .db()
1552 .await
1553 .project_collaborators(project_id, session.connection_id)
1554 .await?;
1555 collaborators
1556 .iter()
1557 .find(|collaborator| collaborator.is_host)
1558 .ok_or_else(|| anyhow!("host not found"))?
1559 .connection_id
1560 };
1561
1562 let payload = session
1563 .peer
1564 .forward_request(session.connection_id, host_connection_id, request)
1565 .await?;
1566
1567 response.send(payload)?;
1568 Ok(())
1569}
1570
1571async fn save_buffer(
1572 request: proto::SaveBuffer,
1573 response: Response<proto::SaveBuffer>,
1574 session: Session,
1575) -> Result<()> {
1576 let project_id = ProjectId::from_proto(request.project_id);
1577 let host_connection_id = {
1578 let collaborators = session
1579 .db()
1580 .await
1581 .project_collaborators(project_id, session.connection_id)
1582 .await?;
1583 collaborators
1584 .iter()
1585 .find(|collaborator| collaborator.is_host)
1586 .ok_or_else(|| anyhow!("host not found"))?
1587 .connection_id
1588 };
1589 let response_payload = session
1590 .peer
1591 .forward_request(session.connection_id, host_connection_id, request.clone())
1592 .await?;
1593
1594 let mut collaborators = session
1595 .db()
1596 .await
1597 .project_collaborators(project_id, session.connection_id)
1598 .await?;
1599 collaborators.retain(|collaborator| collaborator.connection_id != session.connection_id);
1600 let project_connection_ids = collaborators
1601 .iter()
1602 .map(|collaborator| collaborator.connection_id);
1603 broadcast(host_connection_id, project_connection_ids, |conn_id| {
1604 session
1605 .peer
1606 .forward_send(host_connection_id, conn_id, response_payload.clone())
1607 });
1608 response.send(response_payload)?;
1609 Ok(())
1610}
1611
1612async fn create_buffer_for_peer(
1613 request: proto::CreateBufferForPeer,
1614 session: Session,
1615) -> Result<()> {
1616 session.executor.record_backtrace();
1617 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1618 session
1619 .peer
1620 .forward_send(session.connection_id, peer_id.into(), request)?;
1621 Ok(())
1622}
1623
1624async fn update_buffer(
1625 request: proto::UpdateBuffer,
1626 response: Response<proto::UpdateBuffer>,
1627 session: Session,
1628) -> Result<()> {
1629 session.executor.record_backtrace();
1630 let project_id = ProjectId::from_proto(request.project_id);
1631 let project_connection_ids = session
1632 .db()
1633 .await
1634 .project_connection_ids(project_id, session.connection_id)
1635 .await?;
1636
1637 session.executor.record_backtrace();
1638
1639 broadcast(
1640 session.connection_id,
1641 project_connection_ids.iter().copied(),
1642 |connection_id| {
1643 session
1644 .peer
1645 .forward_send(session.connection_id, connection_id, request.clone())
1646 },
1647 );
1648 response.send(proto::Ack {})?;
1649 Ok(())
1650}
1651
1652async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1653 let project_id = ProjectId::from_proto(request.project_id);
1654 let project_connection_ids = session
1655 .db()
1656 .await
1657 .project_connection_ids(project_id, session.connection_id)
1658 .await?;
1659
1660 broadcast(
1661 session.connection_id,
1662 project_connection_ids.iter().copied(),
1663 |connection_id| {
1664 session
1665 .peer
1666 .forward_send(session.connection_id, connection_id, request.clone())
1667 },
1668 );
1669 Ok(())
1670}
1671
1672async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1673 let project_id = ProjectId::from_proto(request.project_id);
1674 let project_connection_ids = session
1675 .db()
1676 .await
1677 .project_connection_ids(project_id, session.connection_id)
1678 .await?;
1679 broadcast(
1680 session.connection_id,
1681 project_connection_ids.iter().copied(),
1682 |connection_id| {
1683 session
1684 .peer
1685 .forward_send(session.connection_id, connection_id, request.clone())
1686 },
1687 );
1688 Ok(())
1689}
1690
1691async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1692 let project_id = ProjectId::from_proto(request.project_id);
1693 let project_connection_ids = session
1694 .db()
1695 .await
1696 .project_connection_ids(project_id, session.connection_id)
1697 .await?;
1698 broadcast(
1699 session.connection_id,
1700 project_connection_ids.iter().copied(),
1701 |connection_id| {
1702 session
1703 .peer
1704 .forward_send(session.connection_id, connection_id, request.clone())
1705 },
1706 );
1707 Ok(())
1708}
1709
1710async fn follow(
1711 request: proto::Follow,
1712 response: Response<proto::Follow>,
1713 session: Session,
1714) -> Result<()> {
1715 let project_id = ProjectId::from_proto(request.project_id);
1716 let leader_id = request
1717 .leader_id
1718 .ok_or_else(|| anyhow!("invalid leader id"))?
1719 .into();
1720 let follower_id = session.connection_id;
1721 {
1722 let project_connection_ids = session
1723 .db()
1724 .await
1725 .project_connection_ids(project_id, session.connection_id)
1726 .await?;
1727
1728 if !project_connection_ids.contains(&leader_id) {
1729 Err(anyhow!("no such peer"))?;
1730 }
1731 }
1732
1733 let mut response_payload = session
1734 .peer
1735 .forward_request(session.connection_id, leader_id, request)
1736 .await?;
1737 response_payload
1738 .views
1739 .retain(|view| view.leader_id != Some(follower_id.into()));
1740 response.send(response_payload)?;
1741 Ok(())
1742}
1743
1744async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1745 let project_id = ProjectId::from_proto(request.project_id);
1746 let leader_id = request
1747 .leader_id
1748 .ok_or_else(|| anyhow!("invalid leader id"))?
1749 .into();
1750 let project_connection_ids = session
1751 .db()
1752 .await
1753 .project_connection_ids(project_id, session.connection_id)
1754 .await?;
1755 if !project_connection_ids.contains(&leader_id) {
1756 Err(anyhow!("no such peer"))?;
1757 }
1758 session
1759 .peer
1760 .forward_send(session.connection_id, leader_id, request)?;
1761 Ok(())
1762}
1763
1764async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1765 let project_id = ProjectId::from_proto(request.project_id);
1766 let project_connection_ids = session
1767 .db
1768 .lock()
1769 .await
1770 .project_connection_ids(project_id, session.connection_id)
1771 .await?;
1772
1773 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1774 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1775 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1776 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1777 });
1778 for follower_peer_id in request.follower_ids.iter().copied() {
1779 let follower_connection_id = follower_peer_id.into();
1780 if project_connection_ids.contains(&follower_connection_id)
1781 && Some(follower_peer_id) != leader_id
1782 {
1783 session.peer.forward_send(
1784 session.connection_id,
1785 follower_connection_id,
1786 request.clone(),
1787 )?;
1788 }
1789 }
1790 Ok(())
1791}
1792
1793async fn get_users(
1794 request: proto::GetUsers,
1795 response: Response<proto::GetUsers>,
1796 session: Session,
1797) -> Result<()> {
1798 let user_ids = request
1799 .user_ids
1800 .into_iter()
1801 .map(UserId::from_proto)
1802 .collect();
1803 let users = session
1804 .db()
1805 .await
1806 .get_users_by_ids(user_ids)
1807 .await?
1808 .into_iter()
1809 .map(|user| proto::User {
1810 id: user.id.to_proto(),
1811 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1812 github_login: user.github_login,
1813 })
1814 .collect();
1815 response.send(proto::UsersResponse { users })?;
1816 Ok(())
1817}
1818
1819async fn fuzzy_search_users(
1820 request: proto::FuzzySearchUsers,
1821 response: Response<proto::FuzzySearchUsers>,
1822 session: Session,
1823) -> Result<()> {
1824 let query = request.query;
1825 let users = match query.len() {
1826 0 => vec![],
1827 1 | 2 => session
1828 .db()
1829 .await
1830 .get_user_by_github_account(&query, None)
1831 .await?
1832 .into_iter()
1833 .collect(),
1834 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1835 };
1836 let users = users
1837 .into_iter()
1838 .filter(|user| user.id != session.user_id)
1839 .map(|user| proto::User {
1840 id: user.id.to_proto(),
1841 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1842 github_login: user.github_login,
1843 })
1844 .collect();
1845 response.send(proto::UsersResponse { users })?;
1846 Ok(())
1847}
1848
1849async fn request_contact(
1850 request: proto::RequestContact,
1851 response: Response<proto::RequestContact>,
1852 session: Session,
1853) -> Result<()> {
1854 let requester_id = session.user_id;
1855 let responder_id = UserId::from_proto(request.responder_id);
1856 if requester_id == responder_id {
1857 return Err(anyhow!("cannot add yourself as a contact"))?;
1858 }
1859
1860 session
1861 .db()
1862 .await
1863 .send_contact_request(requester_id, responder_id)
1864 .await?;
1865
1866 // Update outgoing contact requests of requester
1867 let mut update = proto::UpdateContacts::default();
1868 update.outgoing_requests.push(responder_id.to_proto());
1869 for connection_id in session
1870 .connection_pool()
1871 .await
1872 .user_connection_ids(requester_id)
1873 {
1874 session.peer.send(connection_id, update.clone())?;
1875 }
1876
1877 // Update incoming contact requests of responder
1878 let mut update = proto::UpdateContacts::default();
1879 update
1880 .incoming_requests
1881 .push(proto::IncomingContactRequest {
1882 requester_id: requester_id.to_proto(),
1883 should_notify: true,
1884 });
1885 for connection_id in session
1886 .connection_pool()
1887 .await
1888 .user_connection_ids(responder_id)
1889 {
1890 session.peer.send(connection_id, update.clone())?;
1891 }
1892
1893 response.send(proto::Ack {})?;
1894 Ok(())
1895}
1896
1897async fn respond_to_contact_request(
1898 request: proto::RespondToContactRequest,
1899 response: Response<proto::RespondToContactRequest>,
1900 session: Session,
1901) -> Result<()> {
1902 let responder_id = session.user_id;
1903 let requester_id = UserId::from_proto(request.requester_id);
1904 let db = session.db().await;
1905 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1906 db.dismiss_contact_notification(responder_id, requester_id)
1907 .await?;
1908 } else {
1909 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1910
1911 db.respond_to_contact_request(responder_id, requester_id, accept)
1912 .await?;
1913 let requester_busy = db.is_user_busy(requester_id).await?;
1914 let responder_busy = db.is_user_busy(responder_id).await?;
1915
1916 let pool = session.connection_pool().await;
1917 // Update responder with new contact
1918 let mut update = proto::UpdateContacts::default();
1919 if accept {
1920 update
1921 .contacts
1922 .push(contact_for_user(requester_id, false, requester_busy, &pool));
1923 }
1924 update
1925 .remove_incoming_requests
1926 .push(requester_id.to_proto());
1927 for connection_id in pool.user_connection_ids(responder_id) {
1928 session.peer.send(connection_id, update.clone())?;
1929 }
1930
1931 // Update requester with new contact
1932 let mut update = proto::UpdateContacts::default();
1933 if accept {
1934 update
1935 .contacts
1936 .push(contact_for_user(responder_id, true, responder_busy, &pool));
1937 }
1938 update
1939 .remove_outgoing_requests
1940 .push(responder_id.to_proto());
1941 for connection_id in pool.user_connection_ids(requester_id) {
1942 session.peer.send(connection_id, update.clone())?;
1943 }
1944 }
1945
1946 response.send(proto::Ack {})?;
1947 Ok(())
1948}
1949
1950async fn remove_contact(
1951 request: proto::RemoveContact,
1952 response: Response<proto::RemoveContact>,
1953 session: Session,
1954) -> Result<()> {
1955 let requester_id = session.user_id;
1956 let responder_id = UserId::from_proto(request.user_id);
1957 let db = session.db().await;
1958 db.remove_contact(requester_id, responder_id).await?;
1959
1960 let pool = session.connection_pool().await;
1961 // Update outgoing contact requests of requester
1962 let mut update = proto::UpdateContacts::default();
1963 update
1964 .remove_outgoing_requests
1965 .push(responder_id.to_proto());
1966 for connection_id in pool.user_connection_ids(requester_id) {
1967 session.peer.send(connection_id, update.clone())?;
1968 }
1969
1970 // Update incoming contact requests of responder
1971 let mut update = proto::UpdateContacts::default();
1972 update
1973 .remove_incoming_requests
1974 .push(requester_id.to_proto());
1975 for connection_id in pool.user_connection_ids(responder_id) {
1976 session.peer.send(connection_id, update.clone())?;
1977 }
1978
1979 response.send(proto::Ack {})?;
1980 Ok(())
1981}
1982
1983async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1984 let project_id = ProjectId::from_proto(request.project_id);
1985 let project_connection_ids = session
1986 .db()
1987 .await
1988 .project_connection_ids(project_id, session.connection_id)
1989 .await?;
1990 broadcast(
1991 session.connection_id,
1992 project_connection_ids.iter().copied(),
1993 |connection_id| {
1994 session
1995 .peer
1996 .forward_send(session.connection_id, connection_id, request.clone())
1997 },
1998 );
1999 Ok(())
2000}
2001
2002async fn get_private_user_info(
2003 _request: proto::GetPrivateUserInfo,
2004 response: Response<proto::GetPrivateUserInfo>,
2005 session: Session,
2006) -> Result<()> {
2007 let metrics_id = session
2008 .db()
2009 .await
2010 .get_user_metrics_id(session.user_id)
2011 .await?;
2012 let user = session
2013 .db()
2014 .await
2015 .get_user_by_id(session.user_id)
2016 .await?
2017 .ok_or_else(|| anyhow!("user not found"))?;
2018 response.send(proto::GetPrivateUserInfoResponse {
2019 metrics_id,
2020 staff: user.admin,
2021 })?;
2022 Ok(())
2023}
2024
2025fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2026 match message {
2027 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2028 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2029 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2030 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2031 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2032 code: frame.code.into(),
2033 reason: frame.reason,
2034 })),
2035 }
2036}
2037
2038fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2039 match message {
2040 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2041 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2042 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2043 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2044 AxumMessage::Close(frame) => {
2045 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2046 code: frame.code.into(),
2047 reason: frame.reason,
2048 }))
2049 }
2050 }
2051}
2052
2053fn build_initial_contacts_update(
2054 contacts: Vec<db::Contact>,
2055 pool: &ConnectionPool,
2056) -> proto::UpdateContacts {
2057 let mut update = proto::UpdateContacts::default();
2058
2059 for contact in contacts {
2060 match contact {
2061 db::Contact::Accepted {
2062 user_id,
2063 should_notify,
2064 busy,
2065 } => {
2066 update
2067 .contacts
2068 .push(contact_for_user(user_id, should_notify, busy, &pool));
2069 }
2070 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2071 db::Contact::Incoming {
2072 user_id,
2073 should_notify,
2074 } => update
2075 .incoming_requests
2076 .push(proto::IncomingContactRequest {
2077 requester_id: user_id.to_proto(),
2078 should_notify,
2079 }),
2080 }
2081 }
2082
2083 update
2084}
2085
2086fn contact_for_user(
2087 user_id: UserId,
2088 should_notify: bool,
2089 busy: bool,
2090 pool: &ConnectionPool,
2091) -> proto::Contact {
2092 proto::Contact {
2093 user_id: user_id.to_proto(),
2094 online: pool.is_user_online(user_id),
2095 busy,
2096 should_notify,
2097 }
2098}
2099
2100fn room_updated(room: &proto::Room, peer: &Peer) {
2101 for participant in &room.participants {
2102 if let Some(peer_id) = participant
2103 .peer_id
2104 .ok_or_else(|| anyhow!("invalid participant peer id"))
2105 .trace_err()
2106 {
2107 peer.send(
2108 peer_id.into(),
2109 proto::RoomUpdated {
2110 room: Some(room.clone()),
2111 },
2112 )
2113 .trace_err();
2114 }
2115 }
2116}
2117
2118async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2119 let db = session.db().await;
2120 let contacts = db.get_contacts(user_id).await?;
2121 let busy = db.is_user_busy(user_id).await?;
2122
2123 let pool = session.connection_pool().await;
2124 let updated_contact = contact_for_user(user_id, false, busy, &pool);
2125 for contact in contacts {
2126 if let db::Contact::Accepted {
2127 user_id: contact_user_id,
2128 ..
2129 } = contact
2130 {
2131 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2132 session
2133 .peer
2134 .send(
2135 contact_conn_id,
2136 proto::UpdateContacts {
2137 contacts: vec![updated_contact.clone()],
2138 remove_contacts: Default::default(),
2139 incoming_requests: Default::default(),
2140 remove_incoming_requests: Default::default(),
2141 outgoing_requests: Default::default(),
2142 remove_outgoing_requests: Default::default(),
2143 },
2144 )
2145 .trace_err();
2146 }
2147 }
2148 }
2149 Ok(())
2150}
2151
2152async fn leave_room_for_session(session: &Session) -> Result<()> {
2153 let mut contacts_to_update = HashSet::default();
2154
2155 let room_id;
2156 let canceled_calls_to_user_ids;
2157 let live_kit_room;
2158 let delete_live_kit_room;
2159 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2160 contacts_to_update.insert(session.user_id);
2161
2162 for project in left_room.left_projects.values() {
2163 project_left(project, session);
2164 }
2165
2166 room_updated(&left_room.room, &session.peer);
2167 room_id = RoomId::from_proto(left_room.room.id);
2168 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2169 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2170 delete_live_kit_room = left_room.room.participants.is_empty();
2171 } else {
2172 return Ok(());
2173 }
2174
2175 {
2176 let pool = session.connection_pool().await;
2177 for canceled_user_id in canceled_calls_to_user_ids {
2178 for connection_id in pool.user_connection_ids(canceled_user_id) {
2179 session
2180 .peer
2181 .send(
2182 connection_id,
2183 proto::CallCanceled {
2184 room_id: room_id.to_proto(),
2185 },
2186 )
2187 .trace_err();
2188 }
2189 contacts_to_update.insert(canceled_user_id);
2190 }
2191 }
2192
2193 for contact_user_id in contacts_to_update {
2194 update_user_contacts(contact_user_id, &session).await?;
2195 }
2196
2197 if let Some(live_kit) = session.live_kit_client.as_ref() {
2198 live_kit
2199 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2200 .await
2201 .trace_err();
2202
2203 if delete_live_kit_room {
2204 live_kit.delete_room(live_kit_room).await.trace_err();
2205 }
2206 }
2207
2208 Ok(())
2209}
2210
2211fn project_left(project: &db::LeftProject, session: &Session) {
2212 for connection_id in &project.connection_ids {
2213 if project.host_user_id == session.user_id {
2214 session
2215 .peer
2216 .send(
2217 *connection_id,
2218 proto::UnshareProject {
2219 project_id: project.id.to_proto(),
2220 },
2221 )
2222 .trace_err();
2223 } else {
2224 session
2225 .peer
2226 .send(
2227 *connection_id,
2228 proto::RemoveProjectCollaborator {
2229 project_id: project.id.to_proto(),
2230 peer_id: Some(session.connection_id.into()),
2231 },
2232 )
2233 .trace_err();
2234 }
2235 }
2236}
2237
2238pub trait ResultExt {
2239 type Ok;
2240
2241 fn trace_err(self) -> Option<Self::Ok>;
2242}
2243
2244impl<T, E> ResultExt for Result<T, E>
2245where
2246 E: std::fmt::Debug,
2247{
2248 type Ok = T;
2249
2250 fn trace_err(self) -> Option<T> {
2251 match self {
2252 Ok(value) => Some(value),
2253 Err(error) => {
2254 tracing::error!("{:?}", error);
2255 None
2256 }
2257 }
2258 }
2259}