1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, Database, ProjectId, RoomId, ServerId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
38 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
39};
40use serde::{Serialize, Serializer};
41use std::{
42 any::TypeId,
43 fmt,
44 future::Future,
45 marker::PhantomData,
46 mem,
47 net::SocketAddr,
48 ops::{Deref, DerefMut},
49 rc::Rc,
50 sync::{
51 atomic::{AtomicBool, Ordering::SeqCst},
52 Arc,
53 },
54 time::Duration,
55};
56use tokio::sync::watch;
57use tower::ServiceBuilder;
58use tracing::{info_span, instrument, Instrument};
59
60pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(5);
61pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
62
63lazy_static! {
64 static ref METRIC_CONNECTIONS: IntGauge =
65 register_int_gauge!("connections", "number of connections").unwrap();
66 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
67 "shared_projects",
68 "number of open projects with one or more guests"
69 )
70 .unwrap();
71}
72
73type MessageHandler =
74 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
75
76struct Response<R> {
77 peer: Arc<Peer>,
78 receipt: Receipt<R>,
79 responded: Arc<AtomicBool>,
80}
81
82impl<R: RequestMessage> Response<R> {
83 fn send(self, payload: R::Response) -> Result<()> {
84 self.responded.store(true, SeqCst);
85 self.peer.respond(self.receipt, payload)?;
86 Ok(())
87 }
88}
89
90#[derive(Clone)]
91struct Session {
92 user_id: UserId,
93 connection_id: ConnectionId,
94 db: Arc<tokio::sync::Mutex<DbHandle>>,
95 peer: Arc<Peer>,
96 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
97 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
98}
99
100impl Session {
101 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
102 #[cfg(test)]
103 tokio::task::yield_now().await;
104 let guard = self.db.lock().await;
105 #[cfg(test)]
106 tokio::task::yield_now().await;
107 guard
108 }
109
110 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
111 #[cfg(test)]
112 tokio::task::yield_now().await;
113 let guard = self.connection_pool.lock();
114 ConnectionPoolGuard {
115 guard,
116 _not_send: PhantomData,
117 }
118 }
119}
120
121impl fmt::Debug for Session {
122 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123 f.debug_struct("Session")
124 .field("user_id", &self.user_id)
125 .field("connection_id", &self.connection_id)
126 .finish()
127 }
128}
129
130struct DbHandle(Arc<Database>);
131
132impl Deref for DbHandle {
133 type Target = Database;
134
135 fn deref(&self) -> &Self::Target {
136 self.0.as_ref()
137 }
138}
139
140pub struct Server {
141 id: parking_lot::Mutex<ServerId>,
142 peer: Arc<Peer>,
143 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
144 app_state: Arc<AppState>,
145 executor: Executor,
146 handlers: HashMap<TypeId, MessageHandler>,
147 teardown: watch::Sender<()>,
148}
149
150pub(crate) struct ConnectionPoolGuard<'a> {
151 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
152 _not_send: PhantomData<Rc<()>>,
153}
154
155#[derive(Serialize)]
156pub struct ServerSnapshot<'a> {
157 peer: &'a Peer,
158 #[serde(serialize_with = "serialize_deref")]
159 connection_pool: ConnectionPoolGuard<'a>,
160}
161
162pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
163where
164 S: Serializer,
165 T: Deref<Target = U>,
166 U: Serialize,
167{
168 Serialize::serialize(value.deref(), serializer)
169}
170
171impl Server {
172 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
173 let mut server = Self {
174 id: parking_lot::Mutex::new(id),
175 peer: Peer::new(id.0 as u32),
176 app_state,
177 executor,
178 connection_pool: Default::default(),
179 handlers: Default::default(),
180 teardown: watch::channel(()).0,
181 };
182
183 server
184 .add_request_handler(ping)
185 .add_request_handler(create_room)
186 .add_request_handler(join_room)
187 .add_message_handler(leave_room)
188 .add_request_handler(call)
189 .add_request_handler(cancel_call)
190 .add_message_handler(decline_call)
191 .add_request_handler(update_participant_location)
192 .add_request_handler(share_project)
193 .add_message_handler(unshare_project)
194 .add_request_handler(join_project)
195 .add_message_handler(leave_project)
196 .add_request_handler(update_project)
197 .add_request_handler(update_worktree)
198 .add_message_handler(start_language_server)
199 .add_message_handler(update_language_server)
200 .add_message_handler(update_diagnostic_summary)
201 .add_request_handler(forward_project_request::<proto::GetHover>)
202 .add_request_handler(forward_project_request::<proto::GetDefinition>)
203 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
204 .add_request_handler(forward_project_request::<proto::GetReferences>)
205 .add_request_handler(forward_project_request::<proto::SearchProject>)
206 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
207 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
208 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
209 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
210 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
211 .add_request_handler(forward_project_request::<proto::GetCompletions>)
212 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
213 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
214 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
215 .add_request_handler(forward_project_request::<proto::PrepareRename>)
216 .add_request_handler(forward_project_request::<proto::PerformRename>)
217 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
218 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
219 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
220 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
221 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
222 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
223 .add_message_handler(create_buffer_for_peer)
224 .add_request_handler(update_buffer)
225 .add_message_handler(update_buffer_file)
226 .add_message_handler(buffer_reloaded)
227 .add_message_handler(buffer_saved)
228 .add_request_handler(save_buffer)
229 .add_request_handler(get_users)
230 .add_request_handler(fuzzy_search_users)
231 .add_request_handler(request_contact)
232 .add_request_handler(remove_contact)
233 .add_request_handler(respond_to_contact_request)
234 .add_request_handler(follow)
235 .add_message_handler(unfollow)
236 .add_message_handler(update_followers)
237 .add_message_handler(update_diff_base)
238 .add_request_handler(get_private_user_info);
239
240 Arc::new(server)
241 }
242
243 pub async fn start(&self) -> Result<()> {
244 let server_id = *self.id.lock();
245 let app_state = self.app_state.clone();
246 let peer = self.peer.clone();
247 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
248 let pool = self.connection_pool.clone();
249 let live_kit_client = self.app_state.live_kit_client.clone();
250
251 let span = info_span!("start server");
252 let span_enter = span.enter();
253
254 tracing::info!("begin deleting stale projects");
255 app_state
256 .db
257 .delete_stale_projects(&app_state.config.zed_environment, server_id)
258 .await?;
259 tracing::info!("finish deleting stale projects");
260
261 drop(span_enter);
262 self.executor.spawn_detached(
263 async move {
264 tracing::info!("waiting for cleanup timeout");
265 timeout.await;
266 tracing::info!("cleanup timeout expired, retrieving stale rooms");
267 if let Some(room_ids) = app_state
268 .db
269 .stale_room_ids(&app_state.config.zed_environment, server_id)
270 .await
271 .trace_err()
272 {
273 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
274 for room_id in room_ids {
275 let mut contacts_to_update = HashSet::default();
276 let mut canceled_calls_to_user_ids = Vec::new();
277 let mut live_kit_room = String::new();
278 let mut delete_live_kit_room = false;
279
280 if let Ok(mut refreshed_room) =
281 app_state.db.refresh_room(room_id, server_id).await
282 {
283 tracing::info!(
284 room_id = room_id.0,
285 new_participant_count = refreshed_room.room.participants.len(),
286 "refreshed room"
287 );
288 room_updated(&refreshed_room.room, &peer);
289 contacts_to_update
290 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
291 contacts_to_update
292 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
293 canceled_calls_to_user_ids =
294 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
295 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
296 delete_live_kit_room = refreshed_room.room.participants.is_empty();
297 }
298
299 {
300 let pool = pool.lock();
301 for canceled_user_id in canceled_calls_to_user_ids {
302 for connection_id in pool.user_connection_ids(canceled_user_id) {
303 peer.send(
304 connection_id,
305 proto::CallCanceled {
306 room_id: room_id.to_proto(),
307 },
308 )
309 .trace_err();
310 }
311 }
312 }
313
314 for user_id in contacts_to_update {
315 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
316 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
317 if let Some((busy, contacts)) = busy.zip(contacts) {
318 let pool = pool.lock();
319 let updated_contact = contact_for_user(user_id, false, busy, &pool);
320 for contact in contacts {
321 if let db::Contact::Accepted {
322 user_id: contact_user_id,
323 ..
324 } = contact
325 {
326 for contact_conn_id in
327 pool.user_connection_ids(contact_user_id)
328 {
329 peer.send(
330 contact_conn_id,
331 proto::UpdateContacts {
332 contacts: vec![updated_contact.clone()],
333 remove_contacts: Default::default(),
334 incoming_requests: Default::default(),
335 remove_incoming_requests: Default::default(),
336 outgoing_requests: Default::default(),
337 remove_outgoing_requests: Default::default(),
338 },
339 )
340 .trace_err();
341 }
342 }
343 }
344 }
345 }
346
347 if let Some(live_kit) = live_kit_client.as_ref() {
348 if delete_live_kit_room {
349 live_kit.delete_room(live_kit_room).await.trace_err();
350 }
351 }
352 }
353 }
354
355 app_state
356 .db
357 .delete_stale_servers(server_id, &app_state.config.zed_environment)
358 .await
359 .trace_err();
360 }
361 .instrument(span),
362 );
363 Ok(())
364 }
365
366 pub fn teardown(&self) {
367 self.peer.teardown();
368 self.connection_pool.lock().reset();
369 let _ = self.teardown.send(());
370 }
371
372 #[cfg(test)]
373 pub fn reset(&self, id: ServerId) {
374 self.teardown();
375 *self.id.lock() = id;
376 self.peer.reset(id.0 as u32);
377 }
378
379 #[cfg(test)]
380 pub fn id(&self) -> ServerId {
381 *self.id.lock()
382 }
383
384 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
385 where
386 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
387 Fut: 'static + Send + Future<Output = Result<()>>,
388 M: EnvelopedMessage,
389 {
390 let prev_handler = self.handlers.insert(
391 TypeId::of::<M>(),
392 Box::new(move |envelope, session| {
393 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
394 let span = info_span!(
395 "handle message",
396 payload_type = envelope.payload_type_name()
397 );
398 span.in_scope(|| {
399 tracing::info!(
400 payload_type = envelope.payload_type_name(),
401 "message received"
402 );
403 });
404 let future = (handler)(*envelope, session);
405 async move {
406 if let Err(error) = future.await {
407 tracing::error!(%error, "error handling message");
408 }
409 }
410 .instrument(span)
411 .boxed()
412 }),
413 );
414 if prev_handler.is_some() {
415 panic!("registered a handler for the same message twice");
416 }
417 self
418 }
419
420 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
421 where
422 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
423 Fut: 'static + Send + Future<Output = Result<()>>,
424 M: EnvelopedMessage,
425 {
426 self.add_handler(move |envelope, session| handler(envelope.payload, session));
427 self
428 }
429
430 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
431 where
432 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
433 Fut: Send + Future<Output = Result<()>>,
434 M: RequestMessage,
435 {
436 let handler = Arc::new(handler);
437 self.add_handler(move |envelope, session| {
438 let receipt = envelope.receipt();
439 let handler = handler.clone();
440 async move {
441 let peer = session.peer.clone();
442 let responded = Arc::new(AtomicBool::default());
443 let response = Response {
444 peer: peer.clone(),
445 responded: responded.clone(),
446 receipt,
447 };
448 match (handler)(envelope.payload, response, session).await {
449 Ok(()) => {
450 if responded.load(std::sync::atomic::Ordering::SeqCst) {
451 Ok(())
452 } else {
453 Err(anyhow!("handler did not send a response"))?
454 }
455 }
456 Err(error) => {
457 peer.respond_with_error(
458 receipt,
459 proto::Error {
460 message: error.to_string(),
461 },
462 )?;
463 Err(error)
464 }
465 }
466 }
467 })
468 }
469
470 pub fn handle_connection(
471 self: &Arc<Self>,
472 connection: Connection,
473 address: String,
474 user: User,
475 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
476 executor: Executor,
477 ) -> impl Future<Output = Result<()>> {
478 let this = self.clone();
479 let user_id = user.id;
480 let login = user.github_login;
481 let span = info_span!("handle connection", %user_id, %login, %address);
482 let mut teardown = self.teardown.subscribe();
483 async move {
484 let (connection_id, handle_io, mut incoming_rx) = this
485 .peer
486 .add_connection(connection, {
487 let executor = executor.clone();
488 move |duration| executor.sleep(duration)
489 });
490
491 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
492 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
493 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
494
495 if let Some(send_connection_id) = send_connection_id.take() {
496 let _ = send_connection_id.send(connection_id);
497 }
498
499 if !user.connected_once {
500 this.peer.send(connection_id, proto::ShowContacts {})?;
501 this.app_state.db.set_user_connected_once(user_id, true).await?;
502 }
503
504 let (contacts, invite_code) = future::try_join(
505 this.app_state.db.get_contacts(user_id),
506 this.app_state.db.get_invite_code_for_user(user_id)
507 ).await?;
508
509 {
510 let mut pool = this.connection_pool.lock();
511 pool.add_connection(connection_id, user_id, user.admin);
512 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
513
514 if let Some((code, count)) = invite_code {
515 this.peer.send(connection_id, proto::UpdateInviteInfo {
516 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
517 count: count as u32,
518 })?;
519 }
520 }
521
522 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
523 this.peer.send(connection_id, incoming_call)?;
524 }
525
526 let session = Session {
527 user_id,
528 connection_id,
529 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
530 peer: this.peer.clone(),
531 connection_pool: this.connection_pool.clone(),
532 live_kit_client: this.app_state.live_kit_client.clone()
533 };
534 update_user_contacts(user_id, &session).await?;
535
536 let handle_io = handle_io.fuse();
537 futures::pin_mut!(handle_io);
538
539 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
540 // This prevents deadlocks when e.g., client A performs a request to client B and
541 // client B performs a request to client A. If both clients stop processing further
542 // messages until their respective request completes, they won't have a chance to
543 // respond to the other client's request and cause a deadlock.
544 //
545 // This arrangement ensures we will attempt to process earlier messages first, but fall
546 // back to processing messages arrived later in the spirit of making progress.
547 let mut foreground_message_handlers = FuturesUnordered::new();
548 loop {
549 let next_message = incoming_rx.next().fuse();
550 futures::pin_mut!(next_message);
551 futures::select_biased! {
552 _ = teardown.changed().fuse() => return Ok(()),
553 result = handle_io => {
554 if let Err(error) = result {
555 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
556 }
557 break;
558 }
559 _ = foreground_message_handlers.next() => {}
560 message = next_message => {
561 if let Some(message) = message {
562 let type_name = message.payload_type_name();
563 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
564 let span_enter = span.enter();
565 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
566 let is_background = message.is_background();
567 let handle_message = (handler)(message, session.clone());
568 drop(span_enter);
569
570 let handle_message = handle_message.instrument(span);
571 if is_background {
572 executor.spawn_detached(handle_message);
573 } else {
574 foreground_message_handlers.push(handle_message);
575 }
576 } else {
577 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
578 }
579 } else {
580 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
581 break;
582 }
583 }
584 }
585 }
586
587 drop(foreground_message_handlers);
588 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
589 if let Err(error) = sign_out(session, teardown, executor).await {
590 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
591 }
592
593 Ok(())
594 }.instrument(span)
595 }
596
597 pub async fn invite_code_redeemed(
598 self: &Arc<Self>,
599 inviter_id: UserId,
600 invitee_id: UserId,
601 ) -> Result<()> {
602 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
603 if let Some(code) = &user.invite_code {
604 let pool = self.connection_pool.lock();
605 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
606 for connection_id in pool.user_connection_ids(inviter_id) {
607 self.peer.send(
608 connection_id,
609 proto::UpdateContacts {
610 contacts: vec![invitee_contact.clone()],
611 ..Default::default()
612 },
613 )?;
614 self.peer.send(
615 connection_id,
616 proto::UpdateInviteInfo {
617 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
618 count: user.invite_count as u32,
619 },
620 )?;
621 }
622 }
623 }
624 Ok(())
625 }
626
627 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
628 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
629 if let Some(invite_code) = &user.invite_code {
630 let pool = self.connection_pool.lock();
631 for connection_id in pool.user_connection_ids(user_id) {
632 self.peer.send(
633 connection_id,
634 proto::UpdateInviteInfo {
635 url: format!(
636 "{}{}",
637 self.app_state.config.invite_link_prefix, invite_code
638 ),
639 count: user.invite_count as u32,
640 },
641 )?;
642 }
643 }
644 }
645 Ok(())
646 }
647
648 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
649 ServerSnapshot {
650 connection_pool: ConnectionPoolGuard {
651 guard: self.connection_pool.lock(),
652 _not_send: PhantomData,
653 },
654 peer: &self.peer,
655 }
656 }
657}
658
659impl<'a> Deref for ConnectionPoolGuard<'a> {
660 type Target = ConnectionPool;
661
662 fn deref(&self) -> &Self::Target {
663 &*self.guard
664 }
665}
666
667impl<'a> DerefMut for ConnectionPoolGuard<'a> {
668 fn deref_mut(&mut self) -> &mut Self::Target {
669 &mut *self.guard
670 }
671}
672
673impl<'a> Drop for ConnectionPoolGuard<'a> {
674 fn drop(&mut self) {
675 #[cfg(test)]
676 self.check_invariants();
677 }
678}
679
680fn broadcast<F>(
681 sender_id: ConnectionId,
682 receiver_ids: impl IntoIterator<Item = ConnectionId>,
683 mut f: F,
684) where
685 F: FnMut(ConnectionId) -> anyhow::Result<()>,
686{
687 for receiver_id in receiver_ids {
688 if receiver_id != sender_id {
689 f(receiver_id).trace_err();
690 }
691 }
692}
693
694lazy_static! {
695 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
696}
697
698pub struct ProtocolVersion(u32);
699
700impl Header for ProtocolVersion {
701 fn name() -> &'static HeaderName {
702 &ZED_PROTOCOL_VERSION
703 }
704
705 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
706 where
707 Self: Sized,
708 I: Iterator<Item = &'i axum::http::HeaderValue>,
709 {
710 let version = values
711 .next()
712 .ok_or_else(axum::headers::Error::invalid)?
713 .to_str()
714 .map_err(|_| axum::headers::Error::invalid())?
715 .parse()
716 .map_err(|_| axum::headers::Error::invalid())?;
717 Ok(Self(version))
718 }
719
720 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
721 values.extend([self.0.to_string().parse().unwrap()]);
722 }
723}
724
725pub fn routes(server: Arc<Server>) -> Router<Body> {
726 Router::new()
727 .route("/rpc", get(handle_websocket_request))
728 .layer(
729 ServiceBuilder::new()
730 .layer(Extension(server.app_state.clone()))
731 .layer(middleware::from_fn(auth::validate_header)),
732 )
733 .route("/metrics", get(handle_metrics))
734 .layer(Extension(server))
735}
736
737pub async fn handle_websocket_request(
738 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
739 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
740 Extension(server): Extension<Arc<Server>>,
741 Extension(user): Extension<User>,
742 ws: WebSocketUpgrade,
743) -> axum::response::Response {
744 if protocol_version != rpc::PROTOCOL_VERSION {
745 return (
746 StatusCode::UPGRADE_REQUIRED,
747 "client must be upgraded".to_string(),
748 )
749 .into_response();
750 }
751 let socket_address = socket_address.to_string();
752 ws.on_upgrade(move |socket| {
753 use util::ResultExt;
754 let socket = socket
755 .map_ok(to_tungstenite_message)
756 .err_into()
757 .with(|message| async move { Ok(to_axum_message(message)) });
758 let connection = Connection::new(Box::pin(socket));
759 async move {
760 server
761 .handle_connection(connection, socket_address, user, None, Executor::Production)
762 .await
763 .log_err();
764 }
765 })
766}
767
768pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
769 let connections = server
770 .connection_pool
771 .lock()
772 .connections()
773 .filter(|connection| !connection.admin)
774 .count();
775
776 METRIC_CONNECTIONS.set(connections as _);
777
778 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
779 METRIC_SHARED_PROJECTS.set(shared_projects as _);
780
781 let encoder = prometheus::TextEncoder::new();
782 let metric_families = prometheus::gather();
783 let encoded_metrics = encoder
784 .encode_to_string(&metric_families)
785 .map_err(|err| anyhow!("{}", err))?;
786 Ok(encoded_metrics)
787}
788
789#[instrument(err, skip(executor))]
790async fn sign_out(
791 session: Session,
792 mut teardown: watch::Receiver<()>,
793 executor: Executor,
794) -> Result<()> {
795 session.peer.disconnect(session.connection_id);
796 session
797 .connection_pool()
798 .await
799 .remove_connection(session.connection_id)?;
800
801 if let Some(mut left_projects) = session
802 .db()
803 .await
804 .connection_lost(session.connection_id)
805 .await
806 .trace_err()
807 {
808 for left_project in mem::take(&mut *left_projects) {
809 project_left(&left_project, &session);
810 }
811 }
812
813 futures::select_biased! {
814 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
815 leave_room_for_session(&session).await.trace_err();
816
817 if !session
818 .connection_pool()
819 .await
820 .is_user_online(session.user_id)
821 {
822 let db = session.db().await;
823 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
824 room_updated(&room, &session.peer);
825 }
826 }
827 update_user_contacts(session.user_id, &session).await?;
828 }
829 _ = teardown.changed().fuse() => {}
830 }
831
832 Ok(())
833}
834
835async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
836 response.send(proto::Ack {})?;
837 Ok(())
838}
839
840async fn create_room(
841 _request: proto::CreateRoom,
842 response: Response<proto::CreateRoom>,
843 session: Session,
844) -> Result<()> {
845 let live_kit_room = nanoid::nanoid!(30);
846 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
847 if let Some(_) = live_kit
848 .create_room(live_kit_room.clone())
849 .await
850 .trace_err()
851 {
852 if let Some(token) = live_kit
853 .room_token(&live_kit_room, &session.user_id.to_string())
854 .trace_err()
855 {
856 Some(proto::LiveKitConnectionInfo {
857 server_url: live_kit.url().into(),
858 token,
859 })
860 } else {
861 None
862 }
863 } else {
864 None
865 }
866 } else {
867 None
868 };
869
870 {
871 let room = session
872 .db()
873 .await
874 .create_room(session.user_id, session.connection_id, &live_kit_room)
875 .await?;
876
877 response.send(proto::CreateRoomResponse {
878 room: Some(room.clone()),
879 live_kit_connection_info,
880 })?;
881 }
882
883 update_user_contacts(session.user_id, &session).await?;
884 Ok(())
885}
886
887async fn join_room(
888 request: proto::JoinRoom,
889 response: Response<proto::JoinRoom>,
890 session: Session,
891) -> Result<()> {
892 let room_id = RoomId::from_proto(request.id);
893 let room = {
894 let room = session
895 .db()
896 .await
897 .join_room(room_id, session.user_id, session.connection_id)
898 .await?;
899 room_updated(&room, &session.peer);
900 room.clone()
901 };
902
903 for connection_id in session
904 .connection_pool()
905 .await
906 .user_connection_ids(session.user_id)
907 {
908 session
909 .peer
910 .send(
911 connection_id,
912 proto::CallCanceled {
913 room_id: room_id.to_proto(),
914 },
915 )
916 .trace_err();
917 }
918
919 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
920 if let Some(token) = live_kit
921 .room_token(&room.live_kit_room, &session.user_id.to_string())
922 .trace_err()
923 {
924 Some(proto::LiveKitConnectionInfo {
925 server_url: live_kit.url().into(),
926 token,
927 })
928 } else {
929 None
930 }
931 } else {
932 None
933 };
934
935 response.send(proto::JoinRoomResponse {
936 room: Some(room),
937 live_kit_connection_info,
938 })?;
939
940 update_user_contacts(session.user_id, &session).await?;
941 Ok(())
942}
943
944async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
945 leave_room_for_session(&session).await
946}
947
948async fn call(
949 request: proto::Call,
950 response: Response<proto::Call>,
951 session: Session,
952) -> Result<()> {
953 let room_id = RoomId::from_proto(request.room_id);
954 let calling_user_id = session.user_id;
955 let calling_connection_id = session.connection_id;
956 let called_user_id = UserId::from_proto(request.called_user_id);
957 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
958 if !session
959 .db()
960 .await
961 .has_contact(calling_user_id, called_user_id)
962 .await?
963 {
964 return Err(anyhow!("cannot call a user who isn't a contact"))?;
965 }
966
967 let incoming_call = {
968 let (room, incoming_call) = &mut *session
969 .db()
970 .await
971 .call(
972 room_id,
973 calling_user_id,
974 calling_connection_id,
975 called_user_id,
976 initial_project_id,
977 )
978 .await?;
979 room_updated(&room, &session.peer);
980 mem::take(incoming_call)
981 };
982 update_user_contacts(called_user_id, &session).await?;
983
984 let mut calls = session
985 .connection_pool()
986 .await
987 .user_connection_ids(called_user_id)
988 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
989 .collect::<FuturesUnordered<_>>();
990
991 while let Some(call_response) = calls.next().await {
992 match call_response.as_ref() {
993 Ok(_) => {
994 response.send(proto::Ack {})?;
995 return Ok(());
996 }
997 Err(_) => {
998 call_response.trace_err();
999 }
1000 }
1001 }
1002
1003 {
1004 let room = session
1005 .db()
1006 .await
1007 .call_failed(room_id, called_user_id)
1008 .await?;
1009 room_updated(&room, &session.peer);
1010 }
1011 update_user_contacts(called_user_id, &session).await?;
1012
1013 Err(anyhow!("failed to ring user"))?
1014}
1015
1016async fn cancel_call(
1017 request: proto::CancelCall,
1018 response: Response<proto::CancelCall>,
1019 session: Session,
1020) -> Result<()> {
1021 let called_user_id = UserId::from_proto(request.called_user_id);
1022 let room_id = RoomId::from_proto(request.room_id);
1023 {
1024 let room = session
1025 .db()
1026 .await
1027 .cancel_call(room_id, session.connection_id, called_user_id)
1028 .await?;
1029 room_updated(&room, &session.peer);
1030 }
1031
1032 for connection_id in session
1033 .connection_pool()
1034 .await
1035 .user_connection_ids(called_user_id)
1036 {
1037 session
1038 .peer
1039 .send(
1040 connection_id,
1041 proto::CallCanceled {
1042 room_id: room_id.to_proto(),
1043 },
1044 )
1045 .trace_err();
1046 }
1047 response.send(proto::Ack {})?;
1048
1049 update_user_contacts(called_user_id, &session).await?;
1050 Ok(())
1051}
1052
1053async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1054 let room_id = RoomId::from_proto(message.room_id);
1055 {
1056 let room = session
1057 .db()
1058 .await
1059 .decline_call(Some(room_id), session.user_id)
1060 .await?
1061 .ok_or_else(|| anyhow!("failed to decline call"))?;
1062 room_updated(&room, &session.peer);
1063 }
1064
1065 for connection_id in session
1066 .connection_pool()
1067 .await
1068 .user_connection_ids(session.user_id)
1069 {
1070 session
1071 .peer
1072 .send(
1073 connection_id,
1074 proto::CallCanceled {
1075 room_id: room_id.to_proto(),
1076 },
1077 )
1078 .trace_err();
1079 }
1080 update_user_contacts(session.user_id, &session).await?;
1081 Ok(())
1082}
1083
1084async fn update_participant_location(
1085 request: proto::UpdateParticipantLocation,
1086 response: Response<proto::UpdateParticipantLocation>,
1087 session: Session,
1088) -> Result<()> {
1089 let room_id = RoomId::from_proto(request.room_id);
1090 let location = request
1091 .location
1092 .ok_or_else(|| anyhow!("invalid location"))?;
1093 let room = session
1094 .db()
1095 .await
1096 .update_room_participant_location(room_id, session.connection_id, location)
1097 .await?;
1098 room_updated(&room, &session.peer);
1099 response.send(proto::Ack {})?;
1100 Ok(())
1101}
1102
1103async fn share_project(
1104 request: proto::ShareProject,
1105 response: Response<proto::ShareProject>,
1106 session: Session,
1107) -> Result<()> {
1108 let (project_id, room) = &*session
1109 .db()
1110 .await
1111 .share_project(
1112 RoomId::from_proto(request.room_id),
1113 session.connection_id,
1114 &request.worktrees,
1115 )
1116 .await?;
1117 response.send(proto::ShareProjectResponse {
1118 project_id: project_id.to_proto(),
1119 })?;
1120 room_updated(&room, &session.peer);
1121
1122 Ok(())
1123}
1124
1125async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1126 let project_id = ProjectId::from_proto(message.project_id);
1127
1128 let (room, guest_connection_ids) = &*session
1129 .db()
1130 .await
1131 .unshare_project(project_id, session.connection_id)
1132 .await?;
1133
1134 broadcast(
1135 session.connection_id,
1136 guest_connection_ids.iter().copied(),
1137 |conn_id| session.peer.send(conn_id, message.clone()),
1138 );
1139 room_updated(&room, &session.peer);
1140
1141 Ok(())
1142}
1143
1144async fn join_project(
1145 request: proto::JoinProject,
1146 response: Response<proto::JoinProject>,
1147 session: Session,
1148) -> Result<()> {
1149 let project_id = ProjectId::from_proto(request.project_id);
1150 let guest_user_id = session.user_id;
1151
1152 tracing::info!(%project_id, "join project");
1153
1154 let (project, replica_id) = &mut *session
1155 .db()
1156 .await
1157 .join_project(project_id, session.connection_id)
1158 .await?;
1159
1160 let collaborators = project
1161 .collaborators
1162 .iter()
1163 .map(|collaborator| {
1164 let peer_id = proto::PeerId {
1165 owner_id: collaborator.connection_server_id.0 as u32,
1166 id: collaborator.connection_id as u32,
1167 };
1168 proto::Collaborator {
1169 peer_id: Some(peer_id),
1170 replica_id: collaborator.replica_id.0 as u32,
1171 user_id: collaborator.user_id.to_proto(),
1172 }
1173 })
1174 .filter(|collaborator| collaborator.peer_id != Some(session.connection_id.into()))
1175 .collect::<Vec<_>>();
1176 let worktrees = project
1177 .worktrees
1178 .iter()
1179 .map(|(id, worktree)| proto::WorktreeMetadata {
1180 id: *id,
1181 root_name: worktree.root_name.clone(),
1182 visible: worktree.visible,
1183 abs_path: worktree.abs_path.clone(),
1184 })
1185 .collect::<Vec<_>>();
1186
1187 for collaborator in &collaborators {
1188 session
1189 .peer
1190 .send(
1191 collaborator.peer_id.unwrap().into(),
1192 proto::AddProjectCollaborator {
1193 project_id: project_id.to_proto(),
1194 collaborator: Some(proto::Collaborator {
1195 peer_id: Some(session.connection_id.into()),
1196 replica_id: replica_id.0 as u32,
1197 user_id: guest_user_id.to_proto(),
1198 }),
1199 },
1200 )
1201 .trace_err();
1202 }
1203
1204 // First, we send the metadata associated with each worktree.
1205 response.send(proto::JoinProjectResponse {
1206 worktrees: worktrees.clone(),
1207 replica_id: replica_id.0 as u32,
1208 collaborators: collaborators.clone(),
1209 language_servers: project.language_servers.clone(),
1210 })?;
1211
1212 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1213 #[cfg(any(test, feature = "test-support"))]
1214 const MAX_CHUNK_SIZE: usize = 2;
1215 #[cfg(not(any(test, feature = "test-support")))]
1216 const MAX_CHUNK_SIZE: usize = 256;
1217
1218 // Stream this worktree's entries.
1219 let message = proto::UpdateWorktree {
1220 project_id: project_id.to_proto(),
1221 worktree_id,
1222 abs_path: worktree.abs_path.clone(),
1223 root_name: worktree.root_name,
1224 updated_entries: worktree.entries,
1225 removed_entries: Default::default(),
1226 scan_id: worktree.scan_id,
1227 is_last_update: worktree.is_complete,
1228 };
1229 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1230 session.peer.send(session.connection_id, update.clone())?;
1231 }
1232
1233 // Stream this worktree's diagnostics.
1234 for summary in worktree.diagnostic_summaries {
1235 session.peer.send(
1236 session.connection_id,
1237 proto::UpdateDiagnosticSummary {
1238 project_id: project_id.to_proto(),
1239 worktree_id: worktree.id,
1240 summary: Some(summary),
1241 },
1242 )?;
1243 }
1244 }
1245
1246 for language_server in &project.language_servers {
1247 session.peer.send(
1248 session.connection_id,
1249 proto::UpdateLanguageServer {
1250 project_id: project_id.to_proto(),
1251 language_server_id: language_server.id,
1252 variant: Some(
1253 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1254 proto::LspDiskBasedDiagnosticsUpdated {},
1255 ),
1256 ),
1257 },
1258 )?;
1259 }
1260
1261 Ok(())
1262}
1263
1264async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1265 let sender_id = session.connection_id;
1266 let project_id = ProjectId::from_proto(request.project_id);
1267
1268 let project = session
1269 .db()
1270 .await
1271 .leave_project(project_id, sender_id)
1272 .await?;
1273 tracing::info!(
1274 %project_id,
1275 host_user_id = %project.host_user_id,
1276 host_connection_id = %project.host_connection_id,
1277 "leave project"
1278 );
1279 project_left(&project, &session);
1280
1281 Ok(())
1282}
1283
1284async fn update_project(
1285 request: proto::UpdateProject,
1286 response: Response<proto::UpdateProject>,
1287 session: Session,
1288) -> Result<()> {
1289 let project_id = ProjectId::from_proto(request.project_id);
1290 let (room, guest_connection_ids) = &*session
1291 .db()
1292 .await
1293 .update_project(project_id, session.connection_id, &request.worktrees)
1294 .await?;
1295 broadcast(
1296 session.connection_id,
1297 guest_connection_ids.iter().copied(),
1298 |connection_id| {
1299 session
1300 .peer
1301 .forward_send(session.connection_id, connection_id, request.clone())
1302 },
1303 );
1304 room_updated(&room, &session.peer);
1305 response.send(proto::Ack {})?;
1306
1307 Ok(())
1308}
1309
1310async fn update_worktree(
1311 request: proto::UpdateWorktree,
1312 response: Response<proto::UpdateWorktree>,
1313 session: Session,
1314) -> Result<()> {
1315 let guest_connection_ids = session
1316 .db()
1317 .await
1318 .update_worktree(&request, session.connection_id)
1319 .await?;
1320
1321 broadcast(
1322 session.connection_id,
1323 guest_connection_ids.iter().copied(),
1324 |connection_id| {
1325 session
1326 .peer
1327 .forward_send(session.connection_id, connection_id, request.clone())
1328 },
1329 );
1330 response.send(proto::Ack {})?;
1331 Ok(())
1332}
1333
1334async fn update_diagnostic_summary(
1335 message: proto::UpdateDiagnosticSummary,
1336 session: Session,
1337) -> Result<()> {
1338 let guest_connection_ids = session
1339 .db()
1340 .await
1341 .update_diagnostic_summary(&message, session.connection_id)
1342 .await?;
1343
1344 broadcast(
1345 session.connection_id,
1346 guest_connection_ids.iter().copied(),
1347 |connection_id| {
1348 session
1349 .peer
1350 .forward_send(session.connection_id, connection_id, message.clone())
1351 },
1352 );
1353
1354 Ok(())
1355}
1356
1357async fn start_language_server(
1358 request: proto::StartLanguageServer,
1359 session: Session,
1360) -> Result<()> {
1361 let guest_connection_ids = session
1362 .db()
1363 .await
1364 .start_language_server(&request, session.connection_id)
1365 .await?;
1366
1367 broadcast(
1368 session.connection_id,
1369 guest_connection_ids.iter().copied(),
1370 |connection_id| {
1371 session
1372 .peer
1373 .forward_send(session.connection_id, connection_id, request.clone())
1374 },
1375 );
1376 Ok(())
1377}
1378
1379async fn update_language_server(
1380 request: proto::UpdateLanguageServer,
1381 session: Session,
1382) -> Result<()> {
1383 let project_id = ProjectId::from_proto(request.project_id);
1384 let project_connection_ids = session
1385 .db()
1386 .await
1387 .project_connection_ids(project_id, session.connection_id)
1388 .await?;
1389 broadcast(
1390 session.connection_id,
1391 project_connection_ids.iter().copied(),
1392 |connection_id| {
1393 session
1394 .peer
1395 .forward_send(session.connection_id, connection_id, request.clone())
1396 },
1397 );
1398 Ok(())
1399}
1400
1401async fn forward_project_request<T>(
1402 request: T,
1403 response: Response<T>,
1404 session: Session,
1405) -> Result<()>
1406where
1407 T: EntityMessage + RequestMessage,
1408{
1409 let project_id = ProjectId::from_proto(request.remote_entity_id());
1410 let host_connection_id = {
1411 let collaborators = session
1412 .db()
1413 .await
1414 .project_collaborators(project_id, session.connection_id)
1415 .await?;
1416 let host = collaborators
1417 .iter()
1418 .find(|collaborator| collaborator.is_host)
1419 .ok_or_else(|| anyhow!("host not found"))?;
1420 ConnectionId {
1421 owner_id: host.connection_server_id.0 as u32,
1422 id: host.connection_id as u32,
1423 }
1424 };
1425
1426 let payload = session
1427 .peer
1428 .forward_request(session.connection_id, host_connection_id, request)
1429 .await?;
1430
1431 response.send(payload)?;
1432 Ok(())
1433}
1434
1435async fn save_buffer(
1436 request: proto::SaveBuffer,
1437 response: Response<proto::SaveBuffer>,
1438 session: Session,
1439) -> Result<()> {
1440 let project_id = ProjectId::from_proto(request.project_id);
1441 let host_connection_id = {
1442 let collaborators = session
1443 .db()
1444 .await
1445 .project_collaborators(project_id, session.connection_id)
1446 .await?;
1447 let host = collaborators
1448 .iter()
1449 .find(|collaborator| collaborator.is_host)
1450 .ok_or_else(|| anyhow!("host not found"))?;
1451 ConnectionId {
1452 owner_id: host.connection_server_id.0 as u32,
1453 id: host.connection_id as u32,
1454 }
1455 };
1456 let response_payload = session
1457 .peer
1458 .forward_request(session.connection_id, host_connection_id, request.clone())
1459 .await?;
1460
1461 let mut collaborators = session
1462 .db()
1463 .await
1464 .project_collaborators(project_id, session.connection_id)
1465 .await?;
1466 collaborators.retain(|collaborator| {
1467 let collaborator_connection = ConnectionId {
1468 owner_id: collaborator.connection_server_id.0 as u32,
1469 id: collaborator.connection_id as u32,
1470 };
1471 collaborator_connection != session.connection_id
1472 });
1473 let project_connection_ids = collaborators.iter().map(|collaborator| ConnectionId {
1474 owner_id: collaborator.connection_server_id.0 as u32,
1475 id: collaborator.connection_id as u32,
1476 });
1477 broadcast(host_connection_id, project_connection_ids, |conn_id| {
1478 session
1479 .peer
1480 .forward_send(host_connection_id, conn_id, response_payload.clone())
1481 });
1482 response.send(response_payload)?;
1483 Ok(())
1484}
1485
1486async fn create_buffer_for_peer(
1487 request: proto::CreateBufferForPeer,
1488 session: Session,
1489) -> Result<()> {
1490 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1491 session
1492 .peer
1493 .forward_send(session.connection_id, peer_id.into(), request)?;
1494 Ok(())
1495}
1496
1497async fn update_buffer(
1498 request: proto::UpdateBuffer,
1499 response: Response<proto::UpdateBuffer>,
1500 session: Session,
1501) -> Result<()> {
1502 let project_id = ProjectId::from_proto(request.project_id);
1503 let project_connection_ids = session
1504 .db()
1505 .await
1506 .project_connection_ids(project_id, session.connection_id)
1507 .await?;
1508
1509 broadcast(
1510 session.connection_id,
1511 project_connection_ids.iter().copied(),
1512 |connection_id| {
1513 session
1514 .peer
1515 .forward_send(session.connection_id, connection_id, request.clone())
1516 },
1517 );
1518 response.send(proto::Ack {})?;
1519 Ok(())
1520}
1521
1522async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1523 let project_id = ProjectId::from_proto(request.project_id);
1524 let project_connection_ids = session
1525 .db()
1526 .await
1527 .project_connection_ids(project_id, session.connection_id)
1528 .await?;
1529
1530 broadcast(
1531 session.connection_id,
1532 project_connection_ids.iter().copied(),
1533 |connection_id| {
1534 session
1535 .peer
1536 .forward_send(session.connection_id, connection_id, request.clone())
1537 },
1538 );
1539 Ok(())
1540}
1541
1542async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1543 let project_id = ProjectId::from_proto(request.project_id);
1544 let project_connection_ids = session
1545 .db()
1546 .await
1547 .project_connection_ids(project_id, session.connection_id)
1548 .await?;
1549 broadcast(
1550 session.connection_id,
1551 project_connection_ids.iter().copied(),
1552 |connection_id| {
1553 session
1554 .peer
1555 .forward_send(session.connection_id, connection_id, request.clone())
1556 },
1557 );
1558 Ok(())
1559}
1560
1561async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1562 let project_id = ProjectId::from_proto(request.project_id);
1563 let project_connection_ids = session
1564 .db()
1565 .await
1566 .project_connection_ids(project_id, session.connection_id)
1567 .await?;
1568 broadcast(
1569 session.connection_id,
1570 project_connection_ids.iter().copied(),
1571 |connection_id| {
1572 session
1573 .peer
1574 .forward_send(session.connection_id, connection_id, request.clone())
1575 },
1576 );
1577 Ok(())
1578}
1579
1580async fn follow(
1581 request: proto::Follow,
1582 response: Response<proto::Follow>,
1583 session: Session,
1584) -> Result<()> {
1585 let project_id = ProjectId::from_proto(request.project_id);
1586 let leader_id = request
1587 .leader_id
1588 .ok_or_else(|| anyhow!("invalid leader id"))?
1589 .into();
1590 let follower_id = session.connection_id;
1591 {
1592 let project_connection_ids = session
1593 .db()
1594 .await
1595 .project_connection_ids(project_id, session.connection_id)
1596 .await?;
1597
1598 if !project_connection_ids.contains(&leader_id) {
1599 Err(anyhow!("no such peer"))?;
1600 }
1601 }
1602
1603 let mut response_payload = session
1604 .peer
1605 .forward_request(session.connection_id, leader_id, request)
1606 .await?;
1607 response_payload
1608 .views
1609 .retain(|view| view.leader_id != Some(follower_id.into()));
1610 response.send(response_payload)?;
1611 Ok(())
1612}
1613
1614async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1615 let project_id = ProjectId::from_proto(request.project_id);
1616 let leader_id = request
1617 .leader_id
1618 .ok_or_else(|| anyhow!("invalid leader id"))?
1619 .into();
1620 let project_connection_ids = session
1621 .db()
1622 .await
1623 .project_connection_ids(project_id, session.connection_id)
1624 .await?;
1625 if !project_connection_ids.contains(&leader_id) {
1626 Err(anyhow!("no such peer"))?;
1627 }
1628 session
1629 .peer
1630 .forward_send(session.connection_id, leader_id, request)?;
1631 Ok(())
1632}
1633
1634async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1635 let project_id = ProjectId::from_proto(request.project_id);
1636 let project_connection_ids = session
1637 .db
1638 .lock()
1639 .await
1640 .project_connection_ids(project_id, session.connection_id)
1641 .await?;
1642
1643 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1644 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1645 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1646 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1647 });
1648 for follower_peer_id in request.follower_ids.iter().copied() {
1649 let follower_connection_id = follower_peer_id.into();
1650 if project_connection_ids.contains(&follower_connection_id)
1651 && Some(follower_peer_id) != leader_id
1652 {
1653 session.peer.forward_send(
1654 session.connection_id,
1655 follower_connection_id,
1656 request.clone(),
1657 )?;
1658 }
1659 }
1660 Ok(())
1661}
1662
1663async fn get_users(
1664 request: proto::GetUsers,
1665 response: Response<proto::GetUsers>,
1666 session: Session,
1667) -> Result<()> {
1668 let user_ids = request
1669 .user_ids
1670 .into_iter()
1671 .map(UserId::from_proto)
1672 .collect();
1673 let users = session
1674 .db()
1675 .await
1676 .get_users_by_ids(user_ids)
1677 .await?
1678 .into_iter()
1679 .map(|user| proto::User {
1680 id: user.id.to_proto(),
1681 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1682 github_login: user.github_login,
1683 })
1684 .collect();
1685 response.send(proto::UsersResponse { users })?;
1686 Ok(())
1687}
1688
1689async fn fuzzy_search_users(
1690 request: proto::FuzzySearchUsers,
1691 response: Response<proto::FuzzySearchUsers>,
1692 session: Session,
1693) -> Result<()> {
1694 let query = request.query;
1695 let users = match query.len() {
1696 0 => vec![],
1697 1 | 2 => session
1698 .db()
1699 .await
1700 .get_user_by_github_account(&query, None)
1701 .await?
1702 .into_iter()
1703 .collect(),
1704 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1705 };
1706 let users = users
1707 .into_iter()
1708 .filter(|user| user.id != session.user_id)
1709 .map(|user| proto::User {
1710 id: user.id.to_proto(),
1711 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1712 github_login: user.github_login,
1713 })
1714 .collect();
1715 response.send(proto::UsersResponse { users })?;
1716 Ok(())
1717}
1718
1719async fn request_contact(
1720 request: proto::RequestContact,
1721 response: Response<proto::RequestContact>,
1722 session: Session,
1723) -> Result<()> {
1724 let requester_id = session.user_id;
1725 let responder_id = UserId::from_proto(request.responder_id);
1726 if requester_id == responder_id {
1727 return Err(anyhow!("cannot add yourself as a contact"))?;
1728 }
1729
1730 session
1731 .db()
1732 .await
1733 .send_contact_request(requester_id, responder_id)
1734 .await?;
1735
1736 // Update outgoing contact requests of requester
1737 let mut update = proto::UpdateContacts::default();
1738 update.outgoing_requests.push(responder_id.to_proto());
1739 for connection_id in session
1740 .connection_pool()
1741 .await
1742 .user_connection_ids(requester_id)
1743 {
1744 session.peer.send(connection_id, update.clone())?;
1745 }
1746
1747 // Update incoming contact requests of responder
1748 let mut update = proto::UpdateContacts::default();
1749 update
1750 .incoming_requests
1751 .push(proto::IncomingContactRequest {
1752 requester_id: requester_id.to_proto(),
1753 should_notify: true,
1754 });
1755 for connection_id in session
1756 .connection_pool()
1757 .await
1758 .user_connection_ids(responder_id)
1759 {
1760 session.peer.send(connection_id, update.clone())?;
1761 }
1762
1763 response.send(proto::Ack {})?;
1764 Ok(())
1765}
1766
1767async fn respond_to_contact_request(
1768 request: proto::RespondToContactRequest,
1769 response: Response<proto::RespondToContactRequest>,
1770 session: Session,
1771) -> Result<()> {
1772 let responder_id = session.user_id;
1773 let requester_id = UserId::from_proto(request.requester_id);
1774 let db = session.db().await;
1775 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1776 db.dismiss_contact_notification(responder_id, requester_id)
1777 .await?;
1778 } else {
1779 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1780
1781 db.respond_to_contact_request(responder_id, requester_id, accept)
1782 .await?;
1783 let requester_busy = db.is_user_busy(requester_id).await?;
1784 let responder_busy = db.is_user_busy(responder_id).await?;
1785
1786 let pool = session.connection_pool().await;
1787 // Update responder with new contact
1788 let mut update = proto::UpdateContacts::default();
1789 if accept {
1790 update
1791 .contacts
1792 .push(contact_for_user(requester_id, false, requester_busy, &pool));
1793 }
1794 update
1795 .remove_incoming_requests
1796 .push(requester_id.to_proto());
1797 for connection_id in pool.user_connection_ids(responder_id) {
1798 session.peer.send(connection_id, update.clone())?;
1799 }
1800
1801 // Update requester with new contact
1802 let mut update = proto::UpdateContacts::default();
1803 if accept {
1804 update
1805 .contacts
1806 .push(contact_for_user(responder_id, true, responder_busy, &pool));
1807 }
1808 update
1809 .remove_outgoing_requests
1810 .push(responder_id.to_proto());
1811 for connection_id in pool.user_connection_ids(requester_id) {
1812 session.peer.send(connection_id, update.clone())?;
1813 }
1814 }
1815
1816 response.send(proto::Ack {})?;
1817 Ok(())
1818}
1819
1820async fn remove_contact(
1821 request: proto::RemoveContact,
1822 response: Response<proto::RemoveContact>,
1823 session: Session,
1824) -> Result<()> {
1825 let requester_id = session.user_id;
1826 let responder_id = UserId::from_proto(request.user_id);
1827 let db = session.db().await;
1828 db.remove_contact(requester_id, responder_id).await?;
1829
1830 let pool = session.connection_pool().await;
1831 // Update outgoing contact requests of requester
1832 let mut update = proto::UpdateContacts::default();
1833 update
1834 .remove_outgoing_requests
1835 .push(responder_id.to_proto());
1836 for connection_id in pool.user_connection_ids(requester_id) {
1837 session.peer.send(connection_id, update.clone())?;
1838 }
1839
1840 // Update incoming contact requests of responder
1841 let mut update = proto::UpdateContacts::default();
1842 update
1843 .remove_incoming_requests
1844 .push(requester_id.to_proto());
1845 for connection_id in pool.user_connection_ids(responder_id) {
1846 session.peer.send(connection_id, update.clone())?;
1847 }
1848
1849 response.send(proto::Ack {})?;
1850 Ok(())
1851}
1852
1853async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1854 let project_id = ProjectId::from_proto(request.project_id);
1855 let project_connection_ids = session
1856 .db()
1857 .await
1858 .project_connection_ids(project_id, session.connection_id)
1859 .await?;
1860 broadcast(
1861 session.connection_id,
1862 project_connection_ids.iter().copied(),
1863 |connection_id| {
1864 session
1865 .peer
1866 .forward_send(session.connection_id, connection_id, request.clone())
1867 },
1868 );
1869 Ok(())
1870}
1871
1872async fn get_private_user_info(
1873 _request: proto::GetPrivateUserInfo,
1874 response: Response<proto::GetPrivateUserInfo>,
1875 session: Session,
1876) -> Result<()> {
1877 let metrics_id = session
1878 .db()
1879 .await
1880 .get_user_metrics_id(session.user_id)
1881 .await?;
1882 let user = session
1883 .db()
1884 .await
1885 .get_user_by_id(session.user_id)
1886 .await?
1887 .ok_or_else(|| anyhow!("user not found"))?;
1888 response.send(proto::GetPrivateUserInfoResponse {
1889 metrics_id,
1890 staff: user.admin,
1891 })?;
1892 Ok(())
1893}
1894
1895fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1896 match message {
1897 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1898 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1899 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1900 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1901 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1902 code: frame.code.into(),
1903 reason: frame.reason,
1904 })),
1905 }
1906}
1907
1908fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1909 match message {
1910 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1911 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1912 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1913 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1914 AxumMessage::Close(frame) => {
1915 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1916 code: frame.code.into(),
1917 reason: frame.reason,
1918 }))
1919 }
1920 }
1921}
1922
1923fn build_initial_contacts_update(
1924 contacts: Vec<db::Contact>,
1925 pool: &ConnectionPool,
1926) -> proto::UpdateContacts {
1927 let mut update = proto::UpdateContacts::default();
1928
1929 for contact in contacts {
1930 match contact {
1931 db::Contact::Accepted {
1932 user_id,
1933 should_notify,
1934 busy,
1935 } => {
1936 update
1937 .contacts
1938 .push(contact_for_user(user_id, should_notify, busy, &pool));
1939 }
1940 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
1941 db::Contact::Incoming {
1942 user_id,
1943 should_notify,
1944 } => update
1945 .incoming_requests
1946 .push(proto::IncomingContactRequest {
1947 requester_id: user_id.to_proto(),
1948 should_notify,
1949 }),
1950 }
1951 }
1952
1953 update
1954}
1955
1956fn contact_for_user(
1957 user_id: UserId,
1958 should_notify: bool,
1959 busy: bool,
1960 pool: &ConnectionPool,
1961) -> proto::Contact {
1962 proto::Contact {
1963 user_id: user_id.to_proto(),
1964 online: pool.is_user_online(user_id),
1965 busy,
1966 should_notify,
1967 }
1968}
1969
1970fn room_updated(room: &proto::Room, peer: &Peer) {
1971 for participant in &room.participants {
1972 if let Some(peer_id) = participant
1973 .peer_id
1974 .ok_or_else(|| anyhow!("invalid participant peer id"))
1975 .trace_err()
1976 {
1977 peer.send(
1978 peer_id.into(),
1979 proto::RoomUpdated {
1980 room: Some(room.clone()),
1981 },
1982 )
1983 .trace_err();
1984 }
1985 }
1986}
1987
1988async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
1989 let db = session.db().await;
1990 let contacts = db.get_contacts(user_id).await?;
1991 let busy = db.is_user_busy(user_id).await?;
1992
1993 let pool = session.connection_pool().await;
1994 let updated_contact = contact_for_user(user_id, false, busy, &pool);
1995 for contact in contacts {
1996 if let db::Contact::Accepted {
1997 user_id: contact_user_id,
1998 ..
1999 } = contact
2000 {
2001 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2002 session
2003 .peer
2004 .send(
2005 contact_conn_id,
2006 proto::UpdateContacts {
2007 contacts: vec![updated_contact.clone()],
2008 remove_contacts: Default::default(),
2009 incoming_requests: Default::default(),
2010 remove_incoming_requests: Default::default(),
2011 outgoing_requests: Default::default(),
2012 remove_outgoing_requests: Default::default(),
2013 },
2014 )
2015 .trace_err();
2016 }
2017 }
2018 }
2019 Ok(())
2020}
2021
2022async fn leave_room_for_session(session: &Session) -> Result<()> {
2023 let mut contacts_to_update = HashSet::default();
2024
2025 let room_id;
2026 let canceled_calls_to_user_ids;
2027 let live_kit_room;
2028 let delete_live_kit_room;
2029 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2030 contacts_to_update.insert(session.user_id);
2031
2032 for project in left_room.left_projects.values() {
2033 project_left(project, session);
2034 }
2035
2036 room_updated(&left_room.room, &session.peer);
2037 room_id = RoomId::from_proto(left_room.room.id);
2038 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2039 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2040 delete_live_kit_room = left_room.room.participants.is_empty();
2041 } else {
2042 return Ok(());
2043 }
2044
2045 {
2046 let pool = session.connection_pool().await;
2047 for canceled_user_id in canceled_calls_to_user_ids {
2048 for connection_id in pool.user_connection_ids(canceled_user_id) {
2049 session
2050 .peer
2051 .send(
2052 connection_id,
2053 proto::CallCanceled {
2054 room_id: room_id.to_proto(),
2055 },
2056 )
2057 .trace_err();
2058 }
2059 contacts_to_update.insert(canceled_user_id);
2060 }
2061 }
2062
2063 for contact_user_id in contacts_to_update {
2064 update_user_contacts(contact_user_id, &session).await?;
2065 }
2066
2067 if let Some(live_kit) = session.live_kit_client.as_ref() {
2068 live_kit
2069 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2070 .await
2071 .trace_err();
2072
2073 if delete_live_kit_room {
2074 live_kit.delete_room(live_kit_room).await.trace_err();
2075 }
2076 }
2077
2078 Ok(())
2079}
2080
2081fn project_left(project: &db::LeftProject, session: &Session) {
2082 for connection_id in &project.connection_ids {
2083 if project.host_user_id == session.user_id {
2084 session
2085 .peer
2086 .send(
2087 *connection_id,
2088 proto::UnshareProject {
2089 project_id: project.id.to_proto(),
2090 },
2091 )
2092 .trace_err();
2093 } else {
2094 session
2095 .peer
2096 .send(
2097 *connection_id,
2098 proto::RemoveProjectCollaborator {
2099 project_id: project.id.to_proto(),
2100 peer_id: Some(session.connection_id.into()),
2101 },
2102 )
2103 .trace_err();
2104 }
2105 }
2106
2107 session
2108 .peer
2109 .send(
2110 session.connection_id,
2111 proto::UnshareProject {
2112 project_id: project.id.to_proto(),
2113 },
2114 )
2115 .trace_err();
2116}
2117
2118pub trait ResultExt {
2119 type Ok;
2120
2121 fn trace_err(self) -> Option<Self::Ok>;
2122}
2123
2124impl<T, E> ResultExt for Result<T, E>
2125where
2126 E: std::fmt::Debug,
2127{
2128 type Ok = T;
2129
2130 fn trace_err(self) -> Option<T> {
2131 match self {
2132 Ok(value) => Some(value),
2133 Err(error) => {
2134 tracing::error!("{:?}", error);
2135 None
2136 }
2137 }
2138 }
2139}