1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, Database, ProjectId, RoomId, ServerId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
38 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
39};
40use serde::{Serialize, Serializer};
41use std::{
42 any::TypeId,
43 fmt,
44 future::Future,
45 marker::PhantomData,
46 mem,
47 net::SocketAddr,
48 ops::{Deref, DerefMut},
49 rc::Rc,
50 sync::{
51 atomic::{AtomicBool, Ordering::SeqCst},
52 Arc,
53 },
54 time::Duration,
55};
56use tokio::sync::{watch, Semaphore};
57use tower::ServiceBuilder;
58use tracing::{info_span, instrument, Instrument};
59
60pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
61pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
62
63lazy_static! {
64 static ref METRIC_CONNECTIONS: IntGauge =
65 register_int_gauge!("connections", "number of connections").unwrap();
66 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
67 "shared_projects",
68 "number of open projects with one or more guests"
69 )
70 .unwrap();
71}
72
73type MessageHandler =
74 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
75
76struct Response<R> {
77 peer: Arc<Peer>,
78 receipt: Receipt<R>,
79 responded: Arc<AtomicBool>,
80}
81
82impl<R: RequestMessage> Response<R> {
83 fn send(self, payload: R::Response) -> Result<()> {
84 self.responded.store(true, SeqCst);
85 self.peer.respond(self.receipt, payload)?;
86 Ok(())
87 }
88}
89
90#[derive(Clone)]
91struct Session {
92 user_id: UserId,
93 connection_id: ConnectionId,
94 db: Arc<tokio::sync::Mutex<DbHandle>>,
95 peer: Arc<Peer>,
96 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
97 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
98 executor: Executor,
99}
100
101impl Session {
102 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
103 #[cfg(test)]
104 tokio::task::yield_now().await;
105 let guard = self.db.lock().await;
106 #[cfg(test)]
107 tokio::task::yield_now().await;
108 guard
109 }
110
111 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
112 #[cfg(test)]
113 tokio::task::yield_now().await;
114 let guard = self.connection_pool.lock();
115 ConnectionPoolGuard {
116 guard,
117 _not_send: PhantomData,
118 }
119 }
120}
121
122impl fmt::Debug for Session {
123 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124 f.debug_struct("Session")
125 .field("user_id", &self.user_id)
126 .field("connection_id", &self.connection_id)
127 .finish()
128 }
129}
130
131struct DbHandle(Arc<Database>);
132
133impl Deref for DbHandle {
134 type Target = Database;
135
136 fn deref(&self) -> &Self::Target {
137 self.0.as_ref()
138 }
139}
140
141pub struct Server {
142 id: parking_lot::Mutex<ServerId>,
143 peer: Arc<Peer>,
144 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
145 app_state: Arc<AppState>,
146 executor: Executor,
147 handlers: HashMap<TypeId, MessageHandler>,
148 teardown: watch::Sender<()>,
149}
150
151pub(crate) struct ConnectionPoolGuard<'a> {
152 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
153 _not_send: PhantomData<Rc<()>>,
154}
155
156#[derive(Serialize)]
157pub struct ServerSnapshot<'a> {
158 peer: &'a Peer,
159 #[serde(serialize_with = "serialize_deref")]
160 connection_pool: ConnectionPoolGuard<'a>,
161}
162
163pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
164where
165 S: Serializer,
166 T: Deref<Target = U>,
167 U: Serialize,
168{
169 Serialize::serialize(value.deref(), serializer)
170}
171
172impl Server {
173 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
174 let mut server = Self {
175 id: parking_lot::Mutex::new(id),
176 peer: Peer::new(id.0 as u32),
177 app_state,
178 executor,
179 connection_pool: Default::default(),
180 handlers: Default::default(),
181 teardown: watch::channel(()).0,
182 };
183
184 server
185 .add_request_handler(ping)
186 .add_request_handler(create_room)
187 .add_request_handler(join_room)
188 .add_request_handler(rejoin_room)
189 .add_request_handler(leave_room)
190 .add_request_handler(call)
191 .add_request_handler(cancel_call)
192 .add_message_handler(decline_call)
193 .add_request_handler(update_participant_location)
194 .add_request_handler(share_project)
195 .add_message_handler(unshare_project)
196 .add_request_handler(join_project)
197 .add_message_handler(leave_project)
198 .add_request_handler(update_project)
199 .add_request_handler(update_worktree)
200 .add_message_handler(start_language_server)
201 .add_message_handler(update_language_server)
202 .add_message_handler(update_diagnostic_summary)
203 .add_request_handler(forward_project_request::<proto::GetHover>)
204 .add_request_handler(forward_project_request::<proto::GetDefinition>)
205 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
206 .add_request_handler(forward_project_request::<proto::GetReferences>)
207 .add_request_handler(forward_project_request::<proto::SearchProject>)
208 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
209 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
210 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
211 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
212 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
213 .add_request_handler(forward_project_request::<proto::GetCompletions>)
214 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
215 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
216 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
217 .add_request_handler(forward_project_request::<proto::PrepareRename>)
218 .add_request_handler(forward_project_request::<proto::PerformRename>)
219 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
220 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
221 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
222 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
223 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
224 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
225 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
226 .add_message_handler(create_buffer_for_peer)
227 .add_request_handler(update_buffer)
228 .add_message_handler(update_buffer_file)
229 .add_message_handler(buffer_reloaded)
230 .add_message_handler(buffer_saved)
231 .add_request_handler(save_buffer)
232 .add_request_handler(get_users)
233 .add_request_handler(fuzzy_search_users)
234 .add_request_handler(request_contact)
235 .add_request_handler(remove_contact)
236 .add_request_handler(respond_to_contact_request)
237 .add_request_handler(follow)
238 .add_message_handler(unfollow)
239 .add_message_handler(update_followers)
240 .add_message_handler(update_diff_base)
241 .add_request_handler(get_private_user_info);
242
243 Arc::new(server)
244 }
245
246 pub async fn start(&self) -> Result<()> {
247 let server_id = *self.id.lock();
248 let app_state = self.app_state.clone();
249 let peer = self.peer.clone();
250 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
251 let pool = self.connection_pool.clone();
252 let live_kit_client = self.app_state.live_kit_client.clone();
253
254 let span = info_span!("start server");
255 self.executor.spawn_detached(
256 async move {
257 tracing::info!("waiting for cleanup timeout");
258 timeout.await;
259 tracing::info!("cleanup timeout expired, retrieving stale rooms");
260 if let Some(room_ids) = app_state
261 .db
262 .stale_room_ids(&app_state.config.zed_environment, server_id)
263 .await
264 .trace_err()
265 {
266 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
267 for room_id in room_ids {
268 let mut contacts_to_update = HashSet::default();
269 let mut canceled_calls_to_user_ids = Vec::new();
270 let mut live_kit_room = String::new();
271 let mut delete_live_kit_room = false;
272
273 if let Some(mut refreshed_room) = app_state
274 .db
275 .refresh_room(room_id, server_id)
276 .await
277 .trace_err()
278 {
279 tracing::info!(
280 room_id = room_id.0,
281 new_participant_count = refreshed_room.room.participants.len(),
282 "refreshed room"
283 );
284 room_updated(&refreshed_room.room, &peer);
285 contacts_to_update
286 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
287 contacts_to_update
288 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
289 canceled_calls_to_user_ids =
290 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
291 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
292 delete_live_kit_room = refreshed_room.room.participants.is_empty();
293 }
294
295 {
296 let pool = pool.lock();
297 for canceled_user_id in canceled_calls_to_user_ids {
298 for connection_id in pool.user_connection_ids(canceled_user_id) {
299 peer.send(
300 connection_id,
301 proto::CallCanceled {
302 room_id: room_id.to_proto(),
303 },
304 )
305 .trace_err();
306 }
307 }
308 }
309
310 for user_id in contacts_to_update {
311 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
312 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
313 if let Some((busy, contacts)) = busy.zip(contacts) {
314 let pool = pool.lock();
315 let updated_contact = contact_for_user(user_id, false, busy, &pool);
316 for contact in contacts {
317 if let db::Contact::Accepted {
318 user_id: contact_user_id,
319 ..
320 } = contact
321 {
322 for contact_conn_id in
323 pool.user_connection_ids(contact_user_id)
324 {
325 peer.send(
326 contact_conn_id,
327 proto::UpdateContacts {
328 contacts: vec![updated_contact.clone()],
329 remove_contacts: Default::default(),
330 incoming_requests: Default::default(),
331 remove_incoming_requests: Default::default(),
332 outgoing_requests: Default::default(),
333 remove_outgoing_requests: Default::default(),
334 },
335 )
336 .trace_err();
337 }
338 }
339 }
340 }
341 }
342
343 if let Some(live_kit) = live_kit_client.as_ref() {
344 if delete_live_kit_room {
345 live_kit.delete_room(live_kit_room).await.trace_err();
346 }
347 }
348 }
349 }
350
351 app_state
352 .db
353 .delete_stale_servers(&app_state.config.zed_environment, server_id)
354 .await
355 .trace_err();
356 }
357 .instrument(span),
358 );
359 Ok(())
360 }
361
362 pub fn teardown(&self) {
363 self.peer.teardown();
364 self.connection_pool.lock().reset();
365 let _ = self.teardown.send(());
366 }
367
368 #[cfg(test)]
369 pub fn reset(&self, id: ServerId) {
370 self.teardown();
371 *self.id.lock() = id;
372 self.peer.reset(id.0 as u32);
373 }
374
375 #[cfg(test)]
376 pub fn id(&self) -> ServerId {
377 *self.id.lock()
378 }
379
380 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
381 where
382 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
383 Fut: 'static + Send + Future<Output = Result<()>>,
384 M: EnvelopedMessage,
385 {
386 let prev_handler = self.handlers.insert(
387 TypeId::of::<M>(),
388 Box::new(move |envelope, session| {
389 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
390 let span = info_span!(
391 "handle message",
392 payload_type = envelope.payload_type_name()
393 );
394 span.in_scope(|| {
395 tracing::info!(
396 payload_type = envelope.payload_type_name(),
397 "message received"
398 );
399 });
400 let future = (handler)(*envelope, session);
401 async move {
402 if let Err(error) = future.await {
403 tracing::error!(%error, "error handling message");
404 }
405 }
406 .instrument(span)
407 .boxed()
408 }),
409 );
410 if prev_handler.is_some() {
411 panic!("registered a handler for the same message twice");
412 }
413 self
414 }
415
416 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
417 where
418 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
419 Fut: 'static + Send + Future<Output = Result<()>>,
420 M: EnvelopedMessage,
421 {
422 self.add_handler(move |envelope, session| handler(envelope.payload, session));
423 self
424 }
425
426 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
427 where
428 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
429 Fut: Send + Future<Output = Result<()>>,
430 M: RequestMessage,
431 {
432 let handler = Arc::new(handler);
433 self.add_handler(move |envelope, session| {
434 let receipt = envelope.receipt();
435 let handler = handler.clone();
436 async move {
437 let peer = session.peer.clone();
438 let responded = Arc::new(AtomicBool::default());
439 let response = Response {
440 peer: peer.clone(),
441 responded: responded.clone(),
442 receipt,
443 };
444 match (handler)(envelope.payload, response, session).await {
445 Ok(()) => {
446 if responded.load(std::sync::atomic::Ordering::SeqCst) {
447 Ok(())
448 } else {
449 Err(anyhow!("handler did not send a response"))?
450 }
451 }
452 Err(error) => {
453 peer.respond_with_error(
454 receipt,
455 proto::Error {
456 message: error.to_string(),
457 },
458 )?;
459 Err(error)
460 }
461 }
462 }
463 })
464 }
465
466 pub fn handle_connection(
467 self: &Arc<Self>,
468 connection: Connection,
469 address: String,
470 user: User,
471 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
472 executor: Executor,
473 ) -> impl Future<Output = Result<()>> {
474 let this = self.clone();
475 let user_id = user.id;
476 let login = user.github_login;
477 let span = info_span!("handle connection", %user_id, %login, %address);
478 let mut teardown = self.teardown.subscribe();
479 async move {
480 let (connection_id, handle_io, mut incoming_rx) = this
481 .peer
482 .add_connection(connection, {
483 let executor = executor.clone();
484 move |duration| executor.sleep(duration)
485 });
486
487 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
488 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
489 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
490
491 if let Some(send_connection_id) = send_connection_id.take() {
492 let _ = send_connection_id.send(connection_id);
493 }
494
495 if !user.connected_once {
496 this.peer.send(connection_id, proto::ShowContacts {})?;
497 this.app_state.db.set_user_connected_once(user_id, true).await?;
498 }
499
500 let (contacts, invite_code) = future::try_join(
501 this.app_state.db.get_contacts(user_id),
502 this.app_state.db.get_invite_code_for_user(user_id)
503 ).await?;
504
505 {
506 let mut pool = this.connection_pool.lock();
507 pool.add_connection(connection_id, user_id, user.admin);
508 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
509
510 if let Some((code, count)) = invite_code {
511 this.peer.send(connection_id, proto::UpdateInviteInfo {
512 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
513 count: count as u32,
514 })?;
515 }
516 }
517
518 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
519 this.peer.send(connection_id, incoming_call)?;
520 }
521
522 let session = Session {
523 user_id,
524 connection_id,
525 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
526 peer: this.peer.clone(),
527 connection_pool: this.connection_pool.clone(),
528 live_kit_client: this.app_state.live_kit_client.clone(),
529 executor: executor.clone(),
530 };
531 update_user_contacts(user_id, &session).await?;
532
533 let handle_io = handle_io.fuse();
534 futures::pin_mut!(handle_io);
535
536 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
537 // This prevents deadlocks when e.g., client A performs a request to client B and
538 // client B performs a request to client A. If both clients stop processing further
539 // messages until their respective request completes, they won't have a chance to
540 // respond to the other client's request and cause a deadlock.
541 //
542 // This arrangement ensures we will attempt to process earlier messages first, but fall
543 // back to processing messages arrived later in the spirit of making progress.
544 let mut foreground_message_handlers = FuturesUnordered::new();
545 let concurrent_handlers = Arc::new(Semaphore::new(256));
546 loop {
547 let next_message = async {
548 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
549 let message = incoming_rx.next().await;
550 (permit, message)
551 }.fuse();
552 futures::pin_mut!(next_message);
553 futures::select_biased! {
554 _ = teardown.changed().fuse() => return Ok(()),
555 result = handle_io => {
556 if let Err(error) = result {
557 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
558 }
559 break;
560 }
561 _ = foreground_message_handlers.next() => {}
562 next_message = next_message => {
563 let (permit, message) = next_message;
564 if let Some(message) = message {
565 let type_name = message.payload_type_name();
566 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
567 let span_enter = span.enter();
568 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
569 let is_background = message.is_background();
570 let handle_message = (handler)(message, session.clone());
571 drop(span_enter);
572
573 let handle_message = async move {
574 handle_message.await;
575 drop(permit);
576 }.instrument(span);
577 if is_background {
578 executor.spawn_detached(handle_message);
579 } else {
580 foreground_message_handlers.push(handle_message);
581 }
582 } else {
583 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
584 }
585 } else {
586 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
587 break;
588 }
589 }
590 }
591 }
592
593 drop(foreground_message_handlers);
594 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
595 if let Err(error) = connection_lost(session, teardown, executor).await {
596 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
597 }
598
599 Ok(())
600 }.instrument(span)
601 }
602
603 pub async fn invite_code_redeemed(
604 self: &Arc<Self>,
605 inviter_id: UserId,
606 invitee_id: UserId,
607 ) -> Result<()> {
608 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
609 if let Some(code) = &user.invite_code {
610 let pool = self.connection_pool.lock();
611 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
612 for connection_id in pool.user_connection_ids(inviter_id) {
613 self.peer.send(
614 connection_id,
615 proto::UpdateContacts {
616 contacts: vec![invitee_contact.clone()],
617 ..Default::default()
618 },
619 )?;
620 self.peer.send(
621 connection_id,
622 proto::UpdateInviteInfo {
623 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
624 count: user.invite_count as u32,
625 },
626 )?;
627 }
628 }
629 }
630 Ok(())
631 }
632
633 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
634 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
635 if let Some(invite_code) = &user.invite_code {
636 let pool = self.connection_pool.lock();
637 for connection_id in pool.user_connection_ids(user_id) {
638 self.peer.send(
639 connection_id,
640 proto::UpdateInviteInfo {
641 url: format!(
642 "{}{}",
643 self.app_state.config.invite_link_prefix, invite_code
644 ),
645 count: user.invite_count as u32,
646 },
647 )?;
648 }
649 }
650 }
651 Ok(())
652 }
653
654 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
655 ServerSnapshot {
656 connection_pool: ConnectionPoolGuard {
657 guard: self.connection_pool.lock(),
658 _not_send: PhantomData,
659 },
660 peer: &self.peer,
661 }
662 }
663}
664
665impl<'a> Deref for ConnectionPoolGuard<'a> {
666 type Target = ConnectionPool;
667
668 fn deref(&self) -> &Self::Target {
669 &*self.guard
670 }
671}
672
673impl<'a> DerefMut for ConnectionPoolGuard<'a> {
674 fn deref_mut(&mut self) -> &mut Self::Target {
675 &mut *self.guard
676 }
677}
678
679impl<'a> Drop for ConnectionPoolGuard<'a> {
680 fn drop(&mut self) {
681 #[cfg(test)]
682 self.check_invariants();
683 }
684}
685
686fn broadcast<F>(
687 sender_id: Option<ConnectionId>,
688 receiver_ids: impl IntoIterator<Item = ConnectionId>,
689 mut f: F,
690) where
691 F: FnMut(ConnectionId) -> anyhow::Result<()>,
692{
693 for receiver_id in receiver_ids {
694 if Some(receiver_id) != sender_id {
695 if let Err(error) = f(receiver_id) {
696 tracing::error!("failed to send to {:?} {}", receiver_id, error);
697 }
698 }
699 }
700}
701
702lazy_static! {
703 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
704}
705
706pub struct ProtocolVersion(u32);
707
708impl Header for ProtocolVersion {
709 fn name() -> &'static HeaderName {
710 &ZED_PROTOCOL_VERSION
711 }
712
713 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
714 where
715 Self: Sized,
716 I: Iterator<Item = &'i axum::http::HeaderValue>,
717 {
718 let version = values
719 .next()
720 .ok_or_else(axum::headers::Error::invalid)?
721 .to_str()
722 .map_err(|_| axum::headers::Error::invalid())?
723 .parse()
724 .map_err(|_| axum::headers::Error::invalid())?;
725 Ok(Self(version))
726 }
727
728 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
729 values.extend([self.0.to_string().parse().unwrap()]);
730 }
731}
732
733pub fn routes(server: Arc<Server>) -> Router<Body> {
734 Router::new()
735 .route("/rpc", get(handle_websocket_request))
736 .layer(
737 ServiceBuilder::new()
738 .layer(Extension(server.app_state.clone()))
739 .layer(middleware::from_fn(auth::validate_header)),
740 )
741 .route("/metrics", get(handle_metrics))
742 .layer(Extension(server))
743}
744
745pub async fn handle_websocket_request(
746 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
747 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
748 Extension(server): Extension<Arc<Server>>,
749 Extension(user): Extension<User>,
750 ws: WebSocketUpgrade,
751) -> axum::response::Response {
752 if protocol_version != rpc::PROTOCOL_VERSION {
753 return (
754 StatusCode::UPGRADE_REQUIRED,
755 "client must be upgraded".to_string(),
756 )
757 .into_response();
758 }
759 let socket_address = socket_address.to_string();
760 ws.on_upgrade(move |socket| {
761 use util::ResultExt;
762 let socket = socket
763 .map_ok(to_tungstenite_message)
764 .err_into()
765 .with(|message| async move { Ok(to_axum_message(message)) });
766 let connection = Connection::new(Box::pin(socket));
767 async move {
768 server
769 .handle_connection(connection, socket_address, user, None, Executor::Production)
770 .await
771 .log_err();
772 }
773 })
774}
775
776pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
777 let connections = server
778 .connection_pool
779 .lock()
780 .connections()
781 .filter(|connection| !connection.admin)
782 .count();
783
784 METRIC_CONNECTIONS.set(connections as _);
785
786 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
787 METRIC_SHARED_PROJECTS.set(shared_projects as _);
788
789 let encoder = prometheus::TextEncoder::new();
790 let metric_families = prometheus::gather();
791 let encoded_metrics = encoder
792 .encode_to_string(&metric_families)
793 .map_err(|err| anyhow!("{}", err))?;
794 Ok(encoded_metrics)
795}
796
797#[instrument(err, skip(executor))]
798async fn connection_lost(
799 session: Session,
800 mut teardown: watch::Receiver<()>,
801 executor: Executor,
802) -> Result<()> {
803 session.peer.disconnect(session.connection_id);
804 session
805 .connection_pool()
806 .await
807 .remove_connection(session.connection_id)?;
808
809 session
810 .db()
811 .await
812 .connection_lost(session.connection_id)
813 .await
814 .trace_err();
815
816 futures::select_biased! {
817 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
818 leave_room_for_session(&session).await.trace_err();
819
820 if !session
821 .connection_pool()
822 .await
823 .is_user_online(session.user_id)
824 {
825 let db = session.db().await;
826 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
827 room_updated(&room, &session.peer);
828 }
829 }
830 update_user_contacts(session.user_id, &session).await?;
831 }
832 _ = teardown.changed().fuse() => {}
833 }
834
835 Ok(())
836}
837
838async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
839 response.send(proto::Ack {})?;
840 Ok(())
841}
842
843async fn create_room(
844 _request: proto::CreateRoom,
845 response: Response<proto::CreateRoom>,
846 session: Session,
847) -> Result<()> {
848 let live_kit_room = nanoid::nanoid!(30);
849 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
850 if let Some(_) = live_kit
851 .create_room(live_kit_room.clone())
852 .await
853 .trace_err()
854 {
855 if let Some(token) = live_kit
856 .room_token(&live_kit_room, &session.user_id.to_string())
857 .trace_err()
858 {
859 Some(proto::LiveKitConnectionInfo {
860 server_url: live_kit.url().into(),
861 token,
862 })
863 } else {
864 None
865 }
866 } else {
867 None
868 }
869 } else {
870 None
871 };
872
873 {
874 let room = session
875 .db()
876 .await
877 .create_room(session.user_id, session.connection_id, &live_kit_room)
878 .await?;
879
880 response.send(proto::CreateRoomResponse {
881 room: Some(room.clone()),
882 live_kit_connection_info,
883 })?;
884 }
885
886 update_user_contacts(session.user_id, &session).await?;
887 Ok(())
888}
889
890async fn join_room(
891 request: proto::JoinRoom,
892 response: Response<proto::JoinRoom>,
893 session: Session,
894) -> Result<()> {
895 let room_id = RoomId::from_proto(request.id);
896 let room = {
897 let room = session
898 .db()
899 .await
900 .join_room(room_id, session.user_id, session.connection_id)
901 .await?;
902 room_updated(&room, &session.peer);
903 room.clone()
904 };
905
906 for connection_id in session
907 .connection_pool()
908 .await
909 .user_connection_ids(session.user_id)
910 {
911 session
912 .peer
913 .send(
914 connection_id,
915 proto::CallCanceled {
916 room_id: room_id.to_proto(),
917 },
918 )
919 .trace_err();
920 }
921
922 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
923 if let Some(token) = live_kit
924 .room_token(&room.live_kit_room, &session.user_id.to_string())
925 .trace_err()
926 {
927 Some(proto::LiveKitConnectionInfo {
928 server_url: live_kit.url().into(),
929 token,
930 })
931 } else {
932 None
933 }
934 } else {
935 None
936 };
937
938 response.send(proto::JoinRoomResponse {
939 room: Some(room),
940 live_kit_connection_info,
941 })?;
942
943 update_user_contacts(session.user_id, &session).await?;
944 Ok(())
945}
946
947async fn rejoin_room(
948 request: proto::RejoinRoom,
949 response: Response<proto::RejoinRoom>,
950 session: Session,
951) -> Result<()> {
952 {
953 let mut rejoined_room = session
954 .db()
955 .await
956 .rejoin_room(request, session.user_id, session.connection_id)
957 .await?;
958
959 response.send(proto::RejoinRoomResponse {
960 room: Some(rejoined_room.room.clone()),
961 reshared_projects: rejoined_room
962 .reshared_projects
963 .iter()
964 .map(|project| proto::ResharedProject {
965 id: project.id.to_proto(),
966 collaborators: project
967 .collaborators
968 .iter()
969 .map(|collaborator| collaborator.to_proto())
970 .collect(),
971 })
972 .collect(),
973 rejoined_projects: rejoined_room
974 .rejoined_projects
975 .iter()
976 .map(|rejoined_project| proto::RejoinedProject {
977 id: rejoined_project.id.to_proto(),
978 worktrees: rejoined_project
979 .worktrees
980 .iter()
981 .map(|worktree| proto::WorktreeMetadata {
982 id: worktree.id,
983 root_name: worktree.root_name.clone(),
984 visible: worktree.visible,
985 abs_path: worktree.abs_path.clone(),
986 })
987 .collect(),
988 collaborators: rejoined_project
989 .collaborators
990 .iter()
991 .map(|collaborator| collaborator.to_proto())
992 .collect(),
993 language_servers: rejoined_project.language_servers.clone(),
994 })
995 .collect(),
996 })?;
997 room_updated(&rejoined_room.room, &session.peer);
998
999 for project in &rejoined_room.reshared_projects {
1000 for collaborator in &project.collaborators {
1001 session
1002 .peer
1003 .send(
1004 collaborator.connection_id,
1005 proto::UpdateProjectCollaborator {
1006 project_id: project.id.to_proto(),
1007 old_peer_id: Some(project.old_connection_id.into()),
1008 new_peer_id: Some(session.connection_id.into()),
1009 },
1010 )
1011 .trace_err();
1012 }
1013
1014 broadcast(
1015 Some(session.connection_id),
1016 project
1017 .collaborators
1018 .iter()
1019 .map(|collaborator| collaborator.connection_id),
1020 |connection_id| {
1021 session.peer.forward_send(
1022 session.connection_id,
1023 connection_id,
1024 proto::UpdateProject {
1025 project_id: project.id.to_proto(),
1026 worktrees: project.worktrees.clone(),
1027 },
1028 )
1029 },
1030 );
1031 }
1032
1033 for project in &rejoined_room.rejoined_projects {
1034 for collaborator in &project.collaborators {
1035 session
1036 .peer
1037 .send(
1038 collaborator.connection_id,
1039 proto::UpdateProjectCollaborator {
1040 project_id: project.id.to_proto(),
1041 old_peer_id: Some(project.old_connection_id.into()),
1042 new_peer_id: Some(session.connection_id.into()),
1043 },
1044 )
1045 .trace_err();
1046 }
1047 }
1048
1049 for project in &mut rejoined_room.rejoined_projects {
1050 for worktree in mem::take(&mut project.worktrees) {
1051 #[cfg(any(test, feature = "test-support"))]
1052 const MAX_CHUNK_SIZE: usize = 2;
1053 #[cfg(not(any(test, feature = "test-support")))]
1054 const MAX_CHUNK_SIZE: usize = 256;
1055
1056 // Stream this worktree's entries.
1057 let message = proto::UpdateWorktree {
1058 project_id: project.id.to_proto(),
1059 worktree_id: worktree.id,
1060 abs_path: worktree.abs_path.clone(),
1061 root_name: worktree.root_name,
1062 updated_entries: worktree.updated_entries,
1063 removed_entries: worktree.removed_entries,
1064 scan_id: worktree.scan_id,
1065 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1066 };
1067 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1068 session.peer.send(session.connection_id, update.clone())?;
1069 }
1070
1071 // Stream this worktree's diagnostics.
1072 for summary in worktree.diagnostic_summaries {
1073 session.peer.send(
1074 session.connection_id,
1075 proto::UpdateDiagnosticSummary {
1076 project_id: project.id.to_proto(),
1077 worktree_id: worktree.id,
1078 summary: Some(summary),
1079 },
1080 )?;
1081 }
1082 }
1083
1084 for language_server in &project.language_servers {
1085 session.peer.send(
1086 session.connection_id,
1087 proto::UpdateLanguageServer {
1088 project_id: project.id.to_proto(),
1089 language_server_id: language_server.id,
1090 variant: Some(
1091 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1092 proto::LspDiskBasedDiagnosticsUpdated {},
1093 ),
1094 ),
1095 },
1096 )?;
1097 }
1098 }
1099 }
1100
1101 update_user_contacts(session.user_id, &session).await?;
1102 Ok(())
1103}
1104
1105async fn leave_room(
1106 _: proto::LeaveRoom,
1107 response: Response<proto::LeaveRoom>,
1108 session: Session,
1109) -> Result<()> {
1110 leave_room_for_session(&session).await?;
1111 response.send(proto::Ack {})?;
1112 Ok(())
1113}
1114
1115async fn call(
1116 request: proto::Call,
1117 response: Response<proto::Call>,
1118 session: Session,
1119) -> Result<()> {
1120 let room_id = RoomId::from_proto(request.room_id);
1121 let calling_user_id = session.user_id;
1122 let calling_connection_id = session.connection_id;
1123 let called_user_id = UserId::from_proto(request.called_user_id);
1124 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1125 if !session
1126 .db()
1127 .await
1128 .has_contact(calling_user_id, called_user_id)
1129 .await?
1130 {
1131 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1132 }
1133
1134 let incoming_call = {
1135 let (room, incoming_call) = &mut *session
1136 .db()
1137 .await
1138 .call(
1139 room_id,
1140 calling_user_id,
1141 calling_connection_id,
1142 called_user_id,
1143 initial_project_id,
1144 )
1145 .await?;
1146 room_updated(&room, &session.peer);
1147 mem::take(incoming_call)
1148 };
1149 update_user_contacts(called_user_id, &session).await?;
1150
1151 let mut calls = session
1152 .connection_pool()
1153 .await
1154 .user_connection_ids(called_user_id)
1155 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1156 .collect::<FuturesUnordered<_>>();
1157
1158 while let Some(call_response) = calls.next().await {
1159 match call_response.as_ref() {
1160 Ok(_) => {
1161 response.send(proto::Ack {})?;
1162 return Ok(());
1163 }
1164 Err(_) => {
1165 call_response.trace_err();
1166 }
1167 }
1168 }
1169
1170 {
1171 let room = session
1172 .db()
1173 .await
1174 .call_failed(room_id, called_user_id)
1175 .await?;
1176 room_updated(&room, &session.peer);
1177 }
1178 update_user_contacts(called_user_id, &session).await?;
1179
1180 Err(anyhow!("failed to ring user"))?
1181}
1182
1183async fn cancel_call(
1184 request: proto::CancelCall,
1185 response: Response<proto::CancelCall>,
1186 session: Session,
1187) -> Result<()> {
1188 let called_user_id = UserId::from_proto(request.called_user_id);
1189 let room_id = RoomId::from_proto(request.room_id);
1190 {
1191 let room = session
1192 .db()
1193 .await
1194 .cancel_call(room_id, session.connection_id, called_user_id)
1195 .await?;
1196 room_updated(&room, &session.peer);
1197 }
1198
1199 for connection_id in session
1200 .connection_pool()
1201 .await
1202 .user_connection_ids(called_user_id)
1203 {
1204 session
1205 .peer
1206 .send(
1207 connection_id,
1208 proto::CallCanceled {
1209 room_id: room_id.to_proto(),
1210 },
1211 )
1212 .trace_err();
1213 }
1214 response.send(proto::Ack {})?;
1215
1216 update_user_contacts(called_user_id, &session).await?;
1217 Ok(())
1218}
1219
1220async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1221 let room_id = RoomId::from_proto(message.room_id);
1222 {
1223 let room = session
1224 .db()
1225 .await
1226 .decline_call(Some(room_id), session.user_id)
1227 .await?
1228 .ok_or_else(|| anyhow!("failed to decline call"))?;
1229 room_updated(&room, &session.peer);
1230 }
1231
1232 for connection_id in session
1233 .connection_pool()
1234 .await
1235 .user_connection_ids(session.user_id)
1236 {
1237 session
1238 .peer
1239 .send(
1240 connection_id,
1241 proto::CallCanceled {
1242 room_id: room_id.to_proto(),
1243 },
1244 )
1245 .trace_err();
1246 }
1247 update_user_contacts(session.user_id, &session).await?;
1248 Ok(())
1249}
1250
1251async fn update_participant_location(
1252 request: proto::UpdateParticipantLocation,
1253 response: Response<proto::UpdateParticipantLocation>,
1254 session: Session,
1255) -> Result<()> {
1256 let room_id = RoomId::from_proto(request.room_id);
1257 let location = request
1258 .location
1259 .ok_or_else(|| anyhow!("invalid location"))?;
1260 let room = session
1261 .db()
1262 .await
1263 .update_room_participant_location(room_id, session.connection_id, location)
1264 .await?;
1265 room_updated(&room, &session.peer);
1266 response.send(proto::Ack {})?;
1267 Ok(())
1268}
1269
1270async fn share_project(
1271 request: proto::ShareProject,
1272 response: Response<proto::ShareProject>,
1273 session: Session,
1274) -> Result<()> {
1275 let (project_id, room) = &*session
1276 .db()
1277 .await
1278 .share_project(
1279 RoomId::from_proto(request.room_id),
1280 session.connection_id,
1281 &request.worktrees,
1282 )
1283 .await?;
1284 response.send(proto::ShareProjectResponse {
1285 project_id: project_id.to_proto(),
1286 })?;
1287 room_updated(&room, &session.peer);
1288
1289 Ok(())
1290}
1291
1292async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1293 let project_id = ProjectId::from_proto(message.project_id);
1294
1295 let (room, guest_connection_ids) = &*session
1296 .db()
1297 .await
1298 .unshare_project(project_id, session.connection_id)
1299 .await?;
1300
1301 broadcast(
1302 Some(session.connection_id),
1303 guest_connection_ids.iter().copied(),
1304 |conn_id| session.peer.send(conn_id, message.clone()),
1305 );
1306 room_updated(&room, &session.peer);
1307
1308 Ok(())
1309}
1310
1311async fn join_project(
1312 request: proto::JoinProject,
1313 response: Response<proto::JoinProject>,
1314 session: Session,
1315) -> Result<()> {
1316 let project_id = ProjectId::from_proto(request.project_id);
1317 let guest_user_id = session.user_id;
1318
1319 tracing::info!(%project_id, "join project");
1320
1321 let (project, replica_id) = &mut *session
1322 .db()
1323 .await
1324 .join_project(project_id, session.connection_id)
1325 .await?;
1326
1327 let collaborators = project
1328 .collaborators
1329 .iter()
1330 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1331 .map(|collaborator| collaborator.to_proto())
1332 .collect::<Vec<_>>();
1333
1334 let worktrees = project
1335 .worktrees
1336 .iter()
1337 .map(|(id, worktree)| proto::WorktreeMetadata {
1338 id: *id,
1339 root_name: worktree.root_name.clone(),
1340 visible: worktree.visible,
1341 abs_path: worktree.abs_path.clone(),
1342 })
1343 .collect::<Vec<_>>();
1344
1345 for collaborator in &collaborators {
1346 session
1347 .peer
1348 .send(
1349 collaborator.peer_id.unwrap().into(),
1350 proto::AddProjectCollaborator {
1351 project_id: project_id.to_proto(),
1352 collaborator: Some(proto::Collaborator {
1353 peer_id: Some(session.connection_id.into()),
1354 replica_id: replica_id.0 as u32,
1355 user_id: guest_user_id.to_proto(),
1356 }),
1357 },
1358 )
1359 .trace_err();
1360 }
1361
1362 // First, we send the metadata associated with each worktree.
1363 response.send(proto::JoinProjectResponse {
1364 worktrees: worktrees.clone(),
1365 replica_id: replica_id.0 as u32,
1366 collaborators: collaborators.clone(),
1367 language_servers: project.language_servers.clone(),
1368 })?;
1369
1370 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1371 #[cfg(any(test, feature = "test-support"))]
1372 const MAX_CHUNK_SIZE: usize = 2;
1373 #[cfg(not(any(test, feature = "test-support")))]
1374 const MAX_CHUNK_SIZE: usize = 256;
1375
1376 // Stream this worktree's entries.
1377 let message = proto::UpdateWorktree {
1378 project_id: project_id.to_proto(),
1379 worktree_id,
1380 abs_path: worktree.abs_path.clone(),
1381 root_name: worktree.root_name,
1382 updated_entries: worktree.entries,
1383 removed_entries: Default::default(),
1384 scan_id: worktree.scan_id,
1385 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1386 };
1387 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1388 session.peer.send(session.connection_id, update.clone())?;
1389 }
1390
1391 // Stream this worktree's diagnostics.
1392 for summary in worktree.diagnostic_summaries {
1393 session.peer.send(
1394 session.connection_id,
1395 proto::UpdateDiagnosticSummary {
1396 project_id: project_id.to_proto(),
1397 worktree_id: worktree.id,
1398 summary: Some(summary),
1399 },
1400 )?;
1401 }
1402 }
1403
1404 for language_server in &project.language_servers {
1405 session.peer.send(
1406 session.connection_id,
1407 proto::UpdateLanguageServer {
1408 project_id: project_id.to_proto(),
1409 language_server_id: language_server.id,
1410 variant: Some(
1411 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1412 proto::LspDiskBasedDiagnosticsUpdated {},
1413 ),
1414 ),
1415 },
1416 )?;
1417 }
1418
1419 Ok(())
1420}
1421
1422async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1423 let sender_id = session.connection_id;
1424 let project_id = ProjectId::from_proto(request.project_id);
1425
1426 let (room, project) = &*session
1427 .db()
1428 .await
1429 .leave_project(project_id, sender_id)
1430 .await?;
1431 tracing::info!(
1432 %project_id,
1433 host_user_id = %project.host_user_id,
1434 host_connection_id = %project.host_connection_id,
1435 "leave project"
1436 );
1437
1438 project_left(&project, &session);
1439 room_updated(&room, &session.peer);
1440
1441 Ok(())
1442}
1443
1444async fn update_project(
1445 request: proto::UpdateProject,
1446 response: Response<proto::UpdateProject>,
1447 session: Session,
1448) -> Result<()> {
1449 let project_id = ProjectId::from_proto(request.project_id);
1450 let (room, guest_connection_ids) = &*session
1451 .db()
1452 .await
1453 .update_project(project_id, session.connection_id, &request.worktrees)
1454 .await?;
1455 broadcast(
1456 Some(session.connection_id),
1457 guest_connection_ids.iter().copied(),
1458 |connection_id| {
1459 session
1460 .peer
1461 .forward_send(session.connection_id, connection_id, request.clone())
1462 },
1463 );
1464 room_updated(&room, &session.peer);
1465 response.send(proto::Ack {})?;
1466
1467 Ok(())
1468}
1469
1470async fn update_worktree(
1471 request: proto::UpdateWorktree,
1472 response: Response<proto::UpdateWorktree>,
1473 session: Session,
1474) -> Result<()> {
1475 let guest_connection_ids = session
1476 .db()
1477 .await
1478 .update_worktree(&request, session.connection_id)
1479 .await?;
1480
1481 broadcast(
1482 Some(session.connection_id),
1483 guest_connection_ids.iter().copied(),
1484 |connection_id| {
1485 session
1486 .peer
1487 .forward_send(session.connection_id, connection_id, request.clone())
1488 },
1489 );
1490 response.send(proto::Ack {})?;
1491 Ok(())
1492}
1493
1494async fn update_diagnostic_summary(
1495 message: proto::UpdateDiagnosticSummary,
1496 session: Session,
1497) -> Result<()> {
1498 let guest_connection_ids = session
1499 .db()
1500 .await
1501 .update_diagnostic_summary(&message, session.connection_id)
1502 .await?;
1503
1504 broadcast(
1505 Some(session.connection_id),
1506 guest_connection_ids.iter().copied(),
1507 |connection_id| {
1508 session
1509 .peer
1510 .forward_send(session.connection_id, connection_id, message.clone())
1511 },
1512 );
1513
1514 Ok(())
1515}
1516
1517async fn start_language_server(
1518 request: proto::StartLanguageServer,
1519 session: Session,
1520) -> Result<()> {
1521 let guest_connection_ids = session
1522 .db()
1523 .await
1524 .start_language_server(&request, session.connection_id)
1525 .await?;
1526
1527 broadcast(
1528 Some(session.connection_id),
1529 guest_connection_ids.iter().copied(),
1530 |connection_id| {
1531 session
1532 .peer
1533 .forward_send(session.connection_id, connection_id, request.clone())
1534 },
1535 );
1536 Ok(())
1537}
1538
1539async fn update_language_server(
1540 request: proto::UpdateLanguageServer,
1541 session: Session,
1542) -> Result<()> {
1543 session.executor.record_backtrace();
1544 let project_id = ProjectId::from_proto(request.project_id);
1545 let project_connection_ids = session
1546 .db()
1547 .await
1548 .project_connection_ids(project_id, session.connection_id)
1549 .await?;
1550 broadcast(
1551 Some(session.connection_id),
1552 project_connection_ids.iter().copied(),
1553 |connection_id| {
1554 session
1555 .peer
1556 .forward_send(session.connection_id, connection_id, request.clone())
1557 },
1558 );
1559 Ok(())
1560}
1561
1562async fn forward_project_request<T>(
1563 request: T,
1564 response: Response<T>,
1565 session: Session,
1566) -> Result<()>
1567where
1568 T: EntityMessage + RequestMessage,
1569{
1570 session.executor.record_backtrace();
1571 let project_id = ProjectId::from_proto(request.remote_entity_id());
1572 let host_connection_id = {
1573 let collaborators = session
1574 .db()
1575 .await
1576 .project_collaborators(project_id, session.connection_id)
1577 .await?;
1578 collaborators
1579 .iter()
1580 .find(|collaborator| collaborator.is_host)
1581 .ok_or_else(|| anyhow!("host not found"))?
1582 .connection_id
1583 };
1584
1585 let payload = session
1586 .peer
1587 .forward_request(session.connection_id, host_connection_id, request)
1588 .await?;
1589
1590 response.send(payload)?;
1591 Ok(())
1592}
1593
1594async fn save_buffer(
1595 request: proto::SaveBuffer,
1596 response: Response<proto::SaveBuffer>,
1597 session: Session,
1598) -> Result<()> {
1599 let project_id = ProjectId::from_proto(request.project_id);
1600 let host_connection_id = {
1601 let collaborators = session
1602 .db()
1603 .await
1604 .project_collaborators(project_id, session.connection_id)
1605 .await?;
1606 collaborators
1607 .iter()
1608 .find(|collaborator| collaborator.is_host)
1609 .ok_or_else(|| anyhow!("host not found"))?
1610 .connection_id
1611 };
1612 let response_payload = session
1613 .peer
1614 .forward_request(session.connection_id, host_connection_id, request.clone())
1615 .await?;
1616
1617 let mut collaborators = session
1618 .db()
1619 .await
1620 .project_collaborators(project_id, session.connection_id)
1621 .await?;
1622 collaborators.retain(|collaborator| collaborator.connection_id != session.connection_id);
1623 let project_connection_ids = collaborators
1624 .iter()
1625 .map(|collaborator| collaborator.connection_id);
1626 broadcast(
1627 Some(host_connection_id),
1628 project_connection_ids,
1629 |conn_id| {
1630 session
1631 .peer
1632 .forward_send(host_connection_id, conn_id, response_payload.clone())
1633 },
1634 );
1635 response.send(response_payload)?;
1636 Ok(())
1637}
1638
1639async fn create_buffer_for_peer(
1640 request: proto::CreateBufferForPeer,
1641 session: Session,
1642) -> Result<()> {
1643 session.executor.record_backtrace();
1644 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1645 session
1646 .peer
1647 .forward_send(session.connection_id, peer_id.into(), request)?;
1648 Ok(())
1649}
1650
1651async fn update_buffer(
1652 request: proto::UpdateBuffer,
1653 response: Response<proto::UpdateBuffer>,
1654 session: Session,
1655) -> Result<()> {
1656 session.executor.record_backtrace();
1657 let project_id = ProjectId::from_proto(request.project_id);
1658 let host_connection_id = {
1659 let collaborators = session
1660 .db()
1661 .await
1662 .project_collaborators(project_id, session.connection_id)
1663 .await?;
1664
1665 let host = collaborators
1666 .iter()
1667 .find(|collaborator| collaborator.is_host)
1668 .ok_or_else(|| anyhow!("host not found"))?;
1669 host.connection_id
1670 };
1671
1672 if host_connection_id != session.connection_id {
1673 session
1674 .peer
1675 .forward_request(session.connection_id, host_connection_id, request.clone())
1676 .await?;
1677 }
1678
1679 session.executor.record_backtrace();
1680 let collaborators = session
1681 .db()
1682 .await
1683 .project_collaborators(project_id, session.connection_id)
1684 .await?;
1685
1686 broadcast(
1687 Some(session.connection_id),
1688 collaborators
1689 .iter()
1690 .filter(|collaborator| !collaborator.is_host)
1691 .map(|collaborator| collaborator.connection_id),
1692 |connection_id| {
1693 session
1694 .peer
1695 .forward_send(session.connection_id, connection_id, request.clone())
1696 },
1697 );
1698 response.send(proto::Ack {})?;
1699 Ok(())
1700}
1701
1702async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1703 let project_id = ProjectId::from_proto(request.project_id);
1704 let project_connection_ids = session
1705 .db()
1706 .await
1707 .project_connection_ids(project_id, session.connection_id)
1708 .await?;
1709
1710 broadcast(
1711 Some(session.connection_id),
1712 project_connection_ids.iter().copied(),
1713 |connection_id| {
1714 session
1715 .peer
1716 .forward_send(session.connection_id, connection_id, request.clone())
1717 },
1718 );
1719 Ok(())
1720}
1721
1722async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1723 let project_id = ProjectId::from_proto(request.project_id);
1724 let project_connection_ids = session
1725 .db()
1726 .await
1727 .project_connection_ids(project_id, session.connection_id)
1728 .await?;
1729 broadcast(
1730 Some(session.connection_id),
1731 project_connection_ids.iter().copied(),
1732 |connection_id| {
1733 session
1734 .peer
1735 .forward_send(session.connection_id, connection_id, request.clone())
1736 },
1737 );
1738 Ok(())
1739}
1740
1741async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1742 let project_id = ProjectId::from_proto(request.project_id);
1743 let project_connection_ids = session
1744 .db()
1745 .await
1746 .project_connection_ids(project_id, session.connection_id)
1747 .await?;
1748 broadcast(
1749 Some(session.connection_id),
1750 project_connection_ids.iter().copied(),
1751 |connection_id| {
1752 session
1753 .peer
1754 .forward_send(session.connection_id, connection_id, request.clone())
1755 },
1756 );
1757 Ok(())
1758}
1759
1760async fn follow(
1761 request: proto::Follow,
1762 response: Response<proto::Follow>,
1763 session: Session,
1764) -> Result<()> {
1765 let project_id = ProjectId::from_proto(request.project_id);
1766 let leader_id = request
1767 .leader_id
1768 .ok_or_else(|| anyhow!("invalid leader id"))?
1769 .into();
1770 let follower_id = session.connection_id;
1771
1772 {
1773 let project_connection_ids = session
1774 .db()
1775 .await
1776 .project_connection_ids(project_id, session.connection_id)
1777 .await?;
1778
1779 if !project_connection_ids.contains(&leader_id) {
1780 Err(anyhow!("no such peer"))?;
1781 }
1782 }
1783
1784 let mut response_payload = session
1785 .peer
1786 .forward_request(session.connection_id, leader_id, request)
1787 .await?;
1788 response_payload
1789 .views
1790 .retain(|view| view.leader_id != Some(follower_id.into()));
1791 response.send(response_payload)?;
1792
1793 let room = session
1794 .db()
1795 .await
1796 .follow(project_id, leader_id, follower_id)
1797 .await?;
1798 room_updated(&room, &session.peer);
1799
1800 Ok(())
1801}
1802
1803async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1804 let project_id = ProjectId::from_proto(request.project_id);
1805 let leader_id = request
1806 .leader_id
1807 .ok_or_else(|| anyhow!("invalid leader id"))?
1808 .into();
1809 let follower_id = session.connection_id;
1810
1811 if !session
1812 .db()
1813 .await
1814 .project_connection_ids(project_id, session.connection_id)
1815 .await?
1816 .contains(&leader_id)
1817 {
1818 Err(anyhow!("no such peer"))?;
1819 }
1820
1821 session
1822 .peer
1823 .forward_send(session.connection_id, leader_id, request)?;
1824
1825 let room = session
1826 .db()
1827 .await
1828 .unfollow(project_id, leader_id, follower_id)
1829 .await?;
1830 room_updated(&room, &session.peer);
1831
1832 Ok(())
1833}
1834
1835async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1836 let project_id = ProjectId::from_proto(request.project_id);
1837 let project_connection_ids = session
1838 .db
1839 .lock()
1840 .await
1841 .project_connection_ids(project_id, session.connection_id)
1842 .await?;
1843
1844 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1845 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1846 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1847 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1848 });
1849 for follower_peer_id in request.follower_ids.iter().copied() {
1850 let follower_connection_id = follower_peer_id.into();
1851 if project_connection_ids.contains(&follower_connection_id)
1852 && Some(follower_peer_id) != leader_id
1853 {
1854 session.peer.forward_send(
1855 session.connection_id,
1856 follower_connection_id,
1857 request.clone(),
1858 )?;
1859 }
1860 }
1861 Ok(())
1862}
1863
1864async fn get_users(
1865 request: proto::GetUsers,
1866 response: Response<proto::GetUsers>,
1867 session: Session,
1868) -> Result<()> {
1869 let user_ids = request
1870 .user_ids
1871 .into_iter()
1872 .map(UserId::from_proto)
1873 .collect();
1874 let users = session
1875 .db()
1876 .await
1877 .get_users_by_ids(user_ids)
1878 .await?
1879 .into_iter()
1880 .map(|user| proto::User {
1881 id: user.id.to_proto(),
1882 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1883 github_login: user.github_login,
1884 })
1885 .collect();
1886 response.send(proto::UsersResponse { users })?;
1887 Ok(())
1888}
1889
1890async fn fuzzy_search_users(
1891 request: proto::FuzzySearchUsers,
1892 response: Response<proto::FuzzySearchUsers>,
1893 session: Session,
1894) -> Result<()> {
1895 let query = request.query;
1896 let users = match query.len() {
1897 0 => vec![],
1898 1 | 2 => session
1899 .db()
1900 .await
1901 .get_user_by_github_login(&query)
1902 .await?
1903 .into_iter()
1904 .collect(),
1905 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1906 };
1907 let users = users
1908 .into_iter()
1909 .filter(|user| user.id != session.user_id)
1910 .map(|user| proto::User {
1911 id: user.id.to_proto(),
1912 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1913 github_login: user.github_login,
1914 })
1915 .collect();
1916 response.send(proto::UsersResponse { users })?;
1917 Ok(())
1918}
1919
1920async fn request_contact(
1921 request: proto::RequestContact,
1922 response: Response<proto::RequestContact>,
1923 session: Session,
1924) -> Result<()> {
1925 let requester_id = session.user_id;
1926 let responder_id = UserId::from_proto(request.responder_id);
1927 if requester_id == responder_id {
1928 return Err(anyhow!("cannot add yourself as a contact"))?;
1929 }
1930
1931 session
1932 .db()
1933 .await
1934 .send_contact_request(requester_id, responder_id)
1935 .await?;
1936
1937 // Update outgoing contact requests of requester
1938 let mut update = proto::UpdateContacts::default();
1939 update.outgoing_requests.push(responder_id.to_proto());
1940 for connection_id in session
1941 .connection_pool()
1942 .await
1943 .user_connection_ids(requester_id)
1944 {
1945 session.peer.send(connection_id, update.clone())?;
1946 }
1947
1948 // Update incoming contact requests of responder
1949 let mut update = proto::UpdateContacts::default();
1950 update
1951 .incoming_requests
1952 .push(proto::IncomingContactRequest {
1953 requester_id: requester_id.to_proto(),
1954 should_notify: true,
1955 });
1956 for connection_id in session
1957 .connection_pool()
1958 .await
1959 .user_connection_ids(responder_id)
1960 {
1961 session.peer.send(connection_id, update.clone())?;
1962 }
1963
1964 response.send(proto::Ack {})?;
1965 Ok(())
1966}
1967
1968async fn respond_to_contact_request(
1969 request: proto::RespondToContactRequest,
1970 response: Response<proto::RespondToContactRequest>,
1971 session: Session,
1972) -> Result<()> {
1973 let responder_id = session.user_id;
1974 let requester_id = UserId::from_proto(request.requester_id);
1975 let db = session.db().await;
1976 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1977 db.dismiss_contact_notification(responder_id, requester_id)
1978 .await?;
1979 } else {
1980 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1981
1982 db.respond_to_contact_request(responder_id, requester_id, accept)
1983 .await?;
1984 let requester_busy = db.is_user_busy(requester_id).await?;
1985 let responder_busy = db.is_user_busy(responder_id).await?;
1986
1987 let pool = session.connection_pool().await;
1988 // Update responder with new contact
1989 let mut update = proto::UpdateContacts::default();
1990 if accept {
1991 update
1992 .contacts
1993 .push(contact_for_user(requester_id, false, requester_busy, &pool));
1994 }
1995 update
1996 .remove_incoming_requests
1997 .push(requester_id.to_proto());
1998 for connection_id in pool.user_connection_ids(responder_id) {
1999 session.peer.send(connection_id, update.clone())?;
2000 }
2001
2002 // Update requester with new contact
2003 let mut update = proto::UpdateContacts::default();
2004 if accept {
2005 update
2006 .contacts
2007 .push(contact_for_user(responder_id, true, responder_busy, &pool));
2008 }
2009 update
2010 .remove_outgoing_requests
2011 .push(responder_id.to_proto());
2012 for connection_id in pool.user_connection_ids(requester_id) {
2013 session.peer.send(connection_id, update.clone())?;
2014 }
2015 }
2016
2017 response.send(proto::Ack {})?;
2018 Ok(())
2019}
2020
2021async fn remove_contact(
2022 request: proto::RemoveContact,
2023 response: Response<proto::RemoveContact>,
2024 session: Session,
2025) -> Result<()> {
2026 let requester_id = session.user_id;
2027 let responder_id = UserId::from_proto(request.user_id);
2028 let db = session.db().await;
2029 let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
2030
2031 let pool = session.connection_pool().await;
2032 // Update outgoing contact requests of requester
2033 let mut update = proto::UpdateContacts::default();
2034 if contact_accepted {
2035 update.remove_contacts.push(responder_id.to_proto());
2036 } else {
2037 update
2038 .remove_outgoing_requests
2039 .push(responder_id.to_proto());
2040 }
2041 for connection_id in pool.user_connection_ids(requester_id) {
2042 session.peer.send(connection_id, update.clone())?;
2043 }
2044
2045 // Update incoming contact requests of responder
2046 let mut update = proto::UpdateContacts::default();
2047 if contact_accepted {
2048 update.remove_contacts.push(requester_id.to_proto());
2049 } else {
2050 update
2051 .remove_incoming_requests
2052 .push(requester_id.to_proto());
2053 }
2054 for connection_id in pool.user_connection_ids(responder_id) {
2055 session.peer.send(connection_id, update.clone())?;
2056 }
2057
2058 response.send(proto::Ack {})?;
2059 Ok(())
2060}
2061
2062async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2063 let project_id = ProjectId::from_proto(request.project_id);
2064 let project_connection_ids = session
2065 .db()
2066 .await
2067 .project_connection_ids(project_id, session.connection_id)
2068 .await?;
2069 broadcast(
2070 Some(session.connection_id),
2071 project_connection_ids.iter().copied(),
2072 |connection_id| {
2073 session
2074 .peer
2075 .forward_send(session.connection_id, connection_id, request.clone())
2076 },
2077 );
2078 Ok(())
2079}
2080
2081async fn get_private_user_info(
2082 _request: proto::GetPrivateUserInfo,
2083 response: Response<proto::GetPrivateUserInfo>,
2084 session: Session,
2085) -> Result<()> {
2086 let metrics_id = session
2087 .db()
2088 .await
2089 .get_user_metrics_id(session.user_id)
2090 .await?;
2091 let user = session
2092 .db()
2093 .await
2094 .get_user_by_id(session.user_id)
2095 .await?
2096 .ok_or_else(|| anyhow!("user not found"))?;
2097 response.send(proto::GetPrivateUserInfoResponse {
2098 metrics_id,
2099 staff: user.admin,
2100 })?;
2101 Ok(())
2102}
2103
2104fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2105 match message {
2106 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2107 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2108 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2109 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2110 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2111 code: frame.code.into(),
2112 reason: frame.reason,
2113 })),
2114 }
2115}
2116
2117fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2118 match message {
2119 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2120 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2121 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2122 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2123 AxumMessage::Close(frame) => {
2124 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2125 code: frame.code.into(),
2126 reason: frame.reason,
2127 }))
2128 }
2129 }
2130}
2131
2132fn build_initial_contacts_update(
2133 contacts: Vec<db::Contact>,
2134 pool: &ConnectionPool,
2135) -> proto::UpdateContacts {
2136 let mut update = proto::UpdateContacts::default();
2137
2138 for contact in contacts {
2139 match contact {
2140 db::Contact::Accepted {
2141 user_id,
2142 should_notify,
2143 busy,
2144 } => {
2145 update
2146 .contacts
2147 .push(contact_for_user(user_id, should_notify, busy, &pool));
2148 }
2149 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2150 db::Contact::Incoming {
2151 user_id,
2152 should_notify,
2153 } => update
2154 .incoming_requests
2155 .push(proto::IncomingContactRequest {
2156 requester_id: user_id.to_proto(),
2157 should_notify,
2158 }),
2159 }
2160 }
2161
2162 update
2163}
2164
2165fn contact_for_user(
2166 user_id: UserId,
2167 should_notify: bool,
2168 busy: bool,
2169 pool: &ConnectionPool,
2170) -> proto::Contact {
2171 proto::Contact {
2172 user_id: user_id.to_proto(),
2173 online: pool.is_user_online(user_id),
2174 busy,
2175 should_notify,
2176 }
2177}
2178
2179fn room_updated(room: &proto::Room, peer: &Peer) {
2180 broadcast(
2181 None,
2182 room.participants
2183 .iter()
2184 .filter_map(|participant| Some(participant.peer_id?.into())),
2185 |peer_id| {
2186 peer.send(
2187 peer_id.into(),
2188 proto::RoomUpdated {
2189 room: Some(room.clone()),
2190 },
2191 )
2192 },
2193 );
2194}
2195
2196async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2197 let db = session.db().await;
2198 let contacts = db.get_contacts(user_id).await?;
2199 let busy = db.is_user_busy(user_id).await?;
2200
2201 let pool = session.connection_pool().await;
2202 let updated_contact = contact_for_user(user_id, false, busy, &pool);
2203 for contact in contacts {
2204 if let db::Contact::Accepted {
2205 user_id: contact_user_id,
2206 ..
2207 } = contact
2208 {
2209 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2210 session
2211 .peer
2212 .send(
2213 contact_conn_id,
2214 proto::UpdateContacts {
2215 contacts: vec![updated_contact.clone()],
2216 remove_contacts: Default::default(),
2217 incoming_requests: Default::default(),
2218 remove_incoming_requests: Default::default(),
2219 outgoing_requests: Default::default(),
2220 remove_outgoing_requests: Default::default(),
2221 },
2222 )
2223 .trace_err();
2224 }
2225 }
2226 }
2227 Ok(())
2228}
2229
2230async fn leave_room_for_session(session: &Session) -> Result<()> {
2231 let mut contacts_to_update = HashSet::default();
2232
2233 let room_id;
2234 let canceled_calls_to_user_ids;
2235 let live_kit_room;
2236 let delete_live_kit_room;
2237 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2238 contacts_to_update.insert(session.user_id);
2239
2240 for project in left_room.left_projects.values() {
2241 project_left(project, session);
2242 }
2243
2244 room_updated(&left_room.room, &session.peer);
2245 room_id = RoomId::from_proto(left_room.room.id);
2246 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2247 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2248 delete_live_kit_room = left_room.room.participants.is_empty();
2249 } else {
2250 return Ok(());
2251 }
2252
2253 {
2254 let pool = session.connection_pool().await;
2255 for canceled_user_id in canceled_calls_to_user_ids {
2256 for connection_id in pool.user_connection_ids(canceled_user_id) {
2257 session
2258 .peer
2259 .send(
2260 connection_id,
2261 proto::CallCanceled {
2262 room_id: room_id.to_proto(),
2263 },
2264 )
2265 .trace_err();
2266 }
2267 contacts_to_update.insert(canceled_user_id);
2268 }
2269 }
2270
2271 for contact_user_id in contacts_to_update {
2272 update_user_contacts(contact_user_id, &session).await?;
2273 }
2274
2275 if let Some(live_kit) = session.live_kit_client.as_ref() {
2276 live_kit
2277 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2278 .await
2279 .trace_err();
2280
2281 if delete_live_kit_room {
2282 live_kit.delete_room(live_kit_room).await.trace_err();
2283 }
2284 }
2285
2286 Ok(())
2287}
2288
2289fn project_left(project: &db::LeftProject, session: &Session) {
2290 for connection_id in &project.connection_ids {
2291 if project.host_user_id == session.user_id {
2292 session
2293 .peer
2294 .send(
2295 *connection_id,
2296 proto::UnshareProject {
2297 project_id: project.id.to_proto(),
2298 },
2299 )
2300 .trace_err();
2301 } else {
2302 session
2303 .peer
2304 .send(
2305 *connection_id,
2306 proto::RemoveProjectCollaborator {
2307 project_id: project.id.to_proto(),
2308 peer_id: Some(session.connection_id.into()),
2309 },
2310 )
2311 .trace_err();
2312 }
2313 }
2314}
2315
2316pub trait ResultExt {
2317 type Ok;
2318
2319 fn trace_err(self) -> Option<Self::Ok>;
2320}
2321
2322impl<T, E> ResultExt for Result<T, E>
2323where
2324 E: std::fmt::Debug,
2325{
2326 type Ok = T;
2327
2328 fn trace_err(self) -> Option<T> {
2329 match self {
2330 Ok(value) => Some(value),
2331 Err(error) => {
2332 tracing::error!("{:?}", error);
2333 None
2334 }
2335 }
2336 }
2337}