1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, Database, ProjectId, RoomId, ServerId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
38 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
39};
40use serde::{Serialize, Serializer};
41use std::{
42 any::TypeId,
43 fmt,
44 future::Future,
45 marker::PhantomData,
46 mem,
47 net::SocketAddr,
48 ops::{Deref, DerefMut},
49 rc::Rc,
50 sync::{
51 atomic::{AtomicBool, Ordering::SeqCst},
52 Arc,
53 },
54 time::Duration,
55};
56use tokio::sync::watch;
57use tower::ServiceBuilder;
58use tracing::{info_span, instrument, Instrument};
59
60pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(5);
61pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
62
63lazy_static! {
64 static ref METRIC_CONNECTIONS: IntGauge =
65 register_int_gauge!("connections", "number of connections").unwrap();
66 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
67 "shared_projects",
68 "number of open projects with one or more guests"
69 )
70 .unwrap();
71}
72
73type MessageHandler =
74 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
75
76struct Response<R> {
77 peer: Arc<Peer>,
78 receipt: Receipt<R>,
79 responded: Arc<AtomicBool>,
80}
81
82impl<R: RequestMessage> Response<R> {
83 fn send(self, payload: R::Response) -> Result<()> {
84 self.responded.store(true, SeqCst);
85 self.peer.respond(self.receipt, payload)?;
86 Ok(())
87 }
88}
89
90#[derive(Clone)]
91struct Session {
92 user_id: UserId,
93 connection_id: ConnectionId,
94 db: Arc<tokio::sync::Mutex<DbHandle>>,
95 peer: Arc<Peer>,
96 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
97 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
98}
99
100impl Session {
101 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
102 #[cfg(test)]
103 tokio::task::yield_now().await;
104 let guard = self.db.lock().await;
105 #[cfg(test)]
106 tokio::task::yield_now().await;
107 guard
108 }
109
110 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
111 #[cfg(test)]
112 tokio::task::yield_now().await;
113 let guard = self.connection_pool.lock();
114 ConnectionPoolGuard {
115 guard,
116 _not_send: PhantomData,
117 }
118 }
119}
120
121impl fmt::Debug for Session {
122 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123 f.debug_struct("Session")
124 .field("user_id", &self.user_id)
125 .field("connection_id", &self.connection_id)
126 .finish()
127 }
128}
129
130struct DbHandle(Arc<Database>);
131
132impl Deref for DbHandle {
133 type Target = Database;
134
135 fn deref(&self) -> &Self::Target {
136 self.0.as_ref()
137 }
138}
139
140pub struct Server {
141 id: parking_lot::Mutex<ServerId>,
142 peer: Arc<Peer>,
143 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
144 app_state: Arc<AppState>,
145 executor: Executor,
146 handlers: HashMap<TypeId, MessageHandler>,
147 teardown: watch::Sender<()>,
148}
149
150pub(crate) struct ConnectionPoolGuard<'a> {
151 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
152 _not_send: PhantomData<Rc<()>>,
153}
154
155#[derive(Serialize)]
156pub struct ServerSnapshot<'a> {
157 peer: &'a Peer,
158 #[serde(serialize_with = "serialize_deref")]
159 connection_pool: ConnectionPoolGuard<'a>,
160}
161
162pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
163where
164 S: Serializer,
165 T: Deref<Target = U>,
166 U: Serialize,
167{
168 Serialize::serialize(value.deref(), serializer)
169}
170
171impl Server {
172 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
173 let mut server = Self {
174 id: parking_lot::Mutex::new(id),
175 peer: Peer::new(id.0 as u32),
176 app_state,
177 executor,
178 connection_pool: Default::default(),
179 handlers: Default::default(),
180 teardown: watch::channel(()).0,
181 };
182
183 server
184 .add_request_handler(ping)
185 .add_request_handler(create_room)
186 .add_request_handler(join_room)
187 .add_request_handler(rejoin_room)
188 .add_message_handler(leave_room)
189 .add_request_handler(call)
190 .add_request_handler(cancel_call)
191 .add_message_handler(decline_call)
192 .add_request_handler(update_participant_location)
193 .add_request_handler(share_project)
194 .add_message_handler(unshare_project)
195 .add_request_handler(join_project)
196 .add_message_handler(leave_project)
197 .add_request_handler(update_project)
198 .add_request_handler(update_worktree)
199 .add_message_handler(start_language_server)
200 .add_message_handler(update_language_server)
201 .add_message_handler(update_diagnostic_summary)
202 .add_request_handler(forward_project_request::<proto::GetHover>)
203 .add_request_handler(forward_project_request::<proto::GetDefinition>)
204 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
205 .add_request_handler(forward_project_request::<proto::GetReferences>)
206 .add_request_handler(forward_project_request::<proto::SearchProject>)
207 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
208 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
209 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
210 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
211 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
212 .add_request_handler(forward_project_request::<proto::GetCompletions>)
213 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
214 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
215 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
216 .add_request_handler(forward_project_request::<proto::PrepareRename>)
217 .add_request_handler(forward_project_request::<proto::PerformRename>)
218 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
219 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
220 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
221 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
222 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
223 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
224 .add_message_handler(create_buffer_for_peer)
225 .add_request_handler(update_buffer)
226 .add_message_handler(update_buffer_file)
227 .add_message_handler(buffer_reloaded)
228 .add_message_handler(buffer_saved)
229 .add_request_handler(save_buffer)
230 .add_request_handler(get_users)
231 .add_request_handler(fuzzy_search_users)
232 .add_request_handler(request_contact)
233 .add_request_handler(remove_contact)
234 .add_request_handler(respond_to_contact_request)
235 .add_request_handler(follow)
236 .add_message_handler(unfollow)
237 .add_message_handler(update_followers)
238 .add_message_handler(update_diff_base)
239 .add_request_handler(get_private_user_info);
240
241 Arc::new(server)
242 }
243
244 pub async fn start(&self) -> Result<()> {
245 let server_id = *self.id.lock();
246 let app_state = self.app_state.clone();
247 let peer = self.peer.clone();
248 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
249 let pool = self.connection_pool.clone();
250 let live_kit_client = self.app_state.live_kit_client.clone();
251
252 let span = info_span!("start server");
253 self.executor.spawn_detached(
254 async move {
255 tracing::info!("waiting for cleanup timeout");
256 timeout.await;
257 tracing::info!("cleanup timeout expired, retrieving stale rooms");
258 if let Some(room_ids) = app_state
259 .db
260 .stale_room_ids(&app_state.config.zed_environment, server_id)
261 .await
262 .trace_err()
263 {
264 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
265 for room_id in room_ids {
266 let mut contacts_to_update = HashSet::default();
267 let mut canceled_calls_to_user_ids = Vec::new();
268 let mut live_kit_room = String::new();
269 let mut delete_live_kit_room = false;
270
271 if let Ok(mut refreshed_room) =
272 app_state.db.refresh_room(room_id, server_id).await
273 {
274 tracing::info!(
275 room_id = room_id.0,
276 new_participant_count = refreshed_room.room.participants.len(),
277 "refreshed room"
278 );
279 room_updated(&refreshed_room.room, &peer);
280 contacts_to_update
281 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
282 contacts_to_update
283 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
284 canceled_calls_to_user_ids =
285 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
286 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
287 delete_live_kit_room = refreshed_room.room.participants.is_empty();
288 }
289
290 {
291 let pool = pool.lock();
292 for canceled_user_id in canceled_calls_to_user_ids {
293 for connection_id in pool.user_connection_ids(canceled_user_id) {
294 peer.send(
295 connection_id,
296 proto::CallCanceled {
297 room_id: room_id.to_proto(),
298 },
299 )
300 .trace_err();
301 }
302 }
303 }
304
305 for user_id in contacts_to_update {
306 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
307 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
308 if let Some((busy, contacts)) = busy.zip(contacts) {
309 let pool = pool.lock();
310 let updated_contact = contact_for_user(user_id, false, busy, &pool);
311 for contact in contacts {
312 if let db::Contact::Accepted {
313 user_id: contact_user_id,
314 ..
315 } = contact
316 {
317 for contact_conn_id in
318 pool.user_connection_ids(contact_user_id)
319 {
320 peer.send(
321 contact_conn_id,
322 proto::UpdateContacts {
323 contacts: vec![updated_contact.clone()],
324 remove_contacts: Default::default(),
325 incoming_requests: Default::default(),
326 remove_incoming_requests: Default::default(),
327 outgoing_requests: Default::default(),
328 remove_outgoing_requests: Default::default(),
329 },
330 )
331 .trace_err();
332 }
333 }
334 }
335 }
336 }
337
338 if let Some(live_kit) = live_kit_client.as_ref() {
339 if delete_live_kit_room {
340 live_kit.delete_room(live_kit_room).await.trace_err();
341 }
342 }
343 }
344 }
345
346 app_state
347 .db
348 .delete_stale_servers(&app_state.config.zed_environment, server_id)
349 .await
350 .trace_err();
351 }
352 .instrument(span),
353 );
354 Ok(())
355 }
356
357 pub fn teardown(&self) {
358 self.peer.teardown();
359 self.connection_pool.lock().reset();
360 let _ = self.teardown.send(());
361 }
362
363 #[cfg(test)]
364 pub fn reset(&self, id: ServerId) {
365 self.teardown();
366 *self.id.lock() = id;
367 self.peer.reset(id.0 as u32);
368 }
369
370 #[cfg(test)]
371 pub fn id(&self) -> ServerId {
372 *self.id.lock()
373 }
374
375 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
376 where
377 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
378 Fut: 'static + Send + Future<Output = Result<()>>,
379 M: EnvelopedMessage,
380 {
381 let prev_handler = self.handlers.insert(
382 TypeId::of::<M>(),
383 Box::new(move |envelope, session| {
384 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
385 let span = info_span!(
386 "handle message",
387 payload_type = envelope.payload_type_name()
388 );
389 span.in_scope(|| {
390 tracing::info!(
391 payload_type = envelope.payload_type_name(),
392 "message received"
393 );
394 });
395 let future = (handler)(*envelope, session);
396 async move {
397 if let Err(error) = future.await {
398 tracing::error!(%error, "error handling message");
399 }
400 }
401 .instrument(span)
402 .boxed()
403 }),
404 );
405 if prev_handler.is_some() {
406 panic!("registered a handler for the same message twice");
407 }
408 self
409 }
410
411 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
412 where
413 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
414 Fut: 'static + Send + Future<Output = Result<()>>,
415 M: EnvelopedMessage,
416 {
417 self.add_handler(move |envelope, session| handler(envelope.payload, session));
418 self
419 }
420
421 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
422 where
423 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
424 Fut: Send + Future<Output = Result<()>>,
425 M: RequestMessage,
426 {
427 let handler = Arc::new(handler);
428 self.add_handler(move |envelope, session| {
429 let receipt = envelope.receipt();
430 let handler = handler.clone();
431 async move {
432 let peer = session.peer.clone();
433 let responded = Arc::new(AtomicBool::default());
434 let response = Response {
435 peer: peer.clone(),
436 responded: responded.clone(),
437 receipt,
438 };
439 match (handler)(envelope.payload, response, session).await {
440 Ok(()) => {
441 if responded.load(std::sync::atomic::Ordering::SeqCst) {
442 Ok(())
443 } else {
444 Err(anyhow!("handler did not send a response"))?
445 }
446 }
447 Err(error) => {
448 peer.respond_with_error(
449 receipt,
450 proto::Error {
451 message: error.to_string(),
452 },
453 )?;
454 Err(error)
455 }
456 }
457 }
458 })
459 }
460
461 pub fn handle_connection(
462 self: &Arc<Self>,
463 connection: Connection,
464 address: String,
465 user: User,
466 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
467 executor: Executor,
468 ) -> impl Future<Output = Result<()>> {
469 let this = self.clone();
470 let user_id = user.id;
471 let login = user.github_login;
472 let span = info_span!("handle connection", %user_id, %login, %address);
473 let mut teardown = self.teardown.subscribe();
474 async move {
475 let (connection_id, handle_io, mut incoming_rx) = this
476 .peer
477 .add_connection(connection, {
478 let executor = executor.clone();
479 move |duration| executor.sleep(duration)
480 });
481
482 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
483 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
484 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
485
486 if let Some(send_connection_id) = send_connection_id.take() {
487 let _ = send_connection_id.send(connection_id);
488 }
489
490 if !user.connected_once {
491 this.peer.send(connection_id, proto::ShowContacts {})?;
492 this.app_state.db.set_user_connected_once(user_id, true).await?;
493 }
494
495 let (contacts, invite_code) = future::try_join(
496 this.app_state.db.get_contacts(user_id),
497 this.app_state.db.get_invite_code_for_user(user_id)
498 ).await?;
499
500 {
501 let mut pool = this.connection_pool.lock();
502 pool.add_connection(connection_id, user_id, user.admin);
503 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
504
505 if let Some((code, count)) = invite_code {
506 this.peer.send(connection_id, proto::UpdateInviteInfo {
507 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
508 count: count as u32,
509 })?;
510 }
511 }
512
513 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
514 this.peer.send(connection_id, incoming_call)?;
515 }
516
517 let session = Session {
518 user_id,
519 connection_id,
520 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
521 peer: this.peer.clone(),
522 connection_pool: this.connection_pool.clone(),
523 live_kit_client: this.app_state.live_kit_client.clone()
524 };
525 update_user_contacts(user_id, &session).await?;
526
527 let handle_io = handle_io.fuse();
528 futures::pin_mut!(handle_io);
529
530 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
531 // This prevents deadlocks when e.g., client A performs a request to client B and
532 // client B performs a request to client A. If both clients stop processing further
533 // messages until their respective request completes, they won't have a chance to
534 // respond to the other client's request and cause a deadlock.
535 //
536 // This arrangement ensures we will attempt to process earlier messages first, but fall
537 // back to processing messages arrived later in the spirit of making progress.
538 let mut foreground_message_handlers = FuturesUnordered::new();
539 loop {
540 let next_message = incoming_rx.next().fuse();
541 futures::pin_mut!(next_message);
542 futures::select_biased! {
543 _ = teardown.changed().fuse() => return Ok(()),
544 result = handle_io => {
545 if let Err(error) = result {
546 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
547 }
548 break;
549 }
550 _ = foreground_message_handlers.next() => {}
551 message = next_message => {
552 if let Some(message) = message {
553 let type_name = message.payload_type_name();
554 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
555 let span_enter = span.enter();
556 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
557 let is_background = message.is_background();
558 let handle_message = (handler)(message, session.clone());
559 drop(span_enter);
560
561 let handle_message = handle_message.instrument(span);
562 if is_background {
563 executor.spawn_detached(handle_message);
564 } else {
565 foreground_message_handlers.push(handle_message);
566 }
567 } else {
568 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
569 }
570 } else {
571 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
572 break;
573 }
574 }
575 }
576 }
577
578 drop(foreground_message_handlers);
579 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
580 if let Err(error) = sign_out(session, teardown, executor).await {
581 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
582 }
583
584 Ok(())
585 }.instrument(span)
586 }
587
588 pub async fn invite_code_redeemed(
589 self: &Arc<Self>,
590 inviter_id: UserId,
591 invitee_id: UserId,
592 ) -> Result<()> {
593 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
594 if let Some(code) = &user.invite_code {
595 let pool = self.connection_pool.lock();
596 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
597 for connection_id in pool.user_connection_ids(inviter_id) {
598 self.peer.send(
599 connection_id,
600 proto::UpdateContacts {
601 contacts: vec![invitee_contact.clone()],
602 ..Default::default()
603 },
604 )?;
605 self.peer.send(
606 connection_id,
607 proto::UpdateInviteInfo {
608 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
609 count: user.invite_count as u32,
610 },
611 )?;
612 }
613 }
614 }
615 Ok(())
616 }
617
618 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
619 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
620 if let Some(invite_code) = &user.invite_code {
621 let pool = self.connection_pool.lock();
622 for connection_id in pool.user_connection_ids(user_id) {
623 self.peer.send(
624 connection_id,
625 proto::UpdateInviteInfo {
626 url: format!(
627 "{}{}",
628 self.app_state.config.invite_link_prefix, invite_code
629 ),
630 count: user.invite_count as u32,
631 },
632 )?;
633 }
634 }
635 }
636 Ok(())
637 }
638
639 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
640 ServerSnapshot {
641 connection_pool: ConnectionPoolGuard {
642 guard: self.connection_pool.lock(),
643 _not_send: PhantomData,
644 },
645 peer: &self.peer,
646 }
647 }
648}
649
650impl<'a> Deref for ConnectionPoolGuard<'a> {
651 type Target = ConnectionPool;
652
653 fn deref(&self) -> &Self::Target {
654 &*self.guard
655 }
656}
657
658impl<'a> DerefMut for ConnectionPoolGuard<'a> {
659 fn deref_mut(&mut self) -> &mut Self::Target {
660 &mut *self.guard
661 }
662}
663
664impl<'a> Drop for ConnectionPoolGuard<'a> {
665 fn drop(&mut self) {
666 #[cfg(test)]
667 self.check_invariants();
668 }
669}
670
671fn broadcast<F>(
672 sender_id: ConnectionId,
673 receiver_ids: impl IntoIterator<Item = ConnectionId>,
674 mut f: F,
675) where
676 F: FnMut(ConnectionId) -> anyhow::Result<()>,
677{
678 for receiver_id in receiver_ids {
679 if receiver_id != sender_id {
680 f(receiver_id).trace_err();
681 }
682 }
683}
684
685lazy_static! {
686 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
687}
688
689pub struct ProtocolVersion(u32);
690
691impl Header for ProtocolVersion {
692 fn name() -> &'static HeaderName {
693 &ZED_PROTOCOL_VERSION
694 }
695
696 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
697 where
698 Self: Sized,
699 I: Iterator<Item = &'i axum::http::HeaderValue>,
700 {
701 let version = values
702 .next()
703 .ok_or_else(axum::headers::Error::invalid)?
704 .to_str()
705 .map_err(|_| axum::headers::Error::invalid())?
706 .parse()
707 .map_err(|_| axum::headers::Error::invalid())?;
708 Ok(Self(version))
709 }
710
711 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
712 values.extend([self.0.to_string().parse().unwrap()]);
713 }
714}
715
716pub fn routes(server: Arc<Server>) -> Router<Body> {
717 Router::new()
718 .route("/rpc", get(handle_websocket_request))
719 .layer(
720 ServiceBuilder::new()
721 .layer(Extension(server.app_state.clone()))
722 .layer(middleware::from_fn(auth::validate_header)),
723 )
724 .route("/metrics", get(handle_metrics))
725 .layer(Extension(server))
726}
727
728pub async fn handle_websocket_request(
729 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
730 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
731 Extension(server): Extension<Arc<Server>>,
732 Extension(user): Extension<User>,
733 ws: WebSocketUpgrade,
734) -> axum::response::Response {
735 if protocol_version != rpc::PROTOCOL_VERSION {
736 return (
737 StatusCode::UPGRADE_REQUIRED,
738 "client must be upgraded".to_string(),
739 )
740 .into_response();
741 }
742 let socket_address = socket_address.to_string();
743 ws.on_upgrade(move |socket| {
744 use util::ResultExt;
745 let socket = socket
746 .map_ok(to_tungstenite_message)
747 .err_into()
748 .with(|message| async move { Ok(to_axum_message(message)) });
749 let connection = Connection::new(Box::pin(socket));
750 async move {
751 server
752 .handle_connection(connection, socket_address, user, None, Executor::Production)
753 .await
754 .log_err();
755 }
756 })
757}
758
759pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
760 let connections = server
761 .connection_pool
762 .lock()
763 .connections()
764 .filter(|connection| !connection.admin)
765 .count();
766
767 METRIC_CONNECTIONS.set(connections as _);
768
769 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
770 METRIC_SHARED_PROJECTS.set(shared_projects as _);
771
772 let encoder = prometheus::TextEncoder::new();
773 let metric_families = prometheus::gather();
774 let encoded_metrics = encoder
775 .encode_to_string(&metric_families)
776 .map_err(|err| anyhow!("{}", err))?;
777 Ok(encoded_metrics)
778}
779
780#[instrument(err, skip(executor))]
781async fn sign_out(
782 session: Session,
783 mut teardown: watch::Receiver<()>,
784 executor: Executor,
785) -> Result<()> {
786 session.peer.disconnect(session.connection_id);
787 session
788 .connection_pool()
789 .await
790 .remove_connection(session.connection_id)?;
791
792 session
793 .db()
794 .await
795 .connection_lost(session.connection_id)
796 .await
797 .trace_err();
798
799 futures::select_biased! {
800 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
801 leave_room_for_session(&session).await.trace_err();
802
803 if !session
804 .connection_pool()
805 .await
806 .is_user_online(session.user_id)
807 {
808 let db = session.db().await;
809 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
810 room_updated(&room, &session.peer);
811 }
812 }
813 update_user_contacts(session.user_id, &session).await?;
814 }
815 _ = teardown.changed().fuse() => {}
816 }
817
818 Ok(())
819}
820
821async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
822 response.send(proto::Ack {})?;
823 Ok(())
824}
825
826async fn create_room(
827 _request: proto::CreateRoom,
828 response: Response<proto::CreateRoom>,
829 session: Session,
830) -> Result<()> {
831 let live_kit_room = nanoid::nanoid!(30);
832 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
833 if let Some(_) = live_kit
834 .create_room(live_kit_room.clone())
835 .await
836 .trace_err()
837 {
838 if let Some(token) = live_kit
839 .room_token(&live_kit_room, &session.user_id.to_string())
840 .trace_err()
841 {
842 Some(proto::LiveKitConnectionInfo {
843 server_url: live_kit.url().into(),
844 token,
845 })
846 } else {
847 None
848 }
849 } else {
850 None
851 }
852 } else {
853 None
854 };
855
856 {
857 let room = session
858 .db()
859 .await
860 .create_room(session.user_id, session.connection_id, &live_kit_room)
861 .await?;
862
863 response.send(proto::CreateRoomResponse {
864 room: Some(room.clone()),
865 live_kit_connection_info,
866 })?;
867 }
868
869 update_user_contacts(session.user_id, &session).await?;
870 Ok(())
871}
872
873async fn join_room(
874 request: proto::JoinRoom,
875 response: Response<proto::JoinRoom>,
876 session: Session,
877) -> Result<()> {
878 let room_id = RoomId::from_proto(request.id);
879 let room = {
880 let room = session
881 .db()
882 .await
883 .join_room(room_id, session.user_id, session.connection_id)
884 .await?;
885 room_updated(&room, &session.peer);
886 room.clone()
887 };
888
889 for connection_id in session
890 .connection_pool()
891 .await
892 .user_connection_ids(session.user_id)
893 {
894 session
895 .peer
896 .send(
897 connection_id,
898 proto::CallCanceled {
899 room_id: room_id.to_proto(),
900 },
901 )
902 .trace_err();
903 }
904
905 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
906 if let Some(token) = live_kit
907 .room_token(&room.live_kit_room, &session.user_id.to_string())
908 .trace_err()
909 {
910 Some(proto::LiveKitConnectionInfo {
911 server_url: live_kit.url().into(),
912 token,
913 })
914 } else {
915 None
916 }
917 } else {
918 None
919 };
920
921 response.send(proto::JoinRoomResponse {
922 room: Some(room),
923 live_kit_connection_info,
924 })?;
925
926 update_user_contacts(session.user_id, &session).await?;
927 Ok(())
928}
929
930async fn rejoin_room(
931 request: proto::RejoinRoom,
932 response: Response<proto::RejoinRoom>,
933 session: Session,
934) -> Result<()> {
935 {
936 let mut rejoined_room = session
937 .db()
938 .await
939 .rejoin_room(request, session.user_id, session.connection_id)
940 .await?;
941
942 response.send(proto::RejoinRoomResponse {
943 room: Some(rejoined_room.room.clone()),
944 reshared_projects: rejoined_room
945 .reshared_projects
946 .iter()
947 .map(|project| proto::ResharedProject {
948 id: project.id.to_proto(),
949 collaborators: project
950 .collaborators
951 .iter()
952 .map(|collaborator| collaborator.to_proto())
953 .collect(),
954 })
955 .collect(),
956 rejoined_projects: rejoined_room
957 .rejoined_projects
958 .iter()
959 .map(|rejoined_project| proto::RejoinedProject {
960 id: rejoined_project.id.to_proto(),
961 worktrees: rejoined_project
962 .worktrees
963 .iter()
964 .map(|worktree| proto::WorktreeMetadata {
965 id: worktree.id,
966 root_name: worktree.root_name.clone(),
967 visible: worktree.visible,
968 abs_path: worktree.abs_path.clone(),
969 })
970 .collect(),
971 collaborators: rejoined_project
972 .collaborators
973 .iter()
974 .map(|collaborator| collaborator.to_proto())
975 .collect(),
976 language_servers: rejoined_project.language_servers.clone(),
977 })
978 .collect(),
979 })?;
980 room_updated(&rejoined_room.room, &session.peer);
981
982 for project in &rejoined_room.reshared_projects {
983 for collaborator in &project.collaborators {
984 session
985 .peer
986 .send(
987 collaborator.connection_id,
988 proto::UpdateProjectCollaborator {
989 project_id: project.id.to_proto(),
990 old_peer_id: Some(project.old_connection_id.into()),
991 new_peer_id: Some(session.connection_id.into()),
992 },
993 )
994 .trace_err();
995 }
996
997 broadcast(
998 session.connection_id,
999 project
1000 .collaborators
1001 .iter()
1002 .map(|collaborator| collaborator.connection_id),
1003 |connection_id| {
1004 session.peer.forward_send(
1005 session.connection_id,
1006 connection_id,
1007 proto::UpdateProject {
1008 project_id: project.id.to_proto(),
1009 worktrees: project.worktrees.clone(),
1010 },
1011 )
1012 },
1013 );
1014 }
1015
1016 for project in &rejoined_room.rejoined_projects {
1017 for collaborator in &project.collaborators {
1018 session
1019 .peer
1020 .send(
1021 collaborator.connection_id,
1022 proto::UpdateProjectCollaborator {
1023 project_id: project.id.to_proto(),
1024 old_peer_id: Some(project.old_connection_id.into()),
1025 new_peer_id: Some(session.connection_id.into()),
1026 },
1027 )
1028 .trace_err();
1029 }
1030 }
1031
1032 for project in &mut rejoined_room.rejoined_projects {
1033 for worktree in mem::take(&mut project.worktrees) {
1034 #[cfg(any(test, feature = "test-support"))]
1035 const MAX_CHUNK_SIZE: usize = 2;
1036 #[cfg(not(any(test, feature = "test-support")))]
1037 const MAX_CHUNK_SIZE: usize = 256;
1038
1039 // Stream this worktree's entries.
1040 let message = proto::UpdateWorktree {
1041 project_id: project.id.to_proto(),
1042 worktree_id: worktree.id,
1043 abs_path: worktree.abs_path.clone(),
1044 root_name: worktree.root_name,
1045 updated_entries: worktree.updated_entries,
1046 removed_entries: worktree.removed_entries,
1047 scan_id: worktree.scan_id,
1048 is_last_update: worktree.is_complete,
1049 };
1050 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1051 session.peer.send(session.connection_id, update.clone())?;
1052 }
1053
1054 // Stream this worktree's diagnostics.
1055 for summary in worktree.diagnostic_summaries {
1056 session.peer.send(
1057 session.connection_id,
1058 proto::UpdateDiagnosticSummary {
1059 project_id: project.id.to_proto(),
1060 worktree_id: worktree.id,
1061 summary: Some(summary),
1062 },
1063 )?;
1064 }
1065 }
1066
1067 for language_server in &project.language_servers {
1068 session.peer.send(
1069 session.connection_id,
1070 proto::UpdateLanguageServer {
1071 project_id: project.id.to_proto(),
1072 language_server_id: language_server.id,
1073 variant: Some(
1074 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1075 proto::LspDiskBasedDiagnosticsUpdated {},
1076 ),
1077 ),
1078 },
1079 )?;
1080 }
1081 }
1082 }
1083
1084 update_user_contacts(session.user_id, &session).await?;
1085 Ok(())
1086}
1087
1088async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
1089 leave_room_for_session(&session).await
1090}
1091
1092async fn call(
1093 request: proto::Call,
1094 response: Response<proto::Call>,
1095 session: Session,
1096) -> Result<()> {
1097 let room_id = RoomId::from_proto(request.room_id);
1098 let calling_user_id = session.user_id;
1099 let calling_connection_id = session.connection_id;
1100 let called_user_id = UserId::from_proto(request.called_user_id);
1101 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1102 if !session
1103 .db()
1104 .await
1105 .has_contact(calling_user_id, called_user_id)
1106 .await?
1107 {
1108 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1109 }
1110
1111 let incoming_call = {
1112 let (room, incoming_call) = &mut *session
1113 .db()
1114 .await
1115 .call(
1116 room_id,
1117 calling_user_id,
1118 calling_connection_id,
1119 called_user_id,
1120 initial_project_id,
1121 )
1122 .await?;
1123 room_updated(&room, &session.peer);
1124 mem::take(incoming_call)
1125 };
1126 update_user_contacts(called_user_id, &session).await?;
1127
1128 let mut calls = session
1129 .connection_pool()
1130 .await
1131 .user_connection_ids(called_user_id)
1132 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1133 .collect::<FuturesUnordered<_>>();
1134
1135 while let Some(call_response) = calls.next().await {
1136 match call_response.as_ref() {
1137 Ok(_) => {
1138 response.send(proto::Ack {})?;
1139 return Ok(());
1140 }
1141 Err(_) => {
1142 call_response.trace_err();
1143 }
1144 }
1145 }
1146
1147 {
1148 let room = session
1149 .db()
1150 .await
1151 .call_failed(room_id, called_user_id)
1152 .await?;
1153 room_updated(&room, &session.peer);
1154 }
1155 update_user_contacts(called_user_id, &session).await?;
1156
1157 Err(anyhow!("failed to ring user"))?
1158}
1159
1160async fn cancel_call(
1161 request: proto::CancelCall,
1162 response: Response<proto::CancelCall>,
1163 session: Session,
1164) -> Result<()> {
1165 let called_user_id = UserId::from_proto(request.called_user_id);
1166 let room_id = RoomId::from_proto(request.room_id);
1167 {
1168 let room = session
1169 .db()
1170 .await
1171 .cancel_call(room_id, session.connection_id, called_user_id)
1172 .await?;
1173 room_updated(&room, &session.peer);
1174 }
1175
1176 for connection_id in session
1177 .connection_pool()
1178 .await
1179 .user_connection_ids(called_user_id)
1180 {
1181 session
1182 .peer
1183 .send(
1184 connection_id,
1185 proto::CallCanceled {
1186 room_id: room_id.to_proto(),
1187 },
1188 )
1189 .trace_err();
1190 }
1191 response.send(proto::Ack {})?;
1192
1193 update_user_contacts(called_user_id, &session).await?;
1194 Ok(())
1195}
1196
1197async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1198 let room_id = RoomId::from_proto(message.room_id);
1199 {
1200 let room = session
1201 .db()
1202 .await
1203 .decline_call(Some(room_id), session.user_id)
1204 .await?
1205 .ok_or_else(|| anyhow!("failed to decline call"))?;
1206 room_updated(&room, &session.peer);
1207 }
1208
1209 for connection_id in session
1210 .connection_pool()
1211 .await
1212 .user_connection_ids(session.user_id)
1213 {
1214 session
1215 .peer
1216 .send(
1217 connection_id,
1218 proto::CallCanceled {
1219 room_id: room_id.to_proto(),
1220 },
1221 )
1222 .trace_err();
1223 }
1224 update_user_contacts(session.user_id, &session).await?;
1225 Ok(())
1226}
1227
1228async fn update_participant_location(
1229 request: proto::UpdateParticipantLocation,
1230 response: Response<proto::UpdateParticipantLocation>,
1231 session: Session,
1232) -> Result<()> {
1233 let room_id = RoomId::from_proto(request.room_id);
1234 let location = request
1235 .location
1236 .ok_or_else(|| anyhow!("invalid location"))?;
1237 let room = session
1238 .db()
1239 .await
1240 .update_room_participant_location(room_id, session.connection_id, location)
1241 .await?;
1242 room_updated(&room, &session.peer);
1243 response.send(proto::Ack {})?;
1244 Ok(())
1245}
1246
1247async fn share_project(
1248 request: proto::ShareProject,
1249 response: Response<proto::ShareProject>,
1250 session: Session,
1251) -> Result<()> {
1252 let (project_id, room) = &*session
1253 .db()
1254 .await
1255 .share_project(
1256 RoomId::from_proto(request.room_id),
1257 session.connection_id,
1258 &request.worktrees,
1259 )
1260 .await?;
1261 response.send(proto::ShareProjectResponse {
1262 project_id: project_id.to_proto(),
1263 })?;
1264 room_updated(&room, &session.peer);
1265
1266 Ok(())
1267}
1268
1269async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1270 let project_id = ProjectId::from_proto(message.project_id);
1271
1272 let (room, guest_connection_ids) = &*session
1273 .db()
1274 .await
1275 .unshare_project(project_id, session.connection_id)
1276 .await?;
1277
1278 broadcast(
1279 session.connection_id,
1280 guest_connection_ids.iter().copied(),
1281 |conn_id| session.peer.send(conn_id, message.clone()),
1282 );
1283 room_updated(&room, &session.peer);
1284
1285 Ok(())
1286}
1287
1288async fn join_project(
1289 request: proto::JoinProject,
1290 response: Response<proto::JoinProject>,
1291 session: Session,
1292) -> Result<()> {
1293 let project_id = ProjectId::from_proto(request.project_id);
1294 let guest_user_id = session.user_id;
1295
1296 tracing::info!(%project_id, "join project");
1297
1298 let (project, replica_id) = &mut *session
1299 .db()
1300 .await
1301 .join_project(project_id, session.connection_id)
1302 .await?;
1303
1304 let collaborators = project
1305 .collaborators
1306 .iter()
1307 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1308 .map(|collaborator| collaborator.to_proto())
1309 .collect::<Vec<_>>();
1310 let worktrees = project
1311 .worktrees
1312 .iter()
1313 .map(|(id, worktree)| proto::WorktreeMetadata {
1314 id: *id,
1315 root_name: worktree.root_name.clone(),
1316 visible: worktree.visible,
1317 abs_path: worktree.abs_path.clone(),
1318 })
1319 .collect::<Vec<_>>();
1320
1321 for collaborator in &collaborators {
1322 session
1323 .peer
1324 .send(
1325 collaborator.peer_id.unwrap().into(),
1326 proto::AddProjectCollaborator {
1327 project_id: project_id.to_proto(),
1328 collaborator: Some(proto::Collaborator {
1329 peer_id: Some(session.connection_id.into()),
1330 replica_id: replica_id.0 as u32,
1331 user_id: guest_user_id.to_proto(),
1332 }),
1333 },
1334 )
1335 .trace_err();
1336 }
1337
1338 // First, we send the metadata associated with each worktree.
1339 response.send(proto::JoinProjectResponse {
1340 worktrees: worktrees.clone(),
1341 replica_id: replica_id.0 as u32,
1342 collaborators: collaborators.clone(),
1343 language_servers: project.language_servers.clone(),
1344 })?;
1345
1346 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1347 #[cfg(any(test, feature = "test-support"))]
1348 const MAX_CHUNK_SIZE: usize = 2;
1349 #[cfg(not(any(test, feature = "test-support")))]
1350 const MAX_CHUNK_SIZE: usize = 256;
1351
1352 // Stream this worktree's entries.
1353 let message = proto::UpdateWorktree {
1354 project_id: project_id.to_proto(),
1355 worktree_id,
1356 abs_path: worktree.abs_path.clone(),
1357 root_name: worktree.root_name,
1358 updated_entries: worktree.entries,
1359 removed_entries: Default::default(),
1360 scan_id: worktree.scan_id,
1361 is_last_update: worktree.is_complete,
1362 };
1363 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1364 session.peer.send(session.connection_id, update.clone())?;
1365 }
1366
1367 // Stream this worktree's diagnostics.
1368 for summary in worktree.diagnostic_summaries {
1369 session.peer.send(
1370 session.connection_id,
1371 proto::UpdateDiagnosticSummary {
1372 project_id: project_id.to_proto(),
1373 worktree_id: worktree.id,
1374 summary: Some(summary),
1375 },
1376 )?;
1377 }
1378 }
1379
1380 for language_server in &project.language_servers {
1381 session.peer.send(
1382 session.connection_id,
1383 proto::UpdateLanguageServer {
1384 project_id: project_id.to_proto(),
1385 language_server_id: language_server.id,
1386 variant: Some(
1387 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1388 proto::LspDiskBasedDiagnosticsUpdated {},
1389 ),
1390 ),
1391 },
1392 )?;
1393 }
1394
1395 Ok(())
1396}
1397
1398async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1399 let sender_id = session.connection_id;
1400 let project_id = ProjectId::from_proto(request.project_id);
1401
1402 let project = session
1403 .db()
1404 .await
1405 .leave_project(project_id, sender_id)
1406 .await?;
1407 tracing::info!(
1408 %project_id,
1409 host_user_id = %project.host_user_id,
1410 host_connection_id = %project.host_connection_id,
1411 "leave project"
1412 );
1413 project_left(&project, &session);
1414
1415 Ok(())
1416}
1417
1418async fn update_project(
1419 request: proto::UpdateProject,
1420 response: Response<proto::UpdateProject>,
1421 session: Session,
1422) -> Result<()> {
1423 let project_id = ProjectId::from_proto(request.project_id);
1424 let (room, guest_connection_ids) = &*session
1425 .db()
1426 .await
1427 .update_project(project_id, session.connection_id, &request.worktrees)
1428 .await?;
1429 broadcast(
1430 session.connection_id,
1431 guest_connection_ids.iter().copied(),
1432 |connection_id| {
1433 session
1434 .peer
1435 .forward_send(session.connection_id, connection_id, request.clone())
1436 },
1437 );
1438 room_updated(&room, &session.peer);
1439 response.send(proto::Ack {})?;
1440
1441 Ok(())
1442}
1443
1444async fn update_worktree(
1445 request: proto::UpdateWorktree,
1446 response: Response<proto::UpdateWorktree>,
1447 session: Session,
1448) -> Result<()> {
1449 let guest_connection_ids = session
1450 .db()
1451 .await
1452 .update_worktree(&request, session.connection_id)
1453 .await?;
1454
1455 broadcast(
1456 session.connection_id,
1457 guest_connection_ids.iter().copied(),
1458 |connection_id| {
1459 session
1460 .peer
1461 .forward_send(session.connection_id, connection_id, request.clone())
1462 },
1463 );
1464 response.send(proto::Ack {})?;
1465 Ok(())
1466}
1467
1468async fn update_diagnostic_summary(
1469 message: proto::UpdateDiagnosticSummary,
1470 session: Session,
1471) -> Result<()> {
1472 let guest_connection_ids = session
1473 .db()
1474 .await
1475 .update_diagnostic_summary(&message, session.connection_id)
1476 .await?;
1477
1478 broadcast(
1479 session.connection_id,
1480 guest_connection_ids.iter().copied(),
1481 |connection_id| {
1482 session
1483 .peer
1484 .forward_send(session.connection_id, connection_id, message.clone())
1485 },
1486 );
1487
1488 Ok(())
1489}
1490
1491async fn start_language_server(
1492 request: proto::StartLanguageServer,
1493 session: Session,
1494) -> Result<()> {
1495 let guest_connection_ids = session
1496 .db()
1497 .await
1498 .start_language_server(&request, session.connection_id)
1499 .await?;
1500
1501 broadcast(
1502 session.connection_id,
1503 guest_connection_ids.iter().copied(),
1504 |connection_id| {
1505 session
1506 .peer
1507 .forward_send(session.connection_id, connection_id, request.clone())
1508 },
1509 );
1510 Ok(())
1511}
1512
1513async fn update_language_server(
1514 request: proto::UpdateLanguageServer,
1515 session: Session,
1516) -> Result<()> {
1517 let project_id = ProjectId::from_proto(request.project_id);
1518 let project_connection_ids = session
1519 .db()
1520 .await
1521 .project_connection_ids(project_id, session.connection_id)
1522 .await?;
1523 broadcast(
1524 session.connection_id,
1525 project_connection_ids.iter().copied(),
1526 |connection_id| {
1527 session
1528 .peer
1529 .forward_send(session.connection_id, connection_id, request.clone())
1530 },
1531 );
1532 Ok(())
1533}
1534
1535async fn forward_project_request<T>(
1536 request: T,
1537 response: Response<T>,
1538 session: Session,
1539) -> Result<()>
1540where
1541 T: EntityMessage + RequestMessage,
1542{
1543 let project_id = ProjectId::from_proto(request.remote_entity_id());
1544 let host_connection_id = {
1545 let collaborators = session
1546 .db()
1547 .await
1548 .project_collaborators(project_id, session.connection_id)
1549 .await?;
1550 collaborators
1551 .iter()
1552 .find(|collaborator| collaborator.is_host)
1553 .ok_or_else(|| anyhow!("host not found"))?
1554 .connection_id
1555 };
1556
1557 let payload = session
1558 .peer
1559 .forward_request(session.connection_id, host_connection_id, request)
1560 .await?;
1561
1562 response.send(payload)?;
1563 Ok(())
1564}
1565
1566async fn save_buffer(
1567 request: proto::SaveBuffer,
1568 response: Response<proto::SaveBuffer>,
1569 session: Session,
1570) -> Result<()> {
1571 let project_id = ProjectId::from_proto(request.project_id);
1572 let host_connection_id = {
1573 let collaborators = session
1574 .db()
1575 .await
1576 .project_collaborators(project_id, session.connection_id)
1577 .await?;
1578 collaborators
1579 .iter()
1580 .find(|collaborator| collaborator.is_host)
1581 .ok_or_else(|| anyhow!("host not found"))?
1582 .connection_id
1583 };
1584 let response_payload = session
1585 .peer
1586 .forward_request(session.connection_id, host_connection_id, request.clone())
1587 .await?;
1588
1589 let mut collaborators = session
1590 .db()
1591 .await
1592 .project_collaborators(project_id, session.connection_id)
1593 .await?;
1594 collaborators.retain(|collaborator| collaborator.connection_id != session.connection_id);
1595 let project_connection_ids = collaborators
1596 .iter()
1597 .map(|collaborator| collaborator.connection_id);
1598 broadcast(host_connection_id, project_connection_ids, |conn_id| {
1599 session
1600 .peer
1601 .forward_send(host_connection_id, conn_id, response_payload.clone())
1602 });
1603 response.send(response_payload)?;
1604 Ok(())
1605}
1606
1607async fn create_buffer_for_peer(
1608 request: proto::CreateBufferForPeer,
1609 session: Session,
1610) -> Result<()> {
1611 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1612 session
1613 .peer
1614 .forward_send(session.connection_id, peer_id.into(), request)?;
1615 Ok(())
1616}
1617
1618async fn update_buffer(
1619 request: proto::UpdateBuffer,
1620 response: Response<proto::UpdateBuffer>,
1621 session: Session,
1622) -> Result<()> {
1623 let project_id = ProjectId::from_proto(request.project_id);
1624 let project_connection_ids = session
1625 .db()
1626 .await
1627 .project_connection_ids(project_id, session.connection_id)
1628 .await?;
1629
1630 broadcast(
1631 session.connection_id,
1632 project_connection_ids.iter().copied(),
1633 |connection_id| {
1634 session
1635 .peer
1636 .forward_send(session.connection_id, connection_id, request.clone())
1637 },
1638 );
1639 response.send(proto::Ack {})?;
1640 Ok(())
1641}
1642
1643async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1644 let project_id = ProjectId::from_proto(request.project_id);
1645 let project_connection_ids = session
1646 .db()
1647 .await
1648 .project_connection_ids(project_id, session.connection_id)
1649 .await?;
1650
1651 broadcast(
1652 session.connection_id,
1653 project_connection_ids.iter().copied(),
1654 |connection_id| {
1655 session
1656 .peer
1657 .forward_send(session.connection_id, connection_id, request.clone())
1658 },
1659 );
1660 Ok(())
1661}
1662
1663async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1664 let project_id = ProjectId::from_proto(request.project_id);
1665 let project_connection_ids = session
1666 .db()
1667 .await
1668 .project_connection_ids(project_id, session.connection_id)
1669 .await?;
1670 broadcast(
1671 session.connection_id,
1672 project_connection_ids.iter().copied(),
1673 |connection_id| {
1674 session
1675 .peer
1676 .forward_send(session.connection_id, connection_id, request.clone())
1677 },
1678 );
1679 Ok(())
1680}
1681
1682async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1683 let project_id = ProjectId::from_proto(request.project_id);
1684 let project_connection_ids = session
1685 .db()
1686 .await
1687 .project_connection_ids(project_id, session.connection_id)
1688 .await?;
1689 broadcast(
1690 session.connection_id,
1691 project_connection_ids.iter().copied(),
1692 |connection_id| {
1693 session
1694 .peer
1695 .forward_send(session.connection_id, connection_id, request.clone())
1696 },
1697 );
1698 Ok(())
1699}
1700
1701async fn follow(
1702 request: proto::Follow,
1703 response: Response<proto::Follow>,
1704 session: Session,
1705) -> Result<()> {
1706 let project_id = ProjectId::from_proto(request.project_id);
1707 let leader_id = request
1708 .leader_id
1709 .ok_or_else(|| anyhow!("invalid leader id"))?
1710 .into();
1711 let follower_id = session.connection_id;
1712 {
1713 let project_connection_ids = session
1714 .db()
1715 .await
1716 .project_connection_ids(project_id, session.connection_id)
1717 .await?;
1718
1719 if !project_connection_ids.contains(&leader_id) {
1720 Err(anyhow!("no such peer"))?;
1721 }
1722 }
1723
1724 let mut response_payload = session
1725 .peer
1726 .forward_request(session.connection_id, leader_id, request)
1727 .await?;
1728 response_payload
1729 .views
1730 .retain(|view| view.leader_id != Some(follower_id.into()));
1731 response.send(response_payload)?;
1732 Ok(())
1733}
1734
1735async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1736 let project_id = ProjectId::from_proto(request.project_id);
1737 let leader_id = request
1738 .leader_id
1739 .ok_or_else(|| anyhow!("invalid leader id"))?
1740 .into();
1741 let project_connection_ids = session
1742 .db()
1743 .await
1744 .project_connection_ids(project_id, session.connection_id)
1745 .await?;
1746 if !project_connection_ids.contains(&leader_id) {
1747 Err(anyhow!("no such peer"))?;
1748 }
1749 session
1750 .peer
1751 .forward_send(session.connection_id, leader_id, request)?;
1752 Ok(())
1753}
1754
1755async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1756 let project_id = ProjectId::from_proto(request.project_id);
1757 let project_connection_ids = session
1758 .db
1759 .lock()
1760 .await
1761 .project_connection_ids(project_id, session.connection_id)
1762 .await?;
1763
1764 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1765 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1766 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1767 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1768 });
1769 for follower_peer_id in request.follower_ids.iter().copied() {
1770 let follower_connection_id = follower_peer_id.into();
1771 if project_connection_ids.contains(&follower_connection_id)
1772 && Some(follower_peer_id) != leader_id
1773 {
1774 session.peer.forward_send(
1775 session.connection_id,
1776 follower_connection_id,
1777 request.clone(),
1778 )?;
1779 }
1780 }
1781 Ok(())
1782}
1783
1784async fn get_users(
1785 request: proto::GetUsers,
1786 response: Response<proto::GetUsers>,
1787 session: Session,
1788) -> Result<()> {
1789 let user_ids = request
1790 .user_ids
1791 .into_iter()
1792 .map(UserId::from_proto)
1793 .collect();
1794 let users = session
1795 .db()
1796 .await
1797 .get_users_by_ids(user_ids)
1798 .await?
1799 .into_iter()
1800 .map(|user| proto::User {
1801 id: user.id.to_proto(),
1802 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1803 github_login: user.github_login,
1804 })
1805 .collect();
1806 response.send(proto::UsersResponse { users })?;
1807 Ok(())
1808}
1809
1810async fn fuzzy_search_users(
1811 request: proto::FuzzySearchUsers,
1812 response: Response<proto::FuzzySearchUsers>,
1813 session: Session,
1814) -> Result<()> {
1815 let query = request.query;
1816 let users = match query.len() {
1817 0 => vec![],
1818 1 | 2 => session
1819 .db()
1820 .await
1821 .get_user_by_github_account(&query, None)
1822 .await?
1823 .into_iter()
1824 .collect(),
1825 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1826 };
1827 let users = users
1828 .into_iter()
1829 .filter(|user| user.id != session.user_id)
1830 .map(|user| proto::User {
1831 id: user.id.to_proto(),
1832 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1833 github_login: user.github_login,
1834 })
1835 .collect();
1836 response.send(proto::UsersResponse { users })?;
1837 Ok(())
1838}
1839
1840async fn request_contact(
1841 request: proto::RequestContact,
1842 response: Response<proto::RequestContact>,
1843 session: Session,
1844) -> Result<()> {
1845 let requester_id = session.user_id;
1846 let responder_id = UserId::from_proto(request.responder_id);
1847 if requester_id == responder_id {
1848 return Err(anyhow!("cannot add yourself as a contact"))?;
1849 }
1850
1851 session
1852 .db()
1853 .await
1854 .send_contact_request(requester_id, responder_id)
1855 .await?;
1856
1857 // Update outgoing contact requests of requester
1858 let mut update = proto::UpdateContacts::default();
1859 update.outgoing_requests.push(responder_id.to_proto());
1860 for connection_id in session
1861 .connection_pool()
1862 .await
1863 .user_connection_ids(requester_id)
1864 {
1865 session.peer.send(connection_id, update.clone())?;
1866 }
1867
1868 // Update incoming contact requests of responder
1869 let mut update = proto::UpdateContacts::default();
1870 update
1871 .incoming_requests
1872 .push(proto::IncomingContactRequest {
1873 requester_id: requester_id.to_proto(),
1874 should_notify: true,
1875 });
1876 for connection_id in session
1877 .connection_pool()
1878 .await
1879 .user_connection_ids(responder_id)
1880 {
1881 session.peer.send(connection_id, update.clone())?;
1882 }
1883
1884 response.send(proto::Ack {})?;
1885 Ok(())
1886}
1887
1888async fn respond_to_contact_request(
1889 request: proto::RespondToContactRequest,
1890 response: Response<proto::RespondToContactRequest>,
1891 session: Session,
1892) -> Result<()> {
1893 let responder_id = session.user_id;
1894 let requester_id = UserId::from_proto(request.requester_id);
1895 let db = session.db().await;
1896 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1897 db.dismiss_contact_notification(responder_id, requester_id)
1898 .await?;
1899 } else {
1900 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1901
1902 db.respond_to_contact_request(responder_id, requester_id, accept)
1903 .await?;
1904 let requester_busy = db.is_user_busy(requester_id).await?;
1905 let responder_busy = db.is_user_busy(responder_id).await?;
1906
1907 let pool = session.connection_pool().await;
1908 // Update responder with new contact
1909 let mut update = proto::UpdateContacts::default();
1910 if accept {
1911 update
1912 .contacts
1913 .push(contact_for_user(requester_id, false, requester_busy, &pool));
1914 }
1915 update
1916 .remove_incoming_requests
1917 .push(requester_id.to_proto());
1918 for connection_id in pool.user_connection_ids(responder_id) {
1919 session.peer.send(connection_id, update.clone())?;
1920 }
1921
1922 // Update requester with new contact
1923 let mut update = proto::UpdateContacts::default();
1924 if accept {
1925 update
1926 .contacts
1927 .push(contact_for_user(responder_id, true, responder_busy, &pool));
1928 }
1929 update
1930 .remove_outgoing_requests
1931 .push(responder_id.to_proto());
1932 for connection_id in pool.user_connection_ids(requester_id) {
1933 session.peer.send(connection_id, update.clone())?;
1934 }
1935 }
1936
1937 response.send(proto::Ack {})?;
1938 Ok(())
1939}
1940
1941async fn remove_contact(
1942 request: proto::RemoveContact,
1943 response: Response<proto::RemoveContact>,
1944 session: Session,
1945) -> Result<()> {
1946 let requester_id = session.user_id;
1947 let responder_id = UserId::from_proto(request.user_id);
1948 let db = session.db().await;
1949 db.remove_contact(requester_id, responder_id).await?;
1950
1951 let pool = session.connection_pool().await;
1952 // Update outgoing contact requests of requester
1953 let mut update = proto::UpdateContacts::default();
1954 update
1955 .remove_outgoing_requests
1956 .push(responder_id.to_proto());
1957 for connection_id in pool.user_connection_ids(requester_id) {
1958 session.peer.send(connection_id, update.clone())?;
1959 }
1960
1961 // Update incoming contact requests of responder
1962 let mut update = proto::UpdateContacts::default();
1963 update
1964 .remove_incoming_requests
1965 .push(requester_id.to_proto());
1966 for connection_id in pool.user_connection_ids(responder_id) {
1967 session.peer.send(connection_id, update.clone())?;
1968 }
1969
1970 response.send(proto::Ack {})?;
1971 Ok(())
1972}
1973
1974async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1975 let project_id = ProjectId::from_proto(request.project_id);
1976 let project_connection_ids = session
1977 .db()
1978 .await
1979 .project_connection_ids(project_id, session.connection_id)
1980 .await?;
1981 broadcast(
1982 session.connection_id,
1983 project_connection_ids.iter().copied(),
1984 |connection_id| {
1985 session
1986 .peer
1987 .forward_send(session.connection_id, connection_id, request.clone())
1988 },
1989 );
1990 Ok(())
1991}
1992
1993async fn get_private_user_info(
1994 _request: proto::GetPrivateUserInfo,
1995 response: Response<proto::GetPrivateUserInfo>,
1996 session: Session,
1997) -> Result<()> {
1998 let metrics_id = session
1999 .db()
2000 .await
2001 .get_user_metrics_id(session.user_id)
2002 .await?;
2003 let user = session
2004 .db()
2005 .await
2006 .get_user_by_id(session.user_id)
2007 .await?
2008 .ok_or_else(|| anyhow!("user not found"))?;
2009 response.send(proto::GetPrivateUserInfoResponse {
2010 metrics_id,
2011 staff: user.admin,
2012 })?;
2013 Ok(())
2014}
2015
2016fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2017 match message {
2018 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2019 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2020 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2021 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2022 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2023 code: frame.code.into(),
2024 reason: frame.reason,
2025 })),
2026 }
2027}
2028
2029fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2030 match message {
2031 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2032 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2033 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2034 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2035 AxumMessage::Close(frame) => {
2036 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2037 code: frame.code.into(),
2038 reason: frame.reason,
2039 }))
2040 }
2041 }
2042}
2043
2044fn build_initial_contacts_update(
2045 contacts: Vec<db::Contact>,
2046 pool: &ConnectionPool,
2047) -> proto::UpdateContacts {
2048 let mut update = proto::UpdateContacts::default();
2049
2050 for contact in contacts {
2051 match contact {
2052 db::Contact::Accepted {
2053 user_id,
2054 should_notify,
2055 busy,
2056 } => {
2057 update
2058 .contacts
2059 .push(contact_for_user(user_id, should_notify, busy, &pool));
2060 }
2061 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2062 db::Contact::Incoming {
2063 user_id,
2064 should_notify,
2065 } => update
2066 .incoming_requests
2067 .push(proto::IncomingContactRequest {
2068 requester_id: user_id.to_proto(),
2069 should_notify,
2070 }),
2071 }
2072 }
2073
2074 update
2075}
2076
2077fn contact_for_user(
2078 user_id: UserId,
2079 should_notify: bool,
2080 busy: bool,
2081 pool: &ConnectionPool,
2082) -> proto::Contact {
2083 proto::Contact {
2084 user_id: user_id.to_proto(),
2085 online: pool.is_user_online(user_id),
2086 busy,
2087 should_notify,
2088 }
2089}
2090
2091fn room_updated(room: &proto::Room, peer: &Peer) {
2092 for participant in &room.participants {
2093 if let Some(peer_id) = participant
2094 .peer_id
2095 .ok_or_else(|| anyhow!("invalid participant peer id"))
2096 .trace_err()
2097 {
2098 peer.send(
2099 peer_id.into(),
2100 proto::RoomUpdated {
2101 room: Some(room.clone()),
2102 },
2103 )
2104 .trace_err();
2105 }
2106 }
2107}
2108
2109async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2110 let db = session.db().await;
2111 let contacts = db.get_contacts(user_id).await?;
2112 let busy = db.is_user_busy(user_id).await?;
2113
2114 let pool = session.connection_pool().await;
2115 let updated_contact = contact_for_user(user_id, false, busy, &pool);
2116 for contact in contacts {
2117 if let db::Contact::Accepted {
2118 user_id: contact_user_id,
2119 ..
2120 } = contact
2121 {
2122 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2123 session
2124 .peer
2125 .send(
2126 contact_conn_id,
2127 proto::UpdateContacts {
2128 contacts: vec![updated_contact.clone()],
2129 remove_contacts: Default::default(),
2130 incoming_requests: Default::default(),
2131 remove_incoming_requests: Default::default(),
2132 outgoing_requests: Default::default(),
2133 remove_outgoing_requests: Default::default(),
2134 },
2135 )
2136 .trace_err();
2137 }
2138 }
2139 }
2140 Ok(())
2141}
2142
2143async fn leave_room_for_session(session: &Session) -> Result<()> {
2144 let mut contacts_to_update = HashSet::default();
2145
2146 let room_id;
2147 let canceled_calls_to_user_ids;
2148 let live_kit_room;
2149 let delete_live_kit_room;
2150 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2151 contacts_to_update.insert(session.user_id);
2152
2153 for project in left_room.left_projects.values() {
2154 project_left(project, session);
2155 }
2156
2157 room_updated(&left_room.room, &session.peer);
2158 room_id = RoomId::from_proto(left_room.room.id);
2159 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2160 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2161 delete_live_kit_room = left_room.room.participants.is_empty();
2162 } else {
2163 return Ok(());
2164 }
2165
2166 {
2167 let pool = session.connection_pool().await;
2168 for canceled_user_id in canceled_calls_to_user_ids {
2169 for connection_id in pool.user_connection_ids(canceled_user_id) {
2170 session
2171 .peer
2172 .send(
2173 connection_id,
2174 proto::CallCanceled {
2175 room_id: room_id.to_proto(),
2176 },
2177 )
2178 .trace_err();
2179 }
2180 contacts_to_update.insert(canceled_user_id);
2181 }
2182 }
2183
2184 for contact_user_id in contacts_to_update {
2185 update_user_contacts(contact_user_id, &session).await?;
2186 }
2187
2188 if let Some(live_kit) = session.live_kit_client.as_ref() {
2189 live_kit
2190 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2191 .await
2192 .trace_err();
2193
2194 if delete_live_kit_room {
2195 live_kit.delete_room(live_kit_room).await.trace_err();
2196 }
2197 }
2198
2199 Ok(())
2200}
2201
2202fn project_left(project: &db::LeftProject, session: &Session) {
2203 for connection_id in &project.connection_ids {
2204 if project.host_user_id == session.user_id {
2205 session
2206 .peer
2207 .send(
2208 *connection_id,
2209 proto::UnshareProject {
2210 project_id: project.id.to_proto(),
2211 },
2212 )
2213 .trace_err();
2214 } else {
2215 session
2216 .peer
2217 .send(
2218 *connection_id,
2219 proto::RemoveProjectCollaborator {
2220 project_id: project.id.to_proto(),
2221 peer_id: Some(session.connection_id.into()),
2222 },
2223 )
2224 .trace_err();
2225 }
2226 }
2227}
2228
2229pub trait ResultExt {
2230 type Ok;
2231
2232 fn trace_err(self) -> Option<Self::Ok>;
2233}
2234
2235impl<T, E> ResultExt for Result<T, E>
2236where
2237 E: std::fmt::Debug,
2238{
2239 type Ok = T;
2240
2241 fn trace_err(self) -> Option<T> {
2242 match self {
2243 Ok(value) => Some(value),
2244 Err(error) => {
2245 tracing::error!("{:?}", error);
2246 None
2247 }
2248 }
2249 }
2250}