1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, Database, ProjectId, RoomId, ServerId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
38 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
39};
40use serde::{Serialize, Serializer};
41use std::{
42 any::TypeId,
43 fmt,
44 future::Future,
45 marker::PhantomData,
46 mem,
47 net::SocketAddr,
48 ops::{Deref, DerefMut},
49 rc::Rc,
50 sync::{
51 atomic::{AtomicBool, Ordering::SeqCst},
52 Arc,
53 },
54 time::Duration,
55};
56use tokio::sync::watch;
57use tower::ServiceBuilder;
58use tracing::{info_span, instrument, Instrument};
59
60pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(5);
61pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
62
63lazy_static! {
64 static ref METRIC_CONNECTIONS: IntGauge =
65 register_int_gauge!("connections", "number of connections").unwrap();
66 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
67 "shared_projects",
68 "number of open projects with one or more guests"
69 )
70 .unwrap();
71}
72
73type MessageHandler =
74 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
75
76struct Response<R> {
77 peer: Arc<Peer>,
78 receipt: Receipt<R>,
79 responded: Arc<AtomicBool>,
80}
81
82impl<R: RequestMessage> Response<R> {
83 fn send(self, payload: R::Response) -> Result<()> {
84 self.responded.store(true, SeqCst);
85 self.peer.respond(self.receipt, payload)?;
86 Ok(())
87 }
88}
89
90#[derive(Clone)]
91struct Session {
92 user_id: UserId,
93 connection_id: ConnectionId,
94 db: Arc<tokio::sync::Mutex<DbHandle>>,
95 peer: Arc<Peer>,
96 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
97 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
98 executor: Executor,
99}
100
101impl Session {
102 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
103 #[cfg(test)]
104 tokio::task::yield_now().await;
105 let guard = self.db.lock().await;
106 #[cfg(test)]
107 tokio::task::yield_now().await;
108 guard
109 }
110
111 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
112 #[cfg(test)]
113 tokio::task::yield_now().await;
114 let guard = self.connection_pool.lock();
115 ConnectionPoolGuard {
116 guard,
117 _not_send: PhantomData,
118 }
119 }
120}
121
122impl fmt::Debug for Session {
123 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124 f.debug_struct("Session")
125 .field("user_id", &self.user_id)
126 .field("connection_id", &self.connection_id)
127 .finish()
128 }
129}
130
131struct DbHandle(Arc<Database>);
132
133impl Deref for DbHandle {
134 type Target = Database;
135
136 fn deref(&self) -> &Self::Target {
137 self.0.as_ref()
138 }
139}
140
141pub struct Server {
142 id: parking_lot::Mutex<ServerId>,
143 peer: Arc<Peer>,
144 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
145 app_state: Arc<AppState>,
146 executor: Executor,
147 handlers: HashMap<TypeId, MessageHandler>,
148 teardown: watch::Sender<()>,
149}
150
151pub(crate) struct ConnectionPoolGuard<'a> {
152 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
153 _not_send: PhantomData<Rc<()>>,
154}
155
156#[derive(Serialize)]
157pub struct ServerSnapshot<'a> {
158 peer: &'a Peer,
159 #[serde(serialize_with = "serialize_deref")]
160 connection_pool: ConnectionPoolGuard<'a>,
161}
162
163pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
164where
165 S: Serializer,
166 T: Deref<Target = U>,
167 U: Serialize,
168{
169 Serialize::serialize(value.deref(), serializer)
170}
171
172impl Server {
173 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
174 let mut server = Self {
175 id: parking_lot::Mutex::new(id),
176 peer: Peer::new(id.0 as u32),
177 app_state,
178 executor,
179 connection_pool: Default::default(),
180 handlers: Default::default(),
181 teardown: watch::channel(()).0,
182 };
183
184 server
185 .add_request_handler(ping)
186 .add_request_handler(create_room)
187 .add_request_handler(join_room)
188 .add_request_handler(rejoin_room)
189 .add_message_handler(leave_room)
190 .add_request_handler(call)
191 .add_request_handler(cancel_call)
192 .add_message_handler(decline_call)
193 .add_request_handler(update_participant_location)
194 .add_request_handler(share_project)
195 .add_message_handler(unshare_project)
196 .add_request_handler(join_project)
197 .add_message_handler(leave_project)
198 .add_request_handler(update_project)
199 .add_request_handler(update_worktree)
200 .add_message_handler(start_language_server)
201 .add_message_handler(update_language_server)
202 .add_message_handler(update_diagnostic_summary)
203 .add_request_handler(forward_project_request::<proto::GetHover>)
204 .add_request_handler(forward_project_request::<proto::GetDefinition>)
205 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
206 .add_request_handler(forward_project_request::<proto::GetReferences>)
207 .add_request_handler(forward_project_request::<proto::SearchProject>)
208 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
209 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
210 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
211 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
212 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
213 .add_request_handler(forward_project_request::<proto::GetCompletions>)
214 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
215 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
216 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
217 .add_request_handler(forward_project_request::<proto::PrepareRename>)
218 .add_request_handler(forward_project_request::<proto::PerformRename>)
219 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
220 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
221 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
222 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
223 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
224 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
225 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
226 .add_message_handler(create_buffer_for_peer)
227 .add_request_handler(update_buffer)
228 .add_message_handler(update_buffer_file)
229 .add_message_handler(buffer_reloaded)
230 .add_message_handler(buffer_saved)
231 .add_request_handler(save_buffer)
232 .add_request_handler(get_users)
233 .add_request_handler(fuzzy_search_users)
234 .add_request_handler(request_contact)
235 .add_request_handler(remove_contact)
236 .add_request_handler(respond_to_contact_request)
237 .add_request_handler(follow)
238 .add_message_handler(unfollow)
239 .add_message_handler(update_followers)
240 .add_message_handler(update_diff_base)
241 .add_request_handler(get_private_user_info);
242
243 Arc::new(server)
244 }
245
246 pub async fn start(&self) -> Result<()> {
247 let server_id = *self.id.lock();
248 let app_state = self.app_state.clone();
249 let peer = self.peer.clone();
250 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
251 let pool = self.connection_pool.clone();
252 let live_kit_client = self.app_state.live_kit_client.clone();
253
254 let span = info_span!("start server");
255 self.executor.spawn_detached(
256 async move {
257 tracing::info!("waiting for cleanup timeout");
258 timeout.await;
259 tracing::info!("cleanup timeout expired, retrieving stale rooms");
260 if let Some(room_ids) = app_state
261 .db
262 .stale_room_ids(&app_state.config.zed_environment, server_id)
263 .await
264 .trace_err()
265 {
266 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
267 for room_id in room_ids {
268 let mut contacts_to_update = HashSet::default();
269 let mut canceled_calls_to_user_ids = Vec::new();
270 let mut live_kit_room = String::new();
271 let mut delete_live_kit_room = false;
272
273 if let Ok(mut refreshed_room) =
274 app_state.db.refresh_room(room_id, server_id).await
275 {
276 tracing::info!(
277 room_id = room_id.0,
278 new_participant_count = refreshed_room.room.participants.len(),
279 "refreshed room"
280 );
281 room_updated(&refreshed_room.room, &peer);
282 contacts_to_update
283 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
284 contacts_to_update
285 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
286 canceled_calls_to_user_ids =
287 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
288 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
289 delete_live_kit_room = refreshed_room.room.participants.is_empty();
290 }
291
292 {
293 let pool = pool.lock();
294 for canceled_user_id in canceled_calls_to_user_ids {
295 for connection_id in pool.user_connection_ids(canceled_user_id) {
296 peer.send(
297 connection_id,
298 proto::CallCanceled {
299 room_id: room_id.to_proto(),
300 },
301 )
302 .trace_err();
303 }
304 }
305 }
306
307 for user_id in contacts_to_update {
308 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
309 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
310 if let Some((busy, contacts)) = busy.zip(contacts) {
311 let pool = pool.lock();
312 let updated_contact = contact_for_user(user_id, false, busy, &pool);
313 for contact in contacts {
314 if let db::Contact::Accepted {
315 user_id: contact_user_id,
316 ..
317 } = contact
318 {
319 for contact_conn_id in
320 pool.user_connection_ids(contact_user_id)
321 {
322 peer.send(
323 contact_conn_id,
324 proto::UpdateContacts {
325 contacts: vec![updated_contact.clone()],
326 remove_contacts: Default::default(),
327 incoming_requests: Default::default(),
328 remove_incoming_requests: Default::default(),
329 outgoing_requests: Default::default(),
330 remove_outgoing_requests: Default::default(),
331 },
332 )
333 .trace_err();
334 }
335 }
336 }
337 }
338 }
339
340 if let Some(live_kit) = live_kit_client.as_ref() {
341 if delete_live_kit_room {
342 live_kit.delete_room(live_kit_room).await.trace_err();
343 }
344 }
345 }
346 }
347
348 app_state
349 .db
350 .delete_stale_servers(&app_state.config.zed_environment, server_id)
351 .await
352 .trace_err();
353 }
354 .instrument(span),
355 );
356 Ok(())
357 }
358
359 pub fn teardown(&self) {
360 self.peer.teardown();
361 self.connection_pool.lock().reset();
362 let _ = self.teardown.send(());
363 }
364
365 #[cfg(test)]
366 pub fn reset(&self, id: ServerId) {
367 self.teardown();
368 *self.id.lock() = id;
369 self.peer.reset(id.0 as u32);
370 }
371
372 #[cfg(test)]
373 pub fn id(&self) -> ServerId {
374 *self.id.lock()
375 }
376
377 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
378 where
379 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
380 Fut: 'static + Send + Future<Output = Result<()>>,
381 M: EnvelopedMessage,
382 {
383 let prev_handler = self.handlers.insert(
384 TypeId::of::<M>(),
385 Box::new(move |envelope, session| {
386 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
387 let span = info_span!(
388 "handle message",
389 payload_type = envelope.payload_type_name()
390 );
391 span.in_scope(|| {
392 tracing::info!(
393 payload_type = envelope.payload_type_name(),
394 "message received"
395 );
396 });
397 let future = (handler)(*envelope, session);
398 async move {
399 if let Err(error) = future.await {
400 tracing::error!(%error, "error handling message");
401 }
402 }
403 .instrument(span)
404 .boxed()
405 }),
406 );
407 if prev_handler.is_some() {
408 panic!("registered a handler for the same message twice");
409 }
410 self
411 }
412
413 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
414 where
415 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
416 Fut: 'static + Send + Future<Output = Result<()>>,
417 M: EnvelopedMessage,
418 {
419 self.add_handler(move |envelope, session| handler(envelope.payload, session));
420 self
421 }
422
423 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
424 where
425 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
426 Fut: Send + Future<Output = Result<()>>,
427 M: RequestMessage,
428 {
429 let handler = Arc::new(handler);
430 self.add_handler(move |envelope, session| {
431 let receipt = envelope.receipt();
432 let handler = handler.clone();
433 async move {
434 let peer = session.peer.clone();
435 let responded = Arc::new(AtomicBool::default());
436 let response = Response {
437 peer: peer.clone(),
438 responded: responded.clone(),
439 receipt,
440 };
441 match (handler)(envelope.payload, response, session).await {
442 Ok(()) => {
443 if responded.load(std::sync::atomic::Ordering::SeqCst) {
444 Ok(())
445 } else {
446 Err(anyhow!("handler did not send a response"))?
447 }
448 }
449 Err(error) => {
450 peer.respond_with_error(
451 receipt,
452 proto::Error {
453 message: error.to_string(),
454 },
455 )?;
456 Err(error)
457 }
458 }
459 }
460 })
461 }
462
463 pub fn handle_connection(
464 self: &Arc<Self>,
465 connection: Connection,
466 address: String,
467 user: User,
468 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
469 executor: Executor,
470 ) -> impl Future<Output = Result<()>> {
471 let this = self.clone();
472 let user_id = user.id;
473 let login = user.github_login;
474 let span = info_span!("handle connection", %user_id, %login, %address);
475 let mut teardown = self.teardown.subscribe();
476 async move {
477 let (connection_id, handle_io, mut incoming_rx) = this
478 .peer
479 .add_connection(connection, {
480 let executor = executor.clone();
481 move |duration| executor.sleep(duration)
482 });
483
484 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
485 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
486 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
487
488 if let Some(send_connection_id) = send_connection_id.take() {
489 let _ = send_connection_id.send(connection_id);
490 }
491
492 if !user.connected_once {
493 this.peer.send(connection_id, proto::ShowContacts {})?;
494 this.app_state.db.set_user_connected_once(user_id, true).await?;
495 }
496
497 let (contacts, invite_code) = future::try_join(
498 this.app_state.db.get_contacts(user_id),
499 this.app_state.db.get_invite_code_for_user(user_id)
500 ).await?;
501
502 {
503 let mut pool = this.connection_pool.lock();
504 pool.add_connection(connection_id, user_id, user.admin);
505 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
506
507 if let Some((code, count)) = invite_code {
508 this.peer.send(connection_id, proto::UpdateInviteInfo {
509 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
510 count: count as u32,
511 })?;
512 }
513 }
514
515 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
516 this.peer.send(connection_id, incoming_call)?;
517 }
518
519 let session = Session {
520 user_id,
521 connection_id,
522 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
523 peer: this.peer.clone(),
524 connection_pool: this.connection_pool.clone(),
525 live_kit_client: this.app_state.live_kit_client.clone(),
526 executor: executor.clone(),
527 };
528 update_user_contacts(user_id, &session).await?;
529
530 let handle_io = handle_io.fuse();
531 futures::pin_mut!(handle_io);
532
533 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
534 // This prevents deadlocks when e.g., client A performs a request to client B and
535 // client B performs a request to client A. If both clients stop processing further
536 // messages until their respective request completes, they won't have a chance to
537 // respond to the other client's request and cause a deadlock.
538 //
539 // This arrangement ensures we will attempt to process earlier messages first, but fall
540 // back to processing messages arrived later in the spirit of making progress.
541 let mut foreground_message_handlers = FuturesUnordered::new();
542 loop {
543 let next_message = incoming_rx.next().fuse();
544 futures::pin_mut!(next_message);
545 futures::select_biased! {
546 _ = teardown.changed().fuse() => return Ok(()),
547 result = handle_io => {
548 if let Err(error) = result {
549 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
550 }
551 break;
552 }
553 _ = foreground_message_handlers.next() => {}
554 message = next_message => {
555 if let Some(message) = message {
556 let type_name = message.payload_type_name();
557 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
558 let span_enter = span.enter();
559 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
560 let is_background = message.is_background();
561 let handle_message = (handler)(message, session.clone());
562 drop(span_enter);
563
564 let handle_message = handle_message.instrument(span);
565 if is_background {
566 executor.spawn_detached(handle_message);
567 } else {
568 foreground_message_handlers.push(handle_message);
569 }
570 } else {
571 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
572 }
573 } else {
574 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
575 break;
576 }
577 }
578 }
579 }
580
581 drop(foreground_message_handlers);
582 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
583 if let Err(error) = connection_lost(session, teardown, executor).await {
584 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
585 }
586
587 Ok(())
588 }.instrument(span)
589 }
590
591 pub async fn invite_code_redeemed(
592 self: &Arc<Self>,
593 inviter_id: UserId,
594 invitee_id: UserId,
595 ) -> Result<()> {
596 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
597 if let Some(code) = &user.invite_code {
598 let pool = self.connection_pool.lock();
599 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
600 for connection_id in pool.user_connection_ids(inviter_id) {
601 self.peer.send(
602 connection_id,
603 proto::UpdateContacts {
604 contacts: vec![invitee_contact.clone()],
605 ..Default::default()
606 },
607 )?;
608 self.peer.send(
609 connection_id,
610 proto::UpdateInviteInfo {
611 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
612 count: user.invite_count as u32,
613 },
614 )?;
615 }
616 }
617 }
618 Ok(())
619 }
620
621 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
622 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
623 if let Some(invite_code) = &user.invite_code {
624 let pool = self.connection_pool.lock();
625 for connection_id in pool.user_connection_ids(user_id) {
626 self.peer.send(
627 connection_id,
628 proto::UpdateInviteInfo {
629 url: format!(
630 "{}{}",
631 self.app_state.config.invite_link_prefix, invite_code
632 ),
633 count: user.invite_count as u32,
634 },
635 )?;
636 }
637 }
638 }
639 Ok(())
640 }
641
642 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
643 ServerSnapshot {
644 connection_pool: ConnectionPoolGuard {
645 guard: self.connection_pool.lock(),
646 _not_send: PhantomData,
647 },
648 peer: &self.peer,
649 }
650 }
651}
652
653impl<'a> Deref for ConnectionPoolGuard<'a> {
654 type Target = ConnectionPool;
655
656 fn deref(&self) -> &Self::Target {
657 &*self.guard
658 }
659}
660
661impl<'a> DerefMut for ConnectionPoolGuard<'a> {
662 fn deref_mut(&mut self) -> &mut Self::Target {
663 &mut *self.guard
664 }
665}
666
667impl<'a> Drop for ConnectionPoolGuard<'a> {
668 fn drop(&mut self) {
669 #[cfg(test)]
670 self.check_invariants();
671 }
672}
673
674fn broadcast<F>(
675 sender_id: Option<ConnectionId>,
676 receiver_ids: impl IntoIterator<Item = ConnectionId>,
677 mut f: F,
678) where
679 F: FnMut(ConnectionId) -> anyhow::Result<()>,
680{
681 for receiver_id in receiver_ids {
682 if Some(receiver_id) != sender_id {
683 if let Err(error) = f(receiver_id) {
684 tracing::error!("failed to send to {:?} {}", receiver_id, error);
685 }
686 }
687 }
688}
689
690lazy_static! {
691 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
692}
693
694pub struct ProtocolVersion(u32);
695
696impl Header for ProtocolVersion {
697 fn name() -> &'static HeaderName {
698 &ZED_PROTOCOL_VERSION
699 }
700
701 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
702 where
703 Self: Sized,
704 I: Iterator<Item = &'i axum::http::HeaderValue>,
705 {
706 let version = values
707 .next()
708 .ok_or_else(axum::headers::Error::invalid)?
709 .to_str()
710 .map_err(|_| axum::headers::Error::invalid())?
711 .parse()
712 .map_err(|_| axum::headers::Error::invalid())?;
713 Ok(Self(version))
714 }
715
716 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
717 values.extend([self.0.to_string().parse().unwrap()]);
718 }
719}
720
721pub fn routes(server: Arc<Server>) -> Router<Body> {
722 Router::new()
723 .route("/rpc", get(handle_websocket_request))
724 .layer(
725 ServiceBuilder::new()
726 .layer(Extension(server.app_state.clone()))
727 .layer(middleware::from_fn(auth::validate_header)),
728 )
729 .route("/metrics", get(handle_metrics))
730 .layer(Extension(server))
731}
732
733pub async fn handle_websocket_request(
734 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
735 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
736 Extension(server): Extension<Arc<Server>>,
737 Extension(user): Extension<User>,
738 ws: WebSocketUpgrade,
739) -> axum::response::Response {
740 if protocol_version != rpc::PROTOCOL_VERSION {
741 return (
742 StatusCode::UPGRADE_REQUIRED,
743 "client must be upgraded".to_string(),
744 )
745 .into_response();
746 }
747 let socket_address = socket_address.to_string();
748 ws.on_upgrade(move |socket| {
749 use util::ResultExt;
750 let socket = socket
751 .map_ok(to_tungstenite_message)
752 .err_into()
753 .with(|message| async move { Ok(to_axum_message(message)) });
754 let connection = Connection::new(Box::pin(socket));
755 async move {
756 server
757 .handle_connection(connection, socket_address, user, None, Executor::Production)
758 .await
759 .log_err();
760 }
761 })
762}
763
764pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
765 let connections = server
766 .connection_pool
767 .lock()
768 .connections()
769 .filter(|connection| !connection.admin)
770 .count();
771
772 METRIC_CONNECTIONS.set(connections as _);
773
774 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
775 METRIC_SHARED_PROJECTS.set(shared_projects as _);
776
777 let encoder = prometheus::TextEncoder::new();
778 let metric_families = prometheus::gather();
779 let encoded_metrics = encoder
780 .encode_to_string(&metric_families)
781 .map_err(|err| anyhow!("{}", err))?;
782 Ok(encoded_metrics)
783}
784
785#[instrument(err, skip(executor))]
786async fn connection_lost(
787 session: Session,
788 mut teardown: watch::Receiver<()>,
789 executor: Executor,
790) -> Result<()> {
791 session.peer.disconnect(session.connection_id);
792 session
793 .connection_pool()
794 .await
795 .remove_connection(session.connection_id)?;
796
797 session
798 .db()
799 .await
800 .connection_lost(session.connection_id)
801 .await
802 .trace_err();
803
804 futures::select_biased! {
805 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
806 leave_room_for_session(&session).await.trace_err();
807
808 if !session
809 .connection_pool()
810 .await
811 .is_user_online(session.user_id)
812 {
813 let db = session.db().await;
814 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
815 room_updated(&room, &session.peer);
816 }
817 }
818 update_user_contacts(session.user_id, &session).await?;
819 }
820 _ = teardown.changed().fuse() => {}
821 }
822
823 Ok(())
824}
825
826async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
827 response.send(proto::Ack {})?;
828 Ok(())
829}
830
831async fn create_room(
832 _request: proto::CreateRoom,
833 response: Response<proto::CreateRoom>,
834 session: Session,
835) -> Result<()> {
836 let live_kit_room = nanoid::nanoid!(30);
837 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
838 if let Some(_) = live_kit
839 .create_room(live_kit_room.clone())
840 .await
841 .trace_err()
842 {
843 if let Some(token) = live_kit
844 .room_token(&live_kit_room, &session.user_id.to_string())
845 .trace_err()
846 {
847 Some(proto::LiveKitConnectionInfo {
848 server_url: live_kit.url().into(),
849 token,
850 })
851 } else {
852 None
853 }
854 } else {
855 None
856 }
857 } else {
858 None
859 };
860
861 {
862 let room = session
863 .db()
864 .await
865 .create_room(session.user_id, session.connection_id, &live_kit_room)
866 .await?;
867
868 response.send(proto::CreateRoomResponse {
869 room: Some(room.clone()),
870 live_kit_connection_info,
871 })?;
872 }
873
874 update_user_contacts(session.user_id, &session).await?;
875 Ok(())
876}
877
878async fn join_room(
879 request: proto::JoinRoom,
880 response: Response<proto::JoinRoom>,
881 session: Session,
882) -> Result<()> {
883 let room_id = RoomId::from_proto(request.id);
884 let room = {
885 let room = session
886 .db()
887 .await
888 .join_room(room_id, session.user_id, session.connection_id)
889 .await?;
890 room_updated(&room, &session.peer);
891 room.clone()
892 };
893
894 for connection_id in session
895 .connection_pool()
896 .await
897 .user_connection_ids(session.user_id)
898 {
899 session
900 .peer
901 .send(
902 connection_id,
903 proto::CallCanceled {
904 room_id: room_id.to_proto(),
905 },
906 )
907 .trace_err();
908 }
909
910 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
911 if let Some(token) = live_kit
912 .room_token(&room.live_kit_room, &session.user_id.to_string())
913 .trace_err()
914 {
915 Some(proto::LiveKitConnectionInfo {
916 server_url: live_kit.url().into(),
917 token,
918 })
919 } else {
920 None
921 }
922 } else {
923 None
924 };
925
926 response.send(proto::JoinRoomResponse {
927 room: Some(room),
928 live_kit_connection_info,
929 })?;
930
931 update_user_contacts(session.user_id, &session).await?;
932 Ok(())
933}
934
935async fn rejoin_room(
936 request: proto::RejoinRoom,
937 response: Response<proto::RejoinRoom>,
938 session: Session,
939) -> Result<()> {
940 {
941 let mut rejoined_room = session
942 .db()
943 .await
944 .rejoin_room(request, session.user_id, session.connection_id)
945 .await?;
946
947 response.send(proto::RejoinRoomResponse {
948 room: Some(rejoined_room.room.clone()),
949 reshared_projects: rejoined_room
950 .reshared_projects
951 .iter()
952 .map(|project| proto::ResharedProject {
953 id: project.id.to_proto(),
954 collaborators: project
955 .collaborators
956 .iter()
957 .map(|collaborator| collaborator.to_proto())
958 .collect(),
959 })
960 .collect(),
961 rejoined_projects: rejoined_room
962 .rejoined_projects
963 .iter()
964 .map(|rejoined_project| proto::RejoinedProject {
965 id: rejoined_project.id.to_proto(),
966 worktrees: rejoined_project
967 .worktrees
968 .iter()
969 .map(|worktree| proto::WorktreeMetadata {
970 id: worktree.id,
971 root_name: worktree.root_name.clone(),
972 visible: worktree.visible,
973 abs_path: worktree.abs_path.clone(),
974 })
975 .collect(),
976 collaborators: rejoined_project
977 .collaborators
978 .iter()
979 .map(|collaborator| collaborator.to_proto())
980 .collect(),
981 language_servers: rejoined_project.language_servers.clone(),
982 })
983 .collect(),
984 })?;
985 room_updated(&rejoined_room.room, &session.peer);
986
987 for project in &rejoined_room.reshared_projects {
988 for collaborator in &project.collaborators {
989 session
990 .peer
991 .send(
992 collaborator.connection_id,
993 proto::UpdateProjectCollaborator {
994 project_id: project.id.to_proto(),
995 old_peer_id: Some(project.old_connection_id.into()),
996 new_peer_id: Some(session.connection_id.into()),
997 },
998 )
999 .trace_err();
1000 }
1001
1002 broadcast(
1003 Some(session.connection_id),
1004 project
1005 .collaborators
1006 .iter()
1007 .map(|collaborator| collaborator.connection_id),
1008 |connection_id| {
1009 session.peer.forward_send(
1010 session.connection_id,
1011 connection_id,
1012 proto::UpdateProject {
1013 project_id: project.id.to_proto(),
1014 worktrees: project.worktrees.clone(),
1015 },
1016 )
1017 },
1018 );
1019 }
1020
1021 for project in &rejoined_room.rejoined_projects {
1022 for collaborator in &project.collaborators {
1023 session
1024 .peer
1025 .send(
1026 collaborator.connection_id,
1027 proto::UpdateProjectCollaborator {
1028 project_id: project.id.to_proto(),
1029 old_peer_id: Some(project.old_connection_id.into()),
1030 new_peer_id: Some(session.connection_id.into()),
1031 },
1032 )
1033 .trace_err();
1034 }
1035 }
1036
1037 for project in &mut rejoined_room.rejoined_projects {
1038 for worktree in mem::take(&mut project.worktrees) {
1039 #[cfg(any(test, feature = "test-support"))]
1040 const MAX_CHUNK_SIZE: usize = 2;
1041 #[cfg(not(any(test, feature = "test-support")))]
1042 const MAX_CHUNK_SIZE: usize = 256;
1043
1044 // Stream this worktree's entries.
1045 let message = proto::UpdateWorktree {
1046 project_id: project.id.to_proto(),
1047 worktree_id: worktree.id,
1048 abs_path: worktree.abs_path.clone(),
1049 root_name: worktree.root_name,
1050 updated_entries: worktree.updated_entries,
1051 removed_entries: worktree.removed_entries,
1052 scan_id: worktree.scan_id,
1053 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1054 };
1055 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1056 session.peer.send(session.connection_id, update.clone())?;
1057 }
1058
1059 // Stream this worktree's diagnostics.
1060 for summary in worktree.diagnostic_summaries {
1061 session.peer.send(
1062 session.connection_id,
1063 proto::UpdateDiagnosticSummary {
1064 project_id: project.id.to_proto(),
1065 worktree_id: worktree.id,
1066 summary: Some(summary),
1067 },
1068 )?;
1069 }
1070 }
1071
1072 for language_server in &project.language_servers {
1073 session.peer.send(
1074 session.connection_id,
1075 proto::UpdateLanguageServer {
1076 project_id: project.id.to_proto(),
1077 language_server_id: language_server.id,
1078 variant: Some(
1079 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1080 proto::LspDiskBasedDiagnosticsUpdated {},
1081 ),
1082 ),
1083 },
1084 )?;
1085 }
1086 }
1087 }
1088
1089 update_user_contacts(session.user_id, &session).await?;
1090 Ok(())
1091}
1092
1093async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
1094 leave_room_for_session(&session).await
1095}
1096
1097async fn call(
1098 request: proto::Call,
1099 response: Response<proto::Call>,
1100 session: Session,
1101) -> Result<()> {
1102 let room_id = RoomId::from_proto(request.room_id);
1103 let calling_user_id = session.user_id;
1104 let calling_connection_id = session.connection_id;
1105 let called_user_id = UserId::from_proto(request.called_user_id);
1106 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1107 if !session
1108 .db()
1109 .await
1110 .has_contact(calling_user_id, called_user_id)
1111 .await?
1112 {
1113 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1114 }
1115
1116 let incoming_call = {
1117 let (room, incoming_call) = &mut *session
1118 .db()
1119 .await
1120 .call(
1121 room_id,
1122 calling_user_id,
1123 calling_connection_id,
1124 called_user_id,
1125 initial_project_id,
1126 )
1127 .await?;
1128 room_updated(&room, &session.peer);
1129 mem::take(incoming_call)
1130 };
1131 update_user_contacts(called_user_id, &session).await?;
1132
1133 let mut calls = session
1134 .connection_pool()
1135 .await
1136 .user_connection_ids(called_user_id)
1137 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1138 .collect::<FuturesUnordered<_>>();
1139
1140 while let Some(call_response) = calls.next().await {
1141 match call_response.as_ref() {
1142 Ok(_) => {
1143 response.send(proto::Ack {})?;
1144 return Ok(());
1145 }
1146 Err(_) => {
1147 call_response.trace_err();
1148 }
1149 }
1150 }
1151
1152 {
1153 let room = session
1154 .db()
1155 .await
1156 .call_failed(room_id, called_user_id)
1157 .await?;
1158 room_updated(&room, &session.peer);
1159 }
1160 update_user_contacts(called_user_id, &session).await?;
1161
1162 Err(anyhow!("failed to ring user"))?
1163}
1164
1165async fn cancel_call(
1166 request: proto::CancelCall,
1167 response: Response<proto::CancelCall>,
1168 session: Session,
1169) -> Result<()> {
1170 let called_user_id = UserId::from_proto(request.called_user_id);
1171 let room_id = RoomId::from_proto(request.room_id);
1172 {
1173 let room = session
1174 .db()
1175 .await
1176 .cancel_call(room_id, session.connection_id, called_user_id)
1177 .await?;
1178 room_updated(&room, &session.peer);
1179 }
1180
1181 for connection_id in session
1182 .connection_pool()
1183 .await
1184 .user_connection_ids(called_user_id)
1185 {
1186 session
1187 .peer
1188 .send(
1189 connection_id,
1190 proto::CallCanceled {
1191 room_id: room_id.to_proto(),
1192 },
1193 )
1194 .trace_err();
1195 }
1196 response.send(proto::Ack {})?;
1197
1198 update_user_contacts(called_user_id, &session).await?;
1199 Ok(())
1200}
1201
1202async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1203 let room_id = RoomId::from_proto(message.room_id);
1204 {
1205 let room = session
1206 .db()
1207 .await
1208 .decline_call(Some(room_id), session.user_id)
1209 .await?
1210 .ok_or_else(|| anyhow!("failed to decline call"))?;
1211 room_updated(&room, &session.peer);
1212 }
1213
1214 for connection_id in session
1215 .connection_pool()
1216 .await
1217 .user_connection_ids(session.user_id)
1218 {
1219 session
1220 .peer
1221 .send(
1222 connection_id,
1223 proto::CallCanceled {
1224 room_id: room_id.to_proto(),
1225 },
1226 )
1227 .trace_err();
1228 }
1229 update_user_contacts(session.user_id, &session).await?;
1230 Ok(())
1231}
1232
1233async fn update_participant_location(
1234 request: proto::UpdateParticipantLocation,
1235 response: Response<proto::UpdateParticipantLocation>,
1236 session: Session,
1237) -> Result<()> {
1238 let room_id = RoomId::from_proto(request.room_id);
1239 let location = request
1240 .location
1241 .ok_or_else(|| anyhow!("invalid location"))?;
1242 let room = session
1243 .db()
1244 .await
1245 .update_room_participant_location(room_id, session.connection_id, location)
1246 .await?;
1247 room_updated(&room, &session.peer);
1248 response.send(proto::Ack {})?;
1249 Ok(())
1250}
1251
1252async fn share_project(
1253 request: proto::ShareProject,
1254 response: Response<proto::ShareProject>,
1255 session: Session,
1256) -> Result<()> {
1257 let (project_id, room) = &*session
1258 .db()
1259 .await
1260 .share_project(
1261 RoomId::from_proto(request.room_id),
1262 session.connection_id,
1263 &request.worktrees,
1264 )
1265 .await?;
1266 response.send(proto::ShareProjectResponse {
1267 project_id: project_id.to_proto(),
1268 })?;
1269 room_updated(&room, &session.peer);
1270
1271 Ok(())
1272}
1273
1274async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1275 let project_id = ProjectId::from_proto(message.project_id);
1276
1277 let (room, guest_connection_ids) = &*session
1278 .db()
1279 .await
1280 .unshare_project(project_id, session.connection_id)
1281 .await?;
1282
1283 broadcast(
1284 Some(session.connection_id),
1285 guest_connection_ids.iter().copied(),
1286 |conn_id| session.peer.send(conn_id, message.clone()),
1287 );
1288 room_updated(&room, &session.peer);
1289
1290 Ok(())
1291}
1292
1293async fn join_project(
1294 request: proto::JoinProject,
1295 response: Response<proto::JoinProject>,
1296 session: Session,
1297) -> Result<()> {
1298 let project_id = ProjectId::from_proto(request.project_id);
1299 let guest_user_id = session.user_id;
1300
1301 tracing::info!(%project_id, "join project");
1302
1303 let (project, replica_id) = &mut *session
1304 .db()
1305 .await
1306 .join_project(project_id, session.connection_id)
1307 .await?;
1308
1309 let collaborators = project
1310 .collaborators
1311 .iter()
1312 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1313 .map(|collaborator| collaborator.to_proto())
1314 .collect::<Vec<_>>();
1315
1316 let worktrees = project
1317 .worktrees
1318 .iter()
1319 .map(|(id, worktree)| proto::WorktreeMetadata {
1320 id: *id,
1321 root_name: worktree.root_name.clone(),
1322 visible: worktree.visible,
1323 abs_path: worktree.abs_path.clone(),
1324 })
1325 .collect::<Vec<_>>();
1326
1327 for collaborator in &collaborators {
1328 session
1329 .peer
1330 .send(
1331 collaborator.peer_id.unwrap().into(),
1332 proto::AddProjectCollaborator {
1333 project_id: project_id.to_proto(),
1334 collaborator: Some(proto::Collaborator {
1335 peer_id: Some(session.connection_id.into()),
1336 replica_id: replica_id.0 as u32,
1337 user_id: guest_user_id.to_proto(),
1338 }),
1339 },
1340 )
1341 .trace_err();
1342 }
1343
1344 // First, we send the metadata associated with each worktree.
1345 response.send(proto::JoinProjectResponse {
1346 worktrees: worktrees.clone(),
1347 replica_id: replica_id.0 as u32,
1348 collaborators: collaborators.clone(),
1349 language_servers: project.language_servers.clone(),
1350 })?;
1351
1352 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1353 #[cfg(any(test, feature = "test-support"))]
1354 const MAX_CHUNK_SIZE: usize = 2;
1355 #[cfg(not(any(test, feature = "test-support")))]
1356 const MAX_CHUNK_SIZE: usize = 256;
1357
1358 // Stream this worktree's entries.
1359 let message = proto::UpdateWorktree {
1360 project_id: project_id.to_proto(),
1361 worktree_id,
1362 abs_path: worktree.abs_path.clone(),
1363 root_name: worktree.root_name,
1364 updated_entries: worktree.entries,
1365 removed_entries: Default::default(),
1366 scan_id: worktree.scan_id,
1367 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1368 };
1369 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1370 session.peer.send(session.connection_id, update.clone())?;
1371 }
1372
1373 // Stream this worktree's diagnostics.
1374 for summary in worktree.diagnostic_summaries {
1375 session.peer.send(
1376 session.connection_id,
1377 proto::UpdateDiagnosticSummary {
1378 project_id: project_id.to_proto(),
1379 worktree_id: worktree.id,
1380 summary: Some(summary),
1381 },
1382 )?;
1383 }
1384 }
1385
1386 for language_server in &project.language_servers {
1387 session.peer.send(
1388 session.connection_id,
1389 proto::UpdateLanguageServer {
1390 project_id: project_id.to_proto(),
1391 language_server_id: language_server.id,
1392 variant: Some(
1393 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1394 proto::LspDiskBasedDiagnosticsUpdated {},
1395 ),
1396 ),
1397 },
1398 )?;
1399 }
1400
1401 Ok(())
1402}
1403
1404async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1405 let sender_id = session.connection_id;
1406 let project_id = ProjectId::from_proto(request.project_id);
1407
1408 let project = session
1409 .db()
1410 .await
1411 .leave_project(project_id, sender_id)
1412 .await?;
1413 tracing::info!(
1414 %project_id,
1415 host_user_id = %project.host_user_id,
1416 host_connection_id = %project.host_connection_id,
1417 "leave project"
1418 );
1419 project_left(&project, &session);
1420
1421 Ok(())
1422}
1423
1424async fn update_project(
1425 request: proto::UpdateProject,
1426 response: Response<proto::UpdateProject>,
1427 session: Session,
1428) -> Result<()> {
1429 let project_id = ProjectId::from_proto(request.project_id);
1430 let (room, guest_connection_ids) = &*session
1431 .db()
1432 .await
1433 .update_project(project_id, session.connection_id, &request.worktrees)
1434 .await?;
1435 broadcast(
1436 Some(session.connection_id),
1437 guest_connection_ids.iter().copied(),
1438 |connection_id| {
1439 session
1440 .peer
1441 .forward_send(session.connection_id, connection_id, request.clone())
1442 },
1443 );
1444 room_updated(&room, &session.peer);
1445 response.send(proto::Ack {})?;
1446
1447 Ok(())
1448}
1449
1450async fn update_worktree(
1451 request: proto::UpdateWorktree,
1452 response: Response<proto::UpdateWorktree>,
1453 session: Session,
1454) -> Result<()> {
1455 let guest_connection_ids = session
1456 .db()
1457 .await
1458 .update_worktree(&request, session.connection_id)
1459 .await?;
1460
1461 broadcast(
1462 Some(session.connection_id),
1463 guest_connection_ids.iter().copied(),
1464 |connection_id| {
1465 session
1466 .peer
1467 .forward_send(session.connection_id, connection_id, request.clone())
1468 },
1469 );
1470 response.send(proto::Ack {})?;
1471 Ok(())
1472}
1473
1474async fn update_diagnostic_summary(
1475 message: proto::UpdateDiagnosticSummary,
1476 session: Session,
1477) -> Result<()> {
1478 let guest_connection_ids = session
1479 .db()
1480 .await
1481 .update_diagnostic_summary(&message, session.connection_id)
1482 .await?;
1483
1484 broadcast(
1485 Some(session.connection_id),
1486 guest_connection_ids.iter().copied(),
1487 |connection_id| {
1488 session
1489 .peer
1490 .forward_send(session.connection_id, connection_id, message.clone())
1491 },
1492 );
1493
1494 Ok(())
1495}
1496
1497async fn start_language_server(
1498 request: proto::StartLanguageServer,
1499 session: Session,
1500) -> Result<()> {
1501 let guest_connection_ids = session
1502 .db()
1503 .await
1504 .start_language_server(&request, session.connection_id)
1505 .await?;
1506
1507 broadcast(
1508 Some(session.connection_id),
1509 guest_connection_ids.iter().copied(),
1510 |connection_id| {
1511 session
1512 .peer
1513 .forward_send(session.connection_id, connection_id, request.clone())
1514 },
1515 );
1516 Ok(())
1517}
1518
1519async fn update_language_server(
1520 request: proto::UpdateLanguageServer,
1521 session: Session,
1522) -> Result<()> {
1523 session.executor.record_backtrace();
1524 let project_id = ProjectId::from_proto(request.project_id);
1525 let project_connection_ids = session
1526 .db()
1527 .await
1528 .project_connection_ids(project_id, session.connection_id)
1529 .await?;
1530 broadcast(
1531 Some(session.connection_id),
1532 project_connection_ids.iter().copied(),
1533 |connection_id| {
1534 session
1535 .peer
1536 .forward_send(session.connection_id, connection_id, request.clone())
1537 },
1538 );
1539 Ok(())
1540}
1541
1542async fn forward_project_request<T>(
1543 request: T,
1544 response: Response<T>,
1545 session: Session,
1546) -> Result<()>
1547where
1548 T: EntityMessage + RequestMessage,
1549{
1550 session.executor.record_backtrace();
1551 let project_id = ProjectId::from_proto(request.remote_entity_id());
1552 let host_connection_id = {
1553 let collaborators = session
1554 .db()
1555 .await
1556 .project_collaborators(project_id, session.connection_id)
1557 .await?;
1558 collaborators
1559 .iter()
1560 .find(|collaborator| collaborator.is_host)
1561 .ok_or_else(|| anyhow!("host not found"))?
1562 .connection_id
1563 };
1564
1565 let payload = session
1566 .peer
1567 .forward_request(session.connection_id, host_connection_id, request)
1568 .await?;
1569
1570 response.send(payload)?;
1571 Ok(())
1572}
1573
1574async fn save_buffer(
1575 request: proto::SaveBuffer,
1576 response: Response<proto::SaveBuffer>,
1577 session: Session,
1578) -> Result<()> {
1579 let project_id = ProjectId::from_proto(request.project_id);
1580 let host_connection_id = {
1581 let collaborators = session
1582 .db()
1583 .await
1584 .project_collaborators(project_id, session.connection_id)
1585 .await?;
1586 collaborators
1587 .iter()
1588 .find(|collaborator| collaborator.is_host)
1589 .ok_or_else(|| anyhow!("host not found"))?
1590 .connection_id
1591 };
1592 let response_payload = session
1593 .peer
1594 .forward_request(session.connection_id, host_connection_id, request.clone())
1595 .await?;
1596
1597 let mut collaborators = session
1598 .db()
1599 .await
1600 .project_collaborators(project_id, session.connection_id)
1601 .await?;
1602 collaborators.retain(|collaborator| collaborator.connection_id != session.connection_id);
1603 let project_connection_ids = collaborators
1604 .iter()
1605 .map(|collaborator| collaborator.connection_id);
1606 broadcast(
1607 Some(host_connection_id),
1608 project_connection_ids,
1609 |conn_id| {
1610 session
1611 .peer
1612 .forward_send(host_connection_id, conn_id, response_payload.clone())
1613 },
1614 );
1615 response.send(response_payload)?;
1616 Ok(())
1617}
1618
1619async fn create_buffer_for_peer(
1620 request: proto::CreateBufferForPeer,
1621 session: Session,
1622) -> Result<()> {
1623 session.executor.record_backtrace();
1624 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1625 session
1626 .peer
1627 .forward_send(session.connection_id, peer_id.into(), request)?;
1628 Ok(())
1629}
1630
1631async fn update_buffer(
1632 request: proto::UpdateBuffer,
1633 response: Response<proto::UpdateBuffer>,
1634 session: Session,
1635) -> Result<()> {
1636 session.executor.record_backtrace();
1637 let project_id = ProjectId::from_proto(request.project_id);
1638 let project_connection_ids = session
1639 .db()
1640 .await
1641 .project_connection_ids(project_id, session.connection_id)
1642 .await?;
1643
1644 session.executor.record_backtrace();
1645
1646 broadcast(
1647 Some(session.connection_id),
1648 project_connection_ids.iter().copied(),
1649 |connection_id| {
1650 session
1651 .peer
1652 .forward_send(session.connection_id, connection_id, request.clone())
1653 },
1654 );
1655 response.send(proto::Ack {})?;
1656 Ok(())
1657}
1658
1659async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1660 let project_id = ProjectId::from_proto(request.project_id);
1661 let project_connection_ids = session
1662 .db()
1663 .await
1664 .project_connection_ids(project_id, session.connection_id)
1665 .await?;
1666
1667 broadcast(
1668 Some(session.connection_id),
1669 project_connection_ids.iter().copied(),
1670 |connection_id| {
1671 session
1672 .peer
1673 .forward_send(session.connection_id, connection_id, request.clone())
1674 },
1675 );
1676 Ok(())
1677}
1678
1679async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1680 let project_id = ProjectId::from_proto(request.project_id);
1681 let project_connection_ids = session
1682 .db()
1683 .await
1684 .project_connection_ids(project_id, session.connection_id)
1685 .await?;
1686 broadcast(
1687 Some(session.connection_id),
1688 project_connection_ids.iter().copied(),
1689 |connection_id| {
1690 session
1691 .peer
1692 .forward_send(session.connection_id, connection_id, request.clone())
1693 },
1694 );
1695 Ok(())
1696}
1697
1698async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1699 let project_id = ProjectId::from_proto(request.project_id);
1700 let project_connection_ids = session
1701 .db()
1702 .await
1703 .project_connection_ids(project_id, session.connection_id)
1704 .await?;
1705 broadcast(
1706 Some(session.connection_id),
1707 project_connection_ids.iter().copied(),
1708 |connection_id| {
1709 session
1710 .peer
1711 .forward_send(session.connection_id, connection_id, request.clone())
1712 },
1713 );
1714 Ok(())
1715}
1716
1717async fn follow(
1718 request: proto::Follow,
1719 response: Response<proto::Follow>,
1720 session: Session,
1721) -> Result<()> {
1722 let project_id = ProjectId::from_proto(request.project_id);
1723 let leader_id = request
1724 .leader_id
1725 .ok_or_else(|| anyhow!("invalid leader id"))?
1726 .into();
1727 let follower_id = session.connection_id;
1728
1729 {
1730 let project_connection_ids = session
1731 .db()
1732 .await
1733 .project_connection_ids(project_id, session.connection_id)
1734 .await?;
1735
1736 if !project_connection_ids.contains(&leader_id) {
1737 Err(anyhow!("no such peer"))?;
1738 }
1739 }
1740
1741 let mut response_payload = session
1742 .peer
1743 .forward_request(session.connection_id, leader_id, request)
1744 .await?;
1745 response_payload
1746 .views
1747 .retain(|view| view.leader_id != Some(follower_id.into()));
1748 response.send(response_payload)?;
1749
1750 let room = session
1751 .db()
1752 .await
1753 .follow(project_id, leader_id, follower_id)
1754 .await?;
1755 room_updated(&room, &session.peer);
1756
1757 Ok(())
1758}
1759
1760async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1761 let project_id = ProjectId::from_proto(request.project_id);
1762 let leader_id = request
1763 .leader_id
1764 .ok_or_else(|| anyhow!("invalid leader id"))?
1765 .into();
1766 let follower_id = session.connection_id;
1767
1768 if !session
1769 .db()
1770 .await
1771 .project_connection_ids(project_id, session.connection_id)
1772 .await?
1773 .contains(&leader_id)
1774 {
1775 Err(anyhow!("no such peer"))?;
1776 }
1777
1778 session
1779 .peer
1780 .forward_send(session.connection_id, leader_id, request)?;
1781
1782 let room = session
1783 .db()
1784 .await
1785 .unfollow(project_id, leader_id, follower_id)
1786 .await?;
1787 room_updated(&room, &session.peer);
1788
1789 Ok(())
1790}
1791
1792async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1793 let project_id = ProjectId::from_proto(request.project_id);
1794 let project_connection_ids = session
1795 .db
1796 .lock()
1797 .await
1798 .project_connection_ids(project_id, session.connection_id)
1799 .await?;
1800
1801 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1802 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1803 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1804 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1805 });
1806 for follower_peer_id in request.follower_ids.iter().copied() {
1807 let follower_connection_id = follower_peer_id.into();
1808 if project_connection_ids.contains(&follower_connection_id)
1809 && Some(follower_peer_id) != leader_id
1810 {
1811 session.peer.forward_send(
1812 session.connection_id,
1813 follower_connection_id,
1814 request.clone(),
1815 )?;
1816 }
1817 }
1818 Ok(())
1819}
1820
1821async fn get_users(
1822 request: proto::GetUsers,
1823 response: Response<proto::GetUsers>,
1824 session: Session,
1825) -> Result<()> {
1826 let user_ids = request
1827 .user_ids
1828 .into_iter()
1829 .map(UserId::from_proto)
1830 .collect();
1831 let users = session
1832 .db()
1833 .await
1834 .get_users_by_ids(user_ids)
1835 .await?
1836 .into_iter()
1837 .map(|user| proto::User {
1838 id: user.id.to_proto(),
1839 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1840 github_login: user.github_login,
1841 })
1842 .collect();
1843 response.send(proto::UsersResponse { users })?;
1844 Ok(())
1845}
1846
1847async fn fuzzy_search_users(
1848 request: proto::FuzzySearchUsers,
1849 response: Response<proto::FuzzySearchUsers>,
1850 session: Session,
1851) -> Result<()> {
1852 let query = request.query;
1853 let users = match query.len() {
1854 0 => vec![],
1855 1 | 2 => session
1856 .db()
1857 .await
1858 .get_user_by_github_account(&query, None)
1859 .await?
1860 .into_iter()
1861 .collect(),
1862 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1863 };
1864 let users = users
1865 .into_iter()
1866 .filter(|user| user.id != session.user_id)
1867 .map(|user| proto::User {
1868 id: user.id.to_proto(),
1869 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1870 github_login: user.github_login,
1871 })
1872 .collect();
1873 response.send(proto::UsersResponse { users })?;
1874 Ok(())
1875}
1876
1877async fn request_contact(
1878 request: proto::RequestContact,
1879 response: Response<proto::RequestContact>,
1880 session: Session,
1881) -> Result<()> {
1882 let requester_id = session.user_id;
1883 let responder_id = UserId::from_proto(request.responder_id);
1884 if requester_id == responder_id {
1885 return Err(anyhow!("cannot add yourself as a contact"))?;
1886 }
1887
1888 session
1889 .db()
1890 .await
1891 .send_contact_request(requester_id, responder_id)
1892 .await?;
1893
1894 // Update outgoing contact requests of requester
1895 let mut update = proto::UpdateContacts::default();
1896 update.outgoing_requests.push(responder_id.to_proto());
1897 for connection_id in session
1898 .connection_pool()
1899 .await
1900 .user_connection_ids(requester_id)
1901 {
1902 session.peer.send(connection_id, update.clone())?;
1903 }
1904
1905 // Update incoming contact requests of responder
1906 let mut update = proto::UpdateContacts::default();
1907 update
1908 .incoming_requests
1909 .push(proto::IncomingContactRequest {
1910 requester_id: requester_id.to_proto(),
1911 should_notify: true,
1912 });
1913 for connection_id in session
1914 .connection_pool()
1915 .await
1916 .user_connection_ids(responder_id)
1917 {
1918 session.peer.send(connection_id, update.clone())?;
1919 }
1920
1921 response.send(proto::Ack {})?;
1922 Ok(())
1923}
1924
1925async fn respond_to_contact_request(
1926 request: proto::RespondToContactRequest,
1927 response: Response<proto::RespondToContactRequest>,
1928 session: Session,
1929) -> Result<()> {
1930 let responder_id = session.user_id;
1931 let requester_id = UserId::from_proto(request.requester_id);
1932 let db = session.db().await;
1933 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1934 db.dismiss_contact_notification(responder_id, requester_id)
1935 .await?;
1936 } else {
1937 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1938
1939 db.respond_to_contact_request(responder_id, requester_id, accept)
1940 .await?;
1941 let requester_busy = db.is_user_busy(requester_id).await?;
1942 let responder_busy = db.is_user_busy(responder_id).await?;
1943
1944 let pool = session.connection_pool().await;
1945 // Update responder with new contact
1946 let mut update = proto::UpdateContacts::default();
1947 if accept {
1948 update
1949 .contacts
1950 .push(contact_for_user(requester_id, false, requester_busy, &pool));
1951 }
1952 update
1953 .remove_incoming_requests
1954 .push(requester_id.to_proto());
1955 for connection_id in pool.user_connection_ids(responder_id) {
1956 session.peer.send(connection_id, update.clone())?;
1957 }
1958
1959 // Update requester with new contact
1960 let mut update = proto::UpdateContacts::default();
1961 if accept {
1962 update
1963 .contacts
1964 .push(contact_for_user(responder_id, true, responder_busy, &pool));
1965 }
1966 update
1967 .remove_outgoing_requests
1968 .push(responder_id.to_proto());
1969 for connection_id in pool.user_connection_ids(requester_id) {
1970 session.peer.send(connection_id, update.clone())?;
1971 }
1972 }
1973
1974 response.send(proto::Ack {})?;
1975 Ok(())
1976}
1977
1978async fn remove_contact(
1979 request: proto::RemoveContact,
1980 response: Response<proto::RemoveContact>,
1981 session: Session,
1982) -> Result<()> {
1983 let requester_id = session.user_id;
1984 let responder_id = UserId::from_proto(request.user_id);
1985 let db = session.db().await;
1986 let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
1987
1988 let pool = session.connection_pool().await;
1989 // Update outgoing contact requests of requester
1990 let mut update = proto::UpdateContacts::default();
1991 if contact_accepted {
1992 update.remove_contacts.push(responder_id.to_proto());
1993 } else {
1994 update
1995 .remove_outgoing_requests
1996 .push(responder_id.to_proto());
1997 }
1998 for connection_id in pool.user_connection_ids(requester_id) {
1999 session.peer.send(connection_id, update.clone())?;
2000 }
2001
2002 // Update incoming contact requests of responder
2003 let mut update = proto::UpdateContacts::default();
2004 if contact_accepted {
2005 update.remove_contacts.push(requester_id.to_proto());
2006 } else {
2007 update
2008 .remove_incoming_requests
2009 .push(requester_id.to_proto());
2010 }
2011 for connection_id in pool.user_connection_ids(responder_id) {
2012 session.peer.send(connection_id, update.clone())?;
2013 }
2014
2015 response.send(proto::Ack {})?;
2016 Ok(())
2017}
2018
2019async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2020 let project_id = ProjectId::from_proto(request.project_id);
2021 let project_connection_ids = session
2022 .db()
2023 .await
2024 .project_connection_ids(project_id, session.connection_id)
2025 .await?;
2026 broadcast(
2027 Some(session.connection_id),
2028 project_connection_ids.iter().copied(),
2029 |connection_id| {
2030 session
2031 .peer
2032 .forward_send(session.connection_id, connection_id, request.clone())
2033 },
2034 );
2035 Ok(())
2036}
2037
2038async fn get_private_user_info(
2039 _request: proto::GetPrivateUserInfo,
2040 response: Response<proto::GetPrivateUserInfo>,
2041 session: Session,
2042) -> Result<()> {
2043 let metrics_id = session
2044 .db()
2045 .await
2046 .get_user_metrics_id(session.user_id)
2047 .await?;
2048 let user = session
2049 .db()
2050 .await
2051 .get_user_by_id(session.user_id)
2052 .await?
2053 .ok_or_else(|| anyhow!("user not found"))?;
2054 response.send(proto::GetPrivateUserInfoResponse {
2055 metrics_id,
2056 staff: user.admin,
2057 })?;
2058 Ok(())
2059}
2060
2061fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2062 match message {
2063 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2064 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2065 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2066 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2067 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2068 code: frame.code.into(),
2069 reason: frame.reason,
2070 })),
2071 }
2072}
2073
2074fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2075 match message {
2076 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2077 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2078 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2079 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2080 AxumMessage::Close(frame) => {
2081 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2082 code: frame.code.into(),
2083 reason: frame.reason,
2084 }))
2085 }
2086 }
2087}
2088
2089fn build_initial_contacts_update(
2090 contacts: Vec<db::Contact>,
2091 pool: &ConnectionPool,
2092) -> proto::UpdateContacts {
2093 let mut update = proto::UpdateContacts::default();
2094
2095 for contact in contacts {
2096 match contact {
2097 db::Contact::Accepted {
2098 user_id,
2099 should_notify,
2100 busy,
2101 } => {
2102 update
2103 .contacts
2104 .push(contact_for_user(user_id, should_notify, busy, &pool));
2105 }
2106 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2107 db::Contact::Incoming {
2108 user_id,
2109 should_notify,
2110 } => update
2111 .incoming_requests
2112 .push(proto::IncomingContactRequest {
2113 requester_id: user_id.to_proto(),
2114 should_notify,
2115 }),
2116 }
2117 }
2118
2119 update
2120}
2121
2122fn contact_for_user(
2123 user_id: UserId,
2124 should_notify: bool,
2125 busy: bool,
2126 pool: &ConnectionPool,
2127) -> proto::Contact {
2128 proto::Contact {
2129 user_id: user_id.to_proto(),
2130 online: pool.is_user_online(user_id),
2131 busy,
2132 should_notify,
2133 }
2134}
2135
2136fn room_updated(room: &proto::Room, peer: &Peer) {
2137 broadcast(
2138 None,
2139 room.participants
2140 .iter()
2141 .filter_map(|participant| Some(participant.peer_id?.into())),
2142 |peer_id| {
2143 peer.send(
2144 peer_id.into(),
2145 proto::RoomUpdated {
2146 room: Some(room.clone()),
2147 },
2148 )
2149 },
2150 );
2151}
2152
2153async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2154 let db = session.db().await;
2155 let contacts = db.get_contacts(user_id).await?;
2156 let busy = db.is_user_busy(user_id).await?;
2157
2158 let pool = session.connection_pool().await;
2159 let updated_contact = contact_for_user(user_id, false, busy, &pool);
2160 for contact in contacts {
2161 if let db::Contact::Accepted {
2162 user_id: contact_user_id,
2163 ..
2164 } = contact
2165 {
2166 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2167 session
2168 .peer
2169 .send(
2170 contact_conn_id,
2171 proto::UpdateContacts {
2172 contacts: vec![updated_contact.clone()],
2173 remove_contacts: Default::default(),
2174 incoming_requests: Default::default(),
2175 remove_incoming_requests: Default::default(),
2176 outgoing_requests: Default::default(),
2177 remove_outgoing_requests: Default::default(),
2178 },
2179 )
2180 .trace_err();
2181 }
2182 }
2183 }
2184 Ok(())
2185}
2186
2187async fn leave_room_for_session(session: &Session) -> Result<()> {
2188 let mut contacts_to_update = HashSet::default();
2189
2190 let room_id;
2191 let canceled_calls_to_user_ids;
2192 let live_kit_room;
2193 let delete_live_kit_room;
2194 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2195 contacts_to_update.insert(session.user_id);
2196
2197 for project in left_room.left_projects.values() {
2198 project_left(project, session);
2199 }
2200
2201 room_updated(&left_room.room, &session.peer);
2202 room_id = RoomId::from_proto(left_room.room.id);
2203 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2204 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2205 delete_live_kit_room = left_room.room.participants.is_empty();
2206 } else {
2207 return Ok(());
2208 }
2209
2210 {
2211 let pool = session.connection_pool().await;
2212 for canceled_user_id in canceled_calls_to_user_ids {
2213 for connection_id in pool.user_connection_ids(canceled_user_id) {
2214 session
2215 .peer
2216 .send(
2217 connection_id,
2218 proto::CallCanceled {
2219 room_id: room_id.to_proto(),
2220 },
2221 )
2222 .trace_err();
2223 }
2224 contacts_to_update.insert(canceled_user_id);
2225 }
2226 }
2227
2228 for contact_user_id in contacts_to_update {
2229 update_user_contacts(contact_user_id, &session).await?;
2230 }
2231
2232 if let Some(live_kit) = session.live_kit_client.as_ref() {
2233 live_kit
2234 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2235 .await
2236 .trace_err();
2237
2238 if delete_live_kit_room {
2239 live_kit.delete_room(live_kit_room).await.trace_err();
2240 }
2241 }
2242
2243 Ok(())
2244}
2245
2246fn project_left(project: &db::LeftProject, session: &Session) {
2247 for connection_id in &project.connection_ids {
2248 if project.host_user_id == session.user_id {
2249 session
2250 .peer
2251 .send(
2252 *connection_id,
2253 proto::UnshareProject {
2254 project_id: project.id.to_proto(),
2255 },
2256 )
2257 .trace_err();
2258 } else {
2259 session
2260 .peer
2261 .send(
2262 *connection_id,
2263 proto::RemoveProjectCollaborator {
2264 project_id: project.id.to_proto(),
2265 peer_id: Some(session.connection_id.into()),
2266 },
2267 )
2268 .trace_err();
2269 }
2270 }
2271}
2272
2273pub trait ResultExt {
2274 type Ok;
2275
2276 fn trace_err(self) -> Option<Self::Ok>;
2277}
2278
2279impl<T, E> ResultExt for Result<T, E>
2280where
2281 E: std::fmt::Debug,
2282{
2283 type Ok = T;
2284
2285 fn trace_err(self) -> Option<T> {
2286 match self {
2287 Ok(value) => Some(value),
2288 Err(error) => {
2289 tracing::error!("{:?}", error);
2290 None
2291 }
2292 }
2293 }
2294}