1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, Database, ProjectId, RoomId, ServerId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
38 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
39};
40use serde::{Serialize, Serializer};
41use std::{
42 any::TypeId,
43 fmt,
44 future::Future,
45 marker::PhantomData,
46 mem,
47 net::SocketAddr,
48 ops::{Deref, DerefMut},
49 rc::Rc,
50 sync::{
51 atomic::{AtomicBool, Ordering::SeqCst},
52 Arc,
53 },
54 time::Duration,
55};
56use tokio::sync::watch;
57use tower::ServiceBuilder;
58use tracing::{info_span, instrument, Instrument};
59
60pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(5);
61pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
62
63lazy_static! {
64 static ref METRIC_CONNECTIONS: IntGauge =
65 register_int_gauge!("connections", "number of connections").unwrap();
66 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
67 "shared_projects",
68 "number of open projects with one or more guests"
69 )
70 .unwrap();
71}
72
73type MessageHandler =
74 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
75
76struct Response<R> {
77 peer: Arc<Peer>,
78 receipt: Receipt<R>,
79 responded: Arc<AtomicBool>,
80}
81
82impl<R: RequestMessage> Response<R> {
83 fn send(self, payload: R::Response) -> Result<()> {
84 self.responded.store(true, SeqCst);
85 self.peer.respond(self.receipt, payload)?;
86 Ok(())
87 }
88}
89
90#[derive(Clone)]
91struct Session {
92 user_id: UserId,
93 connection_id: ConnectionId,
94 db: Arc<tokio::sync::Mutex<DbHandle>>,
95 peer: Arc<Peer>,
96 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
97 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
98}
99
100impl Session {
101 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
102 #[cfg(test)]
103 tokio::task::yield_now().await;
104 let guard = self.db.lock().await;
105 #[cfg(test)]
106 tokio::task::yield_now().await;
107 guard
108 }
109
110 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
111 #[cfg(test)]
112 tokio::task::yield_now().await;
113 let guard = self.connection_pool.lock();
114 ConnectionPoolGuard {
115 guard,
116 _not_send: PhantomData,
117 }
118 }
119}
120
121impl fmt::Debug for Session {
122 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123 f.debug_struct("Session")
124 .field("user_id", &self.user_id)
125 .field("connection_id", &self.connection_id)
126 .finish()
127 }
128}
129
130struct DbHandle(Arc<Database>);
131
132impl Deref for DbHandle {
133 type Target = Database;
134
135 fn deref(&self) -> &Self::Target {
136 self.0.as_ref()
137 }
138}
139
140pub struct Server {
141 id: parking_lot::Mutex<ServerId>,
142 peer: Arc<Peer>,
143 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
144 app_state: Arc<AppState>,
145 executor: Executor,
146 handlers: HashMap<TypeId, MessageHandler>,
147 teardown: watch::Sender<()>,
148}
149
150pub(crate) struct ConnectionPoolGuard<'a> {
151 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
152 _not_send: PhantomData<Rc<()>>,
153}
154
155#[derive(Serialize)]
156pub struct ServerSnapshot<'a> {
157 peer: &'a Peer,
158 #[serde(serialize_with = "serialize_deref")]
159 connection_pool: ConnectionPoolGuard<'a>,
160}
161
162pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
163where
164 S: Serializer,
165 T: Deref<Target = U>,
166 U: Serialize,
167{
168 Serialize::serialize(value.deref(), serializer)
169}
170
171impl Server {
172 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
173 let mut server = Self {
174 id: parking_lot::Mutex::new(id),
175 peer: Peer::new(id.0 as u32),
176 app_state,
177 executor,
178 connection_pool: Default::default(),
179 handlers: Default::default(),
180 teardown: watch::channel(()).0,
181 };
182
183 server
184 .add_request_handler(ping)
185 .add_request_handler(create_room)
186 .add_request_handler(join_room)
187 .add_message_handler(leave_room)
188 .add_request_handler(call)
189 .add_request_handler(cancel_call)
190 .add_message_handler(decline_call)
191 .add_request_handler(update_participant_location)
192 .add_request_handler(share_project)
193 .add_message_handler(unshare_project)
194 .add_request_handler(join_project)
195 .add_message_handler(leave_project)
196 .add_request_handler(update_project)
197 .add_request_handler(update_worktree)
198 .add_message_handler(start_language_server)
199 .add_message_handler(update_language_server)
200 .add_message_handler(update_diagnostic_summary)
201 .add_request_handler(forward_project_request::<proto::GetHover>)
202 .add_request_handler(forward_project_request::<proto::GetDefinition>)
203 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
204 .add_request_handler(forward_project_request::<proto::GetReferences>)
205 .add_request_handler(forward_project_request::<proto::SearchProject>)
206 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
207 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
208 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
209 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
210 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
211 .add_request_handler(forward_project_request::<proto::GetCompletions>)
212 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
213 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
214 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
215 .add_request_handler(forward_project_request::<proto::PrepareRename>)
216 .add_request_handler(forward_project_request::<proto::PerformRename>)
217 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
218 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
219 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
220 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
221 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
222 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
223 .add_message_handler(create_buffer_for_peer)
224 .add_request_handler(update_buffer)
225 .add_message_handler(update_buffer_file)
226 .add_message_handler(buffer_reloaded)
227 .add_message_handler(buffer_saved)
228 .add_request_handler(save_buffer)
229 .add_request_handler(get_users)
230 .add_request_handler(fuzzy_search_users)
231 .add_request_handler(request_contact)
232 .add_request_handler(remove_contact)
233 .add_request_handler(respond_to_contact_request)
234 .add_request_handler(follow)
235 .add_message_handler(unfollow)
236 .add_message_handler(update_followers)
237 .add_message_handler(update_diff_base)
238 .add_request_handler(get_private_user_info);
239
240 Arc::new(server)
241 }
242
243 pub async fn start(&self) -> Result<()> {
244 let server_id = *self.id.lock();
245 let app_state = self.app_state.clone();
246 let peer = self.peer.clone();
247 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
248 let pool = self.connection_pool.clone();
249 let live_kit_client = self.app_state.live_kit_client.clone();
250
251 let span = info_span!("start server");
252 let span_enter = span.enter();
253
254 tracing::info!("begin deleting stale projects");
255 app_state
256 .db
257 .delete_stale_projects(&app_state.config.zed_environment, server_id)
258 .await?;
259 tracing::info!("finish deleting stale projects");
260
261 drop(span_enter);
262 self.executor.spawn_detached(
263 async move {
264 tracing::info!("waiting for cleanup timeout");
265 timeout.await;
266 tracing::info!("cleanup timeout expired, retrieving stale rooms");
267 if let Some(room_ids) = app_state
268 .db
269 .stale_room_ids(&app_state.config.zed_environment, server_id)
270 .await
271 .trace_err()
272 {
273 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
274 for room_id in room_ids {
275 let mut contacts_to_update = HashSet::default();
276 let mut canceled_calls_to_user_ids = Vec::new();
277 let mut live_kit_room = String::new();
278 let mut delete_live_kit_room = false;
279
280 if let Ok(mut refreshed_room) =
281 app_state.db.refresh_room(room_id, server_id).await
282 {
283 tracing::info!(
284 room_id = room_id.0,
285 new_participant_count = refreshed_room.room.participants.len(),
286 "refreshed room"
287 );
288 room_updated(&refreshed_room.room, &peer);
289 contacts_to_update
290 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
291 contacts_to_update
292 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
293 canceled_calls_to_user_ids =
294 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
295 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
296 delete_live_kit_room = refreshed_room.room.participants.is_empty();
297 }
298
299 {
300 let pool = pool.lock();
301 for canceled_user_id in canceled_calls_to_user_ids {
302 for connection_id in pool.user_connection_ids(canceled_user_id) {
303 peer.send(
304 connection_id,
305 proto::CallCanceled {
306 room_id: room_id.to_proto(),
307 },
308 )
309 .trace_err();
310 }
311 }
312 }
313
314 for user_id in contacts_to_update {
315 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
316 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
317 if let Some((busy, contacts)) = busy.zip(contacts) {
318 let pool = pool.lock();
319 let updated_contact = contact_for_user(user_id, false, busy, &pool);
320 for contact in contacts {
321 if let db::Contact::Accepted {
322 user_id: contact_user_id,
323 ..
324 } = contact
325 {
326 for contact_conn_id in
327 pool.user_connection_ids(contact_user_id)
328 {
329 peer.send(
330 contact_conn_id,
331 proto::UpdateContacts {
332 contacts: vec![updated_contact.clone()],
333 remove_contacts: Default::default(),
334 incoming_requests: Default::default(),
335 remove_incoming_requests: Default::default(),
336 outgoing_requests: Default::default(),
337 remove_outgoing_requests: Default::default(),
338 },
339 )
340 .trace_err();
341 }
342 }
343 }
344 }
345 }
346
347 if let Some(live_kit) = live_kit_client.as_ref() {
348 if delete_live_kit_room {
349 live_kit.delete_room(live_kit_room).await.trace_err();
350 }
351 }
352 }
353 }
354
355 app_state
356 .db
357 .delete_stale_servers(server_id, &app_state.config.zed_environment)
358 .await
359 .trace_err();
360 }
361 .instrument(span),
362 );
363 Ok(())
364 }
365
366 pub fn teardown(&self) {
367 self.peer.teardown();
368 self.connection_pool.lock().reset();
369 let _ = self.teardown.send(());
370 }
371
372 #[cfg(test)]
373 pub fn reset(&self, id: ServerId) {
374 self.teardown();
375 *self.id.lock() = id;
376 self.peer.reset(id.0 as u32);
377 }
378
379 #[cfg(test)]
380 pub fn id(&self) -> ServerId {
381 *self.id.lock()
382 }
383
384 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
385 where
386 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
387 Fut: 'static + Send + Future<Output = Result<()>>,
388 M: EnvelopedMessage,
389 {
390 let prev_handler = self.handlers.insert(
391 TypeId::of::<M>(),
392 Box::new(move |envelope, session| {
393 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
394 let span = info_span!(
395 "handle message",
396 payload_type = envelope.payload_type_name()
397 );
398 span.in_scope(|| {
399 tracing::info!(
400 payload_type = envelope.payload_type_name(),
401 "message received"
402 );
403 });
404 let future = (handler)(*envelope, session);
405 async move {
406 if let Err(error) = future.await {
407 tracing::error!(%error, "error handling message");
408 }
409 }
410 .instrument(span)
411 .boxed()
412 }),
413 );
414 if prev_handler.is_some() {
415 panic!("registered a handler for the same message twice");
416 }
417 self
418 }
419
420 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
421 where
422 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
423 Fut: 'static + Send + Future<Output = Result<()>>,
424 M: EnvelopedMessage,
425 {
426 self.add_handler(move |envelope, session| handler(envelope.payload, session));
427 self
428 }
429
430 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
431 where
432 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
433 Fut: Send + Future<Output = Result<()>>,
434 M: RequestMessage,
435 {
436 let handler = Arc::new(handler);
437 self.add_handler(move |envelope, session| {
438 let receipt = envelope.receipt();
439 let handler = handler.clone();
440 async move {
441 let peer = session.peer.clone();
442 let responded = Arc::new(AtomicBool::default());
443 let response = Response {
444 peer: peer.clone(),
445 responded: responded.clone(),
446 receipt,
447 };
448 match (handler)(envelope.payload, response, session).await {
449 Ok(()) => {
450 if responded.load(std::sync::atomic::Ordering::SeqCst) {
451 Ok(())
452 } else {
453 Err(anyhow!("handler did not send a response"))?
454 }
455 }
456 Err(error) => {
457 peer.respond_with_error(
458 receipt,
459 proto::Error {
460 message: error.to_string(),
461 },
462 )?;
463 Err(error)
464 }
465 }
466 }
467 })
468 }
469
470 pub fn handle_connection(
471 self: &Arc<Self>,
472 connection: Connection,
473 address: String,
474 user: User,
475 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
476 executor: Executor,
477 ) -> impl Future<Output = Result<()>> {
478 let this = self.clone();
479 let user_id = user.id;
480 let login = user.github_login;
481 let span = info_span!("handle connection", %user_id, %login, %address);
482 let mut teardown = self.teardown.subscribe();
483 async move {
484 let (connection_id, handle_io, mut incoming_rx) = this
485 .peer
486 .add_connection(connection, {
487 let executor = executor.clone();
488 move |duration| executor.sleep(duration)
489 });
490
491 tracing::info!(%user_id, %login, ?connection_id, %address, "connection opened");
492 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
493 tracing::info!(%user_id, %login, ?connection_id, %address, "sent hello message");
494
495 if let Some(send_connection_id) = send_connection_id.take() {
496 let _ = send_connection_id.send(connection_id);
497 }
498
499 if !user.connected_once {
500 this.peer.send(connection_id, proto::ShowContacts {})?;
501 this.app_state.db.set_user_connected_once(user_id, true).await?;
502 }
503
504 let (contacts, invite_code) = future::try_join(
505 this.app_state.db.get_contacts(user_id),
506 this.app_state.db.get_invite_code_for_user(user_id)
507 ).await?;
508
509 {
510 let mut pool = this.connection_pool.lock();
511 pool.add_connection(connection_id, user_id, user.admin);
512 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
513
514 if let Some((code, count)) = invite_code {
515 this.peer.send(connection_id, proto::UpdateInviteInfo {
516 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
517 count: count as u32,
518 })?;
519 }
520 }
521
522 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
523 this.peer.send(connection_id, incoming_call)?;
524 }
525
526 let session = Session {
527 user_id,
528 connection_id,
529 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
530 peer: this.peer.clone(),
531 connection_pool: this.connection_pool.clone(),
532 live_kit_client: this.app_state.live_kit_client.clone()
533 };
534 update_user_contacts(user_id, &session).await?;
535
536 let handle_io = handle_io.fuse();
537 futures::pin_mut!(handle_io);
538
539 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
540 // This prevents deadlocks when e.g., client A performs a request to client B and
541 // client B performs a request to client A. If both clients stop processing further
542 // messages until their respective request completes, they won't have a chance to
543 // respond to the other client's request and cause a deadlock.
544 //
545 // This arrangement ensures we will attempt to process earlier messages first, but fall
546 // back to processing messages arrived later in the spirit of making progress.
547 let mut foreground_message_handlers = FuturesUnordered::new();
548 loop {
549 let next_message = incoming_rx.next().fuse();
550 futures::pin_mut!(next_message);
551 futures::select_biased! {
552 _ = teardown.changed().fuse() => return Ok(()),
553 result = handle_io => {
554 if let Err(error) = result {
555 tracing::error!(?error, %user_id, %login, ?connection_id, %address, "error handling I/O");
556 }
557 break;
558 }
559 _ = foreground_message_handlers.next() => {}
560 message = next_message => {
561 if let Some(message) = message {
562 let type_name = message.payload_type_name();
563 let span = tracing::info_span!("receive message", %user_id, %login, ?connection_id, %address, type_name);
564 let span_enter = span.enter();
565 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
566 let is_background = message.is_background();
567 let handle_message = (handler)(message, session.clone());
568 drop(span_enter);
569
570 let handle_message = handle_message.instrument(span);
571 if is_background {
572 executor.spawn_detached(handle_message);
573 } else {
574 foreground_message_handlers.push(handle_message);
575 }
576 } else {
577 tracing::error!(%user_id, %login, ?connection_id, %address, "no message handler");
578 }
579 } else {
580 tracing::info!(%user_id, %login, ?connection_id, %address, "connection closed");
581 break;
582 }
583 }
584 }
585 }
586
587 drop(foreground_message_handlers);
588 tracing::info!(%user_id, %login, ?connection_id, %address, "signing out");
589 if let Err(error) = sign_out(session, teardown, executor).await {
590 tracing::error!(%user_id, %login, ?connection_id, %address, ?error, "error signing out");
591 }
592
593 Ok(())
594 }.instrument(span)
595 }
596
597 pub async fn invite_code_redeemed(
598 self: &Arc<Self>,
599 inviter_id: UserId,
600 invitee_id: UserId,
601 ) -> Result<()> {
602 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
603 if let Some(code) = &user.invite_code {
604 let pool = self.connection_pool.lock();
605 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
606 for connection_id in pool.user_connection_ids(inviter_id) {
607 self.peer.send(
608 connection_id,
609 proto::UpdateContacts {
610 contacts: vec![invitee_contact.clone()],
611 ..Default::default()
612 },
613 )?;
614 self.peer.send(
615 connection_id,
616 proto::UpdateInviteInfo {
617 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
618 count: user.invite_count as u32,
619 },
620 )?;
621 }
622 }
623 }
624 Ok(())
625 }
626
627 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
628 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
629 if let Some(invite_code) = &user.invite_code {
630 let pool = self.connection_pool.lock();
631 for connection_id in pool.user_connection_ids(user_id) {
632 self.peer.send(
633 connection_id,
634 proto::UpdateInviteInfo {
635 url: format!(
636 "{}{}",
637 self.app_state.config.invite_link_prefix, invite_code
638 ),
639 count: user.invite_count as u32,
640 },
641 )?;
642 }
643 }
644 }
645 Ok(())
646 }
647
648 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
649 ServerSnapshot {
650 connection_pool: ConnectionPoolGuard {
651 guard: self.connection_pool.lock(),
652 _not_send: PhantomData,
653 },
654 peer: &self.peer,
655 }
656 }
657}
658
659impl<'a> Deref for ConnectionPoolGuard<'a> {
660 type Target = ConnectionPool;
661
662 fn deref(&self) -> &Self::Target {
663 &*self.guard
664 }
665}
666
667impl<'a> DerefMut for ConnectionPoolGuard<'a> {
668 fn deref_mut(&mut self) -> &mut Self::Target {
669 &mut *self.guard
670 }
671}
672
673impl<'a> Drop for ConnectionPoolGuard<'a> {
674 fn drop(&mut self) {
675 #[cfg(test)]
676 self.check_invariants();
677 }
678}
679
680fn broadcast<F>(
681 sender_id: ConnectionId,
682 receiver_ids: impl IntoIterator<Item = ConnectionId>,
683 mut f: F,
684) where
685 F: FnMut(ConnectionId) -> anyhow::Result<()>,
686{
687 for receiver_id in receiver_ids {
688 if receiver_id != sender_id {
689 f(receiver_id).trace_err();
690 }
691 }
692}
693
694lazy_static! {
695 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
696}
697
698pub struct ProtocolVersion(u32);
699
700impl Header for ProtocolVersion {
701 fn name() -> &'static HeaderName {
702 &ZED_PROTOCOL_VERSION
703 }
704
705 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
706 where
707 Self: Sized,
708 I: Iterator<Item = &'i axum::http::HeaderValue>,
709 {
710 let version = values
711 .next()
712 .ok_or_else(axum::headers::Error::invalid)?
713 .to_str()
714 .map_err(|_| axum::headers::Error::invalid())?
715 .parse()
716 .map_err(|_| axum::headers::Error::invalid())?;
717 Ok(Self(version))
718 }
719
720 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
721 values.extend([self.0.to_string().parse().unwrap()]);
722 }
723}
724
725pub fn routes(server: Arc<Server>) -> Router<Body> {
726 Router::new()
727 .route("/rpc", get(handle_websocket_request))
728 .layer(
729 ServiceBuilder::new()
730 .layer(Extension(server.app_state.clone()))
731 .layer(middleware::from_fn(auth::validate_header)),
732 )
733 .route("/metrics", get(handle_metrics))
734 .layer(Extension(server))
735}
736
737pub async fn handle_websocket_request(
738 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
739 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
740 Extension(server): Extension<Arc<Server>>,
741 Extension(user): Extension<User>,
742 ws: WebSocketUpgrade,
743) -> axum::response::Response {
744 if protocol_version != rpc::PROTOCOL_VERSION {
745 return (
746 StatusCode::UPGRADE_REQUIRED,
747 "client must be upgraded".to_string(),
748 )
749 .into_response();
750 }
751 let socket_address = socket_address.to_string();
752 ws.on_upgrade(move |socket| {
753 use util::ResultExt;
754 let socket = socket
755 .map_ok(to_tungstenite_message)
756 .err_into()
757 .with(|message| async move { Ok(to_axum_message(message)) });
758 let connection = Connection::new(Box::pin(socket));
759 async move {
760 server
761 .handle_connection(connection, socket_address, user, None, Executor::Production)
762 .await
763 .log_err();
764 }
765 })
766}
767
768pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
769 let connections = server
770 .connection_pool
771 .lock()
772 .connections()
773 .filter(|connection| !connection.admin)
774 .count();
775
776 METRIC_CONNECTIONS.set(connections as _);
777
778 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
779 METRIC_SHARED_PROJECTS.set(shared_projects as _);
780
781 let encoder = prometheus::TextEncoder::new();
782 let metric_families = prometheus::gather();
783 let encoded_metrics = encoder
784 .encode_to_string(&metric_families)
785 .map_err(|err| anyhow!("{}", err))?;
786 Ok(encoded_metrics)
787}
788
789#[instrument(err, skip(executor))]
790async fn sign_out(
791 session: Session,
792 mut teardown: watch::Receiver<()>,
793 executor: Executor,
794) -> Result<()> {
795 session.peer.disconnect(session.connection_id);
796 session
797 .connection_pool()
798 .await
799 .remove_connection(session.connection_id)?;
800
801 if let Some(mut left_projects) = session
802 .db()
803 .await
804 .connection_lost(session.connection_id)
805 .await
806 .trace_err()
807 {
808 for left_project in mem::take(&mut *left_projects) {
809 project_left(&left_project, &session);
810 }
811 }
812
813 futures::select_biased! {
814 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
815 leave_room_for_session(&session).await.trace_err();
816
817 if !session
818 .connection_pool()
819 .await
820 .is_user_online(session.user_id)
821 {
822 let db = session.db().await;
823 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
824 room_updated(&room, &session.peer);
825 }
826 }
827 update_user_contacts(session.user_id, &session).await?;
828 }
829 _ = teardown.changed().fuse() => {}
830 }
831
832 Ok(())
833}
834
835async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
836 response.send(proto::Ack {})?;
837 Ok(())
838}
839
840async fn create_room(
841 _request: proto::CreateRoom,
842 response: Response<proto::CreateRoom>,
843 session: Session,
844) -> Result<()> {
845 let live_kit_room = nanoid::nanoid!(30);
846 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
847 if let Some(_) = live_kit
848 .create_room(live_kit_room.clone())
849 .await
850 .trace_err()
851 {
852 if let Some(token) = live_kit
853 .room_token(&live_kit_room, &session.connection_id.to_string())
854 .trace_err()
855 {
856 Some(proto::LiveKitConnectionInfo {
857 server_url: live_kit.url().into(),
858 token,
859 })
860 } else {
861 None
862 }
863 } else {
864 None
865 }
866 } else {
867 None
868 };
869
870 {
871 let room = session
872 .db()
873 .await
874 .create_room(session.user_id, session.connection_id, &live_kit_room)
875 .await?;
876
877 response.send(proto::CreateRoomResponse {
878 room: Some(room.clone()),
879 live_kit_connection_info,
880 })?;
881 }
882
883 update_user_contacts(session.user_id, &session).await?;
884 Ok(())
885}
886
887async fn join_room(
888 request: proto::JoinRoom,
889 response: Response<proto::JoinRoom>,
890 session: Session,
891) -> Result<()> {
892 let room_id = RoomId::from_proto(request.id);
893 let room = {
894 let room = session
895 .db()
896 .await
897 .join_room(room_id, session.user_id, session.connection_id)
898 .await?;
899 room_updated(&room, &session.peer);
900 room.clone()
901 };
902
903 for connection_id in session
904 .connection_pool()
905 .await
906 .user_connection_ids(session.user_id)
907 {
908 session
909 .peer
910 .send(
911 connection_id,
912 proto::CallCanceled {
913 room_id: room_id.to_proto(),
914 },
915 )
916 .trace_err();
917 }
918
919 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
920 if let Some(token) = live_kit
921 .room_token(&room.live_kit_room, &session.connection_id.to_string())
922 .trace_err()
923 {
924 Some(proto::LiveKitConnectionInfo {
925 server_url: live_kit.url().into(),
926 token,
927 })
928 } else {
929 None
930 }
931 } else {
932 None
933 };
934
935 response.send(proto::JoinRoomResponse {
936 room: Some(room),
937 live_kit_connection_info,
938 })?;
939
940 update_user_contacts(session.user_id, &session).await?;
941 Ok(())
942}
943
944async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
945 leave_room_for_session(&session).await
946}
947
948async fn call(
949 request: proto::Call,
950 response: Response<proto::Call>,
951 session: Session,
952) -> Result<()> {
953 let room_id = RoomId::from_proto(request.room_id);
954 let calling_user_id = session.user_id;
955 let calling_connection_id = session.connection_id;
956 let called_user_id = UserId::from_proto(request.called_user_id);
957 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
958 if !session
959 .db()
960 .await
961 .has_contact(calling_user_id, called_user_id)
962 .await?
963 {
964 return Err(anyhow!("cannot call a user who isn't a contact"))?;
965 }
966
967 let incoming_call = {
968 let (room, incoming_call) = &mut *session
969 .db()
970 .await
971 .call(
972 room_id,
973 calling_user_id,
974 calling_connection_id,
975 called_user_id,
976 initial_project_id,
977 )
978 .await?;
979 room_updated(&room, &session.peer);
980 mem::take(incoming_call)
981 };
982 update_user_contacts(called_user_id, &session).await?;
983
984 let mut calls = session
985 .connection_pool()
986 .await
987 .user_connection_ids(called_user_id)
988 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
989 .collect::<FuturesUnordered<_>>();
990
991 while let Some(call_response) = calls.next().await {
992 match call_response.as_ref() {
993 Ok(_) => {
994 response.send(proto::Ack {})?;
995 return Ok(());
996 }
997 Err(_) => {
998 call_response.trace_err();
999 }
1000 }
1001 }
1002
1003 {
1004 let room = session
1005 .db()
1006 .await
1007 .call_failed(room_id, called_user_id)
1008 .await?;
1009 room_updated(&room, &session.peer);
1010 }
1011 update_user_contacts(called_user_id, &session).await?;
1012
1013 Err(anyhow!("failed to ring user"))?
1014}
1015
1016async fn cancel_call(
1017 request: proto::CancelCall,
1018 response: Response<proto::CancelCall>,
1019 session: Session,
1020) -> Result<()> {
1021 let called_user_id = UserId::from_proto(request.called_user_id);
1022 let room_id = RoomId::from_proto(request.room_id);
1023 {
1024 let room = session
1025 .db()
1026 .await
1027 .cancel_call(room_id, session.connection_id, called_user_id)
1028 .await?;
1029 room_updated(&room, &session.peer);
1030 }
1031
1032 for connection_id in session
1033 .connection_pool()
1034 .await
1035 .user_connection_ids(called_user_id)
1036 {
1037 session
1038 .peer
1039 .send(
1040 connection_id,
1041 proto::CallCanceled {
1042 room_id: room_id.to_proto(),
1043 },
1044 )
1045 .trace_err();
1046 }
1047 response.send(proto::Ack {})?;
1048
1049 update_user_contacts(called_user_id, &session).await?;
1050 Ok(())
1051}
1052
1053async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1054 let room_id = RoomId::from_proto(message.room_id);
1055 {
1056 let room = session
1057 .db()
1058 .await
1059 .decline_call(Some(room_id), session.user_id)
1060 .await?
1061 .ok_or_else(|| anyhow!("failed to decline call"))?;
1062 room_updated(&room, &session.peer);
1063 }
1064
1065 for connection_id in session
1066 .connection_pool()
1067 .await
1068 .user_connection_ids(session.user_id)
1069 {
1070 session
1071 .peer
1072 .send(
1073 connection_id,
1074 proto::CallCanceled {
1075 room_id: room_id.to_proto(),
1076 },
1077 )
1078 .trace_err();
1079 }
1080 update_user_contacts(session.user_id, &session).await?;
1081 Ok(())
1082}
1083
1084async fn update_participant_location(
1085 request: proto::UpdateParticipantLocation,
1086 response: Response<proto::UpdateParticipantLocation>,
1087 session: Session,
1088) -> Result<()> {
1089 let room_id = RoomId::from_proto(request.room_id);
1090 let location = request
1091 .location
1092 .ok_or_else(|| anyhow!("invalid location"))?;
1093 let room = session
1094 .db()
1095 .await
1096 .update_room_participant_location(room_id, session.connection_id, location)
1097 .await?;
1098 room_updated(&room, &session.peer);
1099 response.send(proto::Ack {})?;
1100 Ok(())
1101}
1102
1103async fn share_project(
1104 request: proto::ShareProject,
1105 response: Response<proto::ShareProject>,
1106 session: Session,
1107) -> Result<()> {
1108 let (project_id, room) = &*session
1109 .db()
1110 .await
1111 .share_project(
1112 RoomId::from_proto(request.room_id),
1113 session.connection_id,
1114 &request.worktrees,
1115 )
1116 .await?;
1117 response.send(proto::ShareProjectResponse {
1118 project_id: project_id.to_proto(),
1119 })?;
1120 room_updated(&room, &session.peer);
1121
1122 Ok(())
1123}
1124
1125async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1126 let project_id = ProjectId::from_proto(message.project_id);
1127
1128 let (room, guest_connection_ids) = &*session
1129 .db()
1130 .await
1131 .unshare_project(project_id, session.connection_id)
1132 .await?;
1133
1134 broadcast(
1135 session.connection_id,
1136 guest_connection_ids.iter().copied(),
1137 |conn_id| session.peer.send(conn_id, message.clone()),
1138 );
1139 room_updated(&room, &session.peer);
1140
1141 Ok(())
1142}
1143
1144async fn join_project(
1145 request: proto::JoinProject,
1146 response: Response<proto::JoinProject>,
1147 session: Session,
1148) -> Result<()> {
1149 let project_id = ProjectId::from_proto(request.project_id);
1150 let guest_user_id = session.user_id;
1151
1152 tracing::info!(%project_id, "join project");
1153
1154 let (project, replica_id) = &mut *session
1155 .db()
1156 .await
1157 .join_project(project_id, session.connection_id)
1158 .await?;
1159
1160 let collaborators = project
1161 .collaborators
1162 .iter()
1163 .map(|collaborator| {
1164 let peer_id = proto::PeerId {
1165 owner_id: collaborator.connection_server_id.0 as u32,
1166 id: collaborator.connection_id as u32,
1167 };
1168 proto::Collaborator {
1169 peer_id: Some(peer_id),
1170 replica_id: collaborator.replica_id.0 as u32,
1171 user_id: collaborator.user_id.to_proto(),
1172 }
1173 })
1174 .filter(|collaborator| collaborator.peer_id != Some(session.connection_id.into()))
1175 .collect::<Vec<_>>();
1176 let worktrees = project
1177 .worktrees
1178 .iter()
1179 .map(|(id, worktree)| proto::WorktreeMetadata {
1180 id: *id,
1181 root_name: worktree.root_name.clone(),
1182 visible: worktree.visible,
1183 abs_path: worktree.abs_path.clone(),
1184 })
1185 .collect::<Vec<_>>();
1186
1187 for collaborator in &collaborators {
1188 session
1189 .peer
1190 .send(
1191 collaborator.peer_id.unwrap().into(),
1192 proto::AddProjectCollaborator {
1193 project_id: project_id.to_proto(),
1194 collaborator: Some(proto::Collaborator {
1195 peer_id: Some(session.connection_id.into()),
1196 replica_id: replica_id.0 as u32,
1197 user_id: guest_user_id.to_proto(),
1198 }),
1199 },
1200 )
1201 .trace_err();
1202 }
1203
1204 // First, we send the metadata associated with each worktree.
1205 response.send(proto::JoinProjectResponse {
1206 worktrees: worktrees.clone(),
1207 replica_id: replica_id.0 as u32,
1208 collaborators: collaborators.clone(),
1209 language_servers: project.language_servers.clone(),
1210 })?;
1211
1212 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1213 #[cfg(any(test, feature = "test-support"))]
1214 const MAX_CHUNK_SIZE: usize = 2;
1215 #[cfg(not(any(test, feature = "test-support")))]
1216 const MAX_CHUNK_SIZE: usize = 256;
1217
1218 // Stream this worktree's entries.
1219 let message = proto::UpdateWorktree {
1220 project_id: project_id.to_proto(),
1221 worktree_id,
1222 abs_path: worktree.abs_path.clone(),
1223 root_name: worktree.root_name,
1224 updated_entries: worktree.entries,
1225 removed_entries: Default::default(),
1226 scan_id: worktree.scan_id,
1227 is_last_update: worktree.is_complete,
1228 };
1229 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1230 session.peer.send(session.connection_id, update.clone())?;
1231 }
1232
1233 // Stream this worktree's diagnostics.
1234 for summary in worktree.diagnostic_summaries {
1235 session.peer.send(
1236 session.connection_id,
1237 proto::UpdateDiagnosticSummary {
1238 project_id: project_id.to_proto(),
1239 worktree_id: worktree.id,
1240 summary: Some(summary),
1241 },
1242 )?;
1243 }
1244 }
1245
1246 for language_server in &project.language_servers {
1247 session.peer.send(
1248 session.connection_id,
1249 proto::UpdateLanguageServer {
1250 project_id: project_id.to_proto(),
1251 language_server_id: language_server.id,
1252 variant: Some(
1253 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1254 proto::LspDiskBasedDiagnosticsUpdated {},
1255 ),
1256 ),
1257 },
1258 )?;
1259 }
1260
1261 Ok(())
1262}
1263
1264async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1265 let sender_id = session.connection_id;
1266 let project_id = ProjectId::from_proto(request.project_id);
1267
1268 let project = session
1269 .db()
1270 .await
1271 .leave_project(project_id, sender_id)
1272 .await?;
1273 tracing::info!(
1274 %project_id,
1275 host_user_id = %project.host_user_id,
1276 host_connection_id = %project.host_connection_id,
1277 "leave project"
1278 );
1279 project_left(&project, &session);
1280
1281 Ok(())
1282}
1283
1284async fn update_project(
1285 request: proto::UpdateProject,
1286 response: Response<proto::UpdateProject>,
1287 session: Session,
1288) -> Result<()> {
1289 let project_id = ProjectId::from_proto(request.project_id);
1290 let (room, guest_connection_ids) = &*session
1291 .db()
1292 .await
1293 .update_project(project_id, session.connection_id, &request.worktrees)
1294 .await?;
1295 broadcast(
1296 session.connection_id,
1297 guest_connection_ids.iter().copied(),
1298 |connection_id| {
1299 session
1300 .peer
1301 .forward_send(session.connection_id, connection_id, request.clone())
1302 },
1303 );
1304 room_updated(&room, &session.peer);
1305 response.send(proto::Ack {})?;
1306
1307 Ok(())
1308}
1309
1310async fn update_worktree(
1311 request: proto::UpdateWorktree,
1312 response: Response<proto::UpdateWorktree>,
1313 session: Session,
1314) -> Result<()> {
1315 let guest_connection_ids = session
1316 .db()
1317 .await
1318 .update_worktree(&request, session.connection_id)
1319 .await?;
1320
1321 broadcast(
1322 session.connection_id,
1323 guest_connection_ids.iter().copied(),
1324 |connection_id| {
1325 session
1326 .peer
1327 .forward_send(session.connection_id, connection_id, request.clone())
1328 },
1329 );
1330 response.send(proto::Ack {})?;
1331 Ok(())
1332}
1333
1334async fn update_diagnostic_summary(
1335 message: proto::UpdateDiagnosticSummary,
1336 session: Session,
1337) -> Result<()> {
1338 let guest_connection_ids = session
1339 .db()
1340 .await
1341 .update_diagnostic_summary(&message, session.connection_id)
1342 .await?;
1343
1344 broadcast(
1345 session.connection_id,
1346 guest_connection_ids.iter().copied(),
1347 |connection_id| {
1348 session
1349 .peer
1350 .forward_send(session.connection_id, connection_id, message.clone())
1351 },
1352 );
1353
1354 Ok(())
1355}
1356
1357async fn start_language_server(
1358 request: proto::StartLanguageServer,
1359 session: Session,
1360) -> Result<()> {
1361 let guest_connection_ids = session
1362 .db()
1363 .await
1364 .start_language_server(&request, session.connection_id)
1365 .await?;
1366
1367 broadcast(
1368 session.connection_id,
1369 guest_connection_ids.iter().copied(),
1370 |connection_id| {
1371 session
1372 .peer
1373 .forward_send(session.connection_id, connection_id, request.clone())
1374 },
1375 );
1376 Ok(())
1377}
1378
1379async fn update_language_server(
1380 request: proto::UpdateLanguageServer,
1381 session: Session,
1382) -> Result<()> {
1383 let project_id = ProjectId::from_proto(request.project_id);
1384 let project_connection_ids = session
1385 .db()
1386 .await
1387 .project_connection_ids(project_id, session.connection_id)
1388 .await?;
1389 broadcast(
1390 session.connection_id,
1391 project_connection_ids.iter().copied(),
1392 |connection_id| {
1393 session
1394 .peer
1395 .forward_send(session.connection_id, connection_id, request.clone())
1396 },
1397 );
1398 Ok(())
1399}
1400
1401async fn forward_project_request<T>(
1402 request: T,
1403 response: Response<T>,
1404 session: Session,
1405) -> Result<()>
1406where
1407 T: EntityMessage + RequestMessage,
1408{
1409 let project_id = ProjectId::from_proto(request.remote_entity_id());
1410 let host_connection_id = {
1411 let collaborators = session
1412 .db()
1413 .await
1414 .project_collaborators(project_id, session.connection_id)
1415 .await?;
1416 let host = collaborators
1417 .iter()
1418 .find(|collaborator| collaborator.is_host)
1419 .ok_or_else(|| anyhow!("host not found"))?;
1420 ConnectionId {
1421 owner_id: host.connection_server_id.0 as u32,
1422 id: host.connection_id as u32,
1423 }
1424 };
1425
1426 let payload = session
1427 .peer
1428 .forward_request(session.connection_id, host_connection_id, request)
1429 .await?;
1430
1431 response.send(payload)?;
1432 Ok(())
1433}
1434
1435async fn save_buffer(
1436 request: proto::SaveBuffer,
1437 response: Response<proto::SaveBuffer>,
1438 session: Session,
1439) -> Result<()> {
1440 let project_id = ProjectId::from_proto(request.project_id);
1441 let host_connection_id = {
1442 let collaborators = session
1443 .db()
1444 .await
1445 .project_collaborators(project_id, session.connection_id)
1446 .await?;
1447 let host = collaborators
1448 .iter()
1449 .find(|collaborator| collaborator.is_host)
1450 .ok_or_else(|| anyhow!("host not found"))?;
1451 ConnectionId {
1452 owner_id: host.connection_server_id.0 as u32,
1453 id: host.connection_id as u32,
1454 }
1455 };
1456 let response_payload = session
1457 .peer
1458 .forward_request(session.connection_id, host_connection_id, request.clone())
1459 .await?;
1460
1461 let mut collaborators = session
1462 .db()
1463 .await
1464 .project_collaborators(project_id, session.connection_id)
1465 .await?;
1466 collaborators.retain(|collaborator| {
1467 let collaborator_connection = ConnectionId {
1468 owner_id: collaborator.connection_server_id.0 as u32,
1469 id: collaborator.connection_id as u32,
1470 };
1471 collaborator_connection != session.connection_id
1472 });
1473 let project_connection_ids = collaborators.iter().map(|collaborator| ConnectionId {
1474 owner_id: collaborator.connection_server_id.0 as u32,
1475 id: collaborator.connection_id as u32,
1476 });
1477 broadcast(host_connection_id, project_connection_ids, |conn_id| {
1478 session
1479 .peer
1480 .forward_send(host_connection_id, conn_id, response_payload.clone())
1481 });
1482 response.send(response_payload)?;
1483 Ok(())
1484}
1485
1486async fn create_buffer_for_peer(
1487 request: proto::CreateBufferForPeer,
1488 session: Session,
1489) -> Result<()> {
1490 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1491 session
1492 .peer
1493 .forward_send(session.connection_id, peer_id.into(), request)?;
1494 Ok(())
1495}
1496
1497async fn update_buffer(
1498 request: proto::UpdateBuffer,
1499 response: Response<proto::UpdateBuffer>,
1500 session: Session,
1501) -> Result<()> {
1502 let project_id = ProjectId::from_proto(request.project_id);
1503 let project_connection_ids = session
1504 .db()
1505 .await
1506 .project_connection_ids(project_id, session.connection_id)
1507 .await?;
1508
1509 dbg!(session.connection_id, &*project_connection_ids);
1510 broadcast(
1511 session.connection_id,
1512 project_connection_ids.iter().copied(),
1513 |connection_id| {
1514 session
1515 .peer
1516 .forward_send(session.connection_id, connection_id, request.clone())
1517 },
1518 );
1519 response.send(proto::Ack {})?;
1520 Ok(())
1521}
1522
1523async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1524 let project_id = ProjectId::from_proto(request.project_id);
1525 let project_connection_ids = session
1526 .db()
1527 .await
1528 .project_connection_ids(project_id, session.connection_id)
1529 .await?;
1530
1531 broadcast(
1532 session.connection_id,
1533 project_connection_ids.iter().copied(),
1534 |connection_id| {
1535 session
1536 .peer
1537 .forward_send(session.connection_id, connection_id, request.clone())
1538 },
1539 );
1540 Ok(())
1541}
1542
1543async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1544 let project_id = ProjectId::from_proto(request.project_id);
1545 let project_connection_ids = session
1546 .db()
1547 .await
1548 .project_connection_ids(project_id, session.connection_id)
1549 .await?;
1550 broadcast(
1551 session.connection_id,
1552 project_connection_ids.iter().copied(),
1553 |connection_id| {
1554 session
1555 .peer
1556 .forward_send(session.connection_id, connection_id, request.clone())
1557 },
1558 );
1559 Ok(())
1560}
1561
1562async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1563 let project_id = ProjectId::from_proto(request.project_id);
1564 let project_connection_ids = session
1565 .db()
1566 .await
1567 .project_connection_ids(project_id, session.connection_id)
1568 .await?;
1569 broadcast(
1570 session.connection_id,
1571 project_connection_ids.iter().copied(),
1572 |connection_id| {
1573 session
1574 .peer
1575 .forward_send(session.connection_id, connection_id, request.clone())
1576 },
1577 );
1578 Ok(())
1579}
1580
1581async fn follow(
1582 request: proto::Follow,
1583 response: Response<proto::Follow>,
1584 session: Session,
1585) -> Result<()> {
1586 let project_id = ProjectId::from_proto(request.project_id);
1587 let leader_id = request
1588 .leader_id
1589 .ok_or_else(|| anyhow!("invalid leader id"))?
1590 .into();
1591 let follower_id = session.connection_id;
1592 {
1593 let project_connection_ids = session
1594 .db()
1595 .await
1596 .project_connection_ids(project_id, session.connection_id)
1597 .await?;
1598
1599 if !project_connection_ids.contains(&leader_id) {
1600 Err(anyhow!("no such peer"))?;
1601 }
1602 }
1603
1604 let mut response_payload = session
1605 .peer
1606 .forward_request(session.connection_id, leader_id, request)
1607 .await?;
1608 response_payload
1609 .views
1610 .retain(|view| view.leader_id != Some(follower_id.into()));
1611 response.send(response_payload)?;
1612 Ok(())
1613}
1614
1615async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1616 let project_id = ProjectId::from_proto(request.project_id);
1617 let leader_id = request
1618 .leader_id
1619 .ok_or_else(|| anyhow!("invalid leader id"))?
1620 .into();
1621 let project_connection_ids = session
1622 .db()
1623 .await
1624 .project_connection_ids(project_id, session.connection_id)
1625 .await?;
1626 if !project_connection_ids.contains(&leader_id) {
1627 Err(anyhow!("no such peer"))?;
1628 }
1629 session
1630 .peer
1631 .forward_send(session.connection_id, leader_id, request)?;
1632 Ok(())
1633}
1634
1635async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1636 let project_id = ProjectId::from_proto(request.project_id);
1637 let project_connection_ids = session
1638 .db
1639 .lock()
1640 .await
1641 .project_connection_ids(project_id, session.connection_id)
1642 .await?;
1643
1644 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1645 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1646 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1647 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1648 });
1649 for follower_peer_id in request.follower_ids.iter().copied() {
1650 let follower_connection_id = follower_peer_id.into();
1651 if project_connection_ids.contains(&follower_connection_id)
1652 && Some(follower_peer_id) != leader_id
1653 {
1654 session.peer.forward_send(
1655 session.connection_id,
1656 follower_connection_id,
1657 request.clone(),
1658 )?;
1659 }
1660 }
1661 Ok(())
1662}
1663
1664async fn get_users(
1665 request: proto::GetUsers,
1666 response: Response<proto::GetUsers>,
1667 session: Session,
1668) -> Result<()> {
1669 let user_ids = request
1670 .user_ids
1671 .into_iter()
1672 .map(UserId::from_proto)
1673 .collect();
1674 let users = session
1675 .db()
1676 .await
1677 .get_users_by_ids(user_ids)
1678 .await?
1679 .into_iter()
1680 .map(|user| proto::User {
1681 id: user.id.to_proto(),
1682 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1683 github_login: user.github_login,
1684 })
1685 .collect();
1686 response.send(proto::UsersResponse { users })?;
1687 Ok(())
1688}
1689
1690async fn fuzzy_search_users(
1691 request: proto::FuzzySearchUsers,
1692 response: Response<proto::FuzzySearchUsers>,
1693 session: Session,
1694) -> Result<()> {
1695 let query = request.query;
1696 let users = match query.len() {
1697 0 => vec![],
1698 1 | 2 => session
1699 .db()
1700 .await
1701 .get_user_by_github_account(&query, None)
1702 .await?
1703 .into_iter()
1704 .collect(),
1705 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1706 };
1707 let users = users
1708 .into_iter()
1709 .filter(|user| user.id != session.user_id)
1710 .map(|user| proto::User {
1711 id: user.id.to_proto(),
1712 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1713 github_login: user.github_login,
1714 })
1715 .collect();
1716 response.send(proto::UsersResponse { users })?;
1717 Ok(())
1718}
1719
1720async fn request_contact(
1721 request: proto::RequestContact,
1722 response: Response<proto::RequestContact>,
1723 session: Session,
1724) -> Result<()> {
1725 let requester_id = session.user_id;
1726 let responder_id = UserId::from_proto(request.responder_id);
1727 if requester_id == responder_id {
1728 return Err(anyhow!("cannot add yourself as a contact"))?;
1729 }
1730
1731 session
1732 .db()
1733 .await
1734 .send_contact_request(requester_id, responder_id)
1735 .await?;
1736
1737 // Update outgoing contact requests of requester
1738 let mut update = proto::UpdateContacts::default();
1739 update.outgoing_requests.push(responder_id.to_proto());
1740 for connection_id in session
1741 .connection_pool()
1742 .await
1743 .user_connection_ids(requester_id)
1744 {
1745 session.peer.send(connection_id, update.clone())?;
1746 }
1747
1748 // Update incoming contact requests of responder
1749 let mut update = proto::UpdateContacts::default();
1750 update
1751 .incoming_requests
1752 .push(proto::IncomingContactRequest {
1753 requester_id: requester_id.to_proto(),
1754 should_notify: true,
1755 });
1756 for connection_id in session
1757 .connection_pool()
1758 .await
1759 .user_connection_ids(responder_id)
1760 {
1761 session.peer.send(connection_id, update.clone())?;
1762 }
1763
1764 response.send(proto::Ack {})?;
1765 Ok(())
1766}
1767
1768async fn respond_to_contact_request(
1769 request: proto::RespondToContactRequest,
1770 response: Response<proto::RespondToContactRequest>,
1771 session: Session,
1772) -> Result<()> {
1773 let responder_id = session.user_id;
1774 let requester_id = UserId::from_proto(request.requester_id);
1775 let db = session.db().await;
1776 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1777 db.dismiss_contact_notification(responder_id, requester_id)
1778 .await?;
1779 } else {
1780 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1781
1782 db.respond_to_contact_request(responder_id, requester_id, accept)
1783 .await?;
1784 let requester_busy = db.is_user_busy(requester_id).await?;
1785 let responder_busy = db.is_user_busy(responder_id).await?;
1786
1787 let pool = session.connection_pool().await;
1788 // Update responder with new contact
1789 let mut update = proto::UpdateContacts::default();
1790 if accept {
1791 update
1792 .contacts
1793 .push(contact_for_user(requester_id, false, requester_busy, &pool));
1794 }
1795 update
1796 .remove_incoming_requests
1797 .push(requester_id.to_proto());
1798 for connection_id in pool.user_connection_ids(responder_id) {
1799 session.peer.send(connection_id, update.clone())?;
1800 }
1801
1802 // Update requester with new contact
1803 let mut update = proto::UpdateContacts::default();
1804 if accept {
1805 update
1806 .contacts
1807 .push(contact_for_user(responder_id, true, responder_busy, &pool));
1808 }
1809 update
1810 .remove_outgoing_requests
1811 .push(responder_id.to_proto());
1812 for connection_id in pool.user_connection_ids(requester_id) {
1813 session.peer.send(connection_id, update.clone())?;
1814 }
1815 }
1816
1817 response.send(proto::Ack {})?;
1818 Ok(())
1819}
1820
1821async fn remove_contact(
1822 request: proto::RemoveContact,
1823 response: Response<proto::RemoveContact>,
1824 session: Session,
1825) -> Result<()> {
1826 let requester_id = session.user_id;
1827 let responder_id = UserId::from_proto(request.user_id);
1828 let db = session.db().await;
1829 db.remove_contact(requester_id, responder_id).await?;
1830
1831 let pool = session.connection_pool().await;
1832 // Update outgoing contact requests of requester
1833 let mut update = proto::UpdateContacts::default();
1834 update
1835 .remove_outgoing_requests
1836 .push(responder_id.to_proto());
1837 for connection_id in pool.user_connection_ids(requester_id) {
1838 session.peer.send(connection_id, update.clone())?;
1839 }
1840
1841 // Update incoming contact requests of responder
1842 let mut update = proto::UpdateContacts::default();
1843 update
1844 .remove_incoming_requests
1845 .push(requester_id.to_proto());
1846 for connection_id in pool.user_connection_ids(responder_id) {
1847 session.peer.send(connection_id, update.clone())?;
1848 }
1849
1850 response.send(proto::Ack {})?;
1851 Ok(())
1852}
1853
1854async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1855 let project_id = ProjectId::from_proto(request.project_id);
1856 let project_connection_ids = session
1857 .db()
1858 .await
1859 .project_connection_ids(project_id, session.connection_id)
1860 .await?;
1861 broadcast(
1862 session.connection_id,
1863 project_connection_ids.iter().copied(),
1864 |connection_id| {
1865 session
1866 .peer
1867 .forward_send(session.connection_id, connection_id, request.clone())
1868 },
1869 );
1870 Ok(())
1871}
1872
1873async fn get_private_user_info(
1874 _request: proto::GetPrivateUserInfo,
1875 response: Response<proto::GetPrivateUserInfo>,
1876 session: Session,
1877) -> Result<()> {
1878 let metrics_id = session
1879 .db()
1880 .await
1881 .get_user_metrics_id(session.user_id)
1882 .await?;
1883 let user = session
1884 .db()
1885 .await
1886 .get_user_by_id(session.user_id)
1887 .await?
1888 .ok_or_else(|| anyhow!("user not found"))?;
1889 response.send(proto::GetPrivateUserInfoResponse {
1890 metrics_id,
1891 staff: user.admin,
1892 })?;
1893 Ok(())
1894}
1895
1896fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1897 match message {
1898 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1899 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1900 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1901 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1902 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1903 code: frame.code.into(),
1904 reason: frame.reason,
1905 })),
1906 }
1907}
1908
1909fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1910 match message {
1911 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1912 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1913 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1914 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1915 AxumMessage::Close(frame) => {
1916 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1917 code: frame.code.into(),
1918 reason: frame.reason,
1919 }))
1920 }
1921 }
1922}
1923
1924fn build_initial_contacts_update(
1925 contacts: Vec<db::Contact>,
1926 pool: &ConnectionPool,
1927) -> proto::UpdateContacts {
1928 let mut update = proto::UpdateContacts::default();
1929
1930 for contact in contacts {
1931 match contact {
1932 db::Contact::Accepted {
1933 user_id,
1934 should_notify,
1935 busy,
1936 } => {
1937 update
1938 .contacts
1939 .push(contact_for_user(user_id, should_notify, busy, &pool));
1940 }
1941 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
1942 db::Contact::Incoming {
1943 user_id,
1944 should_notify,
1945 } => update
1946 .incoming_requests
1947 .push(proto::IncomingContactRequest {
1948 requester_id: user_id.to_proto(),
1949 should_notify,
1950 }),
1951 }
1952 }
1953
1954 update
1955}
1956
1957fn contact_for_user(
1958 user_id: UserId,
1959 should_notify: bool,
1960 busy: bool,
1961 pool: &ConnectionPool,
1962) -> proto::Contact {
1963 proto::Contact {
1964 user_id: user_id.to_proto(),
1965 online: pool.is_user_online(user_id),
1966 busy,
1967 should_notify,
1968 }
1969}
1970
1971fn room_updated(room: &proto::Room, peer: &Peer) {
1972 for participant in &room.participants {
1973 if let Some(peer_id) = participant
1974 .peer_id
1975 .ok_or_else(|| anyhow!("invalid participant peer id"))
1976 .trace_err()
1977 {
1978 peer.send(
1979 peer_id.into(),
1980 proto::RoomUpdated {
1981 room: Some(room.clone()),
1982 },
1983 )
1984 .trace_err();
1985 }
1986 }
1987}
1988
1989async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
1990 let db = session.db().await;
1991 let contacts = db.get_contacts(user_id).await?;
1992 let busy = db.is_user_busy(user_id).await?;
1993
1994 let pool = session.connection_pool().await;
1995 let updated_contact = contact_for_user(user_id, false, busy, &pool);
1996 for contact in contacts {
1997 if let db::Contact::Accepted {
1998 user_id: contact_user_id,
1999 ..
2000 } = contact
2001 {
2002 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2003 session
2004 .peer
2005 .send(
2006 contact_conn_id,
2007 proto::UpdateContacts {
2008 contacts: vec![updated_contact.clone()],
2009 remove_contacts: Default::default(),
2010 incoming_requests: Default::default(),
2011 remove_incoming_requests: Default::default(),
2012 outgoing_requests: Default::default(),
2013 remove_outgoing_requests: Default::default(),
2014 },
2015 )
2016 .trace_err();
2017 }
2018 }
2019 }
2020 Ok(())
2021}
2022
2023async fn leave_room_for_session(session: &Session) -> Result<()> {
2024 let mut contacts_to_update = HashSet::default();
2025
2026 let room_id;
2027 let canceled_calls_to_user_ids;
2028 let live_kit_room;
2029 let delete_live_kit_room;
2030 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2031 contacts_to_update.insert(session.user_id);
2032
2033 for project in left_room.left_projects.values() {
2034 project_left(project, session);
2035 }
2036
2037 room_updated(&left_room.room, &session.peer);
2038 room_id = RoomId::from_proto(left_room.room.id);
2039 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2040 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2041 delete_live_kit_room = left_room.room.participants.is_empty();
2042 } else {
2043 return Ok(());
2044 }
2045
2046 {
2047 let pool = session.connection_pool().await;
2048 for canceled_user_id in canceled_calls_to_user_ids {
2049 for connection_id in pool.user_connection_ids(canceled_user_id) {
2050 session
2051 .peer
2052 .send(
2053 connection_id,
2054 proto::CallCanceled {
2055 room_id: room_id.to_proto(),
2056 },
2057 )
2058 .trace_err();
2059 }
2060 contacts_to_update.insert(canceled_user_id);
2061 }
2062 }
2063
2064 for contact_user_id in contacts_to_update {
2065 update_user_contacts(contact_user_id, &session).await?;
2066 }
2067
2068 if let Some(live_kit) = session.live_kit_client.as_ref() {
2069 live_kit
2070 .remove_participant(live_kit_room.clone(), session.connection_id.to_string())
2071 .await
2072 .trace_err();
2073
2074 if delete_live_kit_room {
2075 live_kit.delete_room(live_kit_room).await.trace_err();
2076 }
2077 }
2078
2079 Ok(())
2080}
2081
2082fn project_left(project: &db::LeftProject, session: &Session) {
2083 for connection_id in &project.connection_ids {
2084 if project.host_user_id == session.user_id {
2085 session
2086 .peer
2087 .send(
2088 *connection_id,
2089 proto::UnshareProject {
2090 project_id: project.id.to_proto(),
2091 },
2092 )
2093 .trace_err();
2094 } else {
2095 session
2096 .peer
2097 .send(
2098 *connection_id,
2099 proto::RemoveProjectCollaborator {
2100 project_id: project.id.to_proto(),
2101 peer_id: Some(session.connection_id.into()),
2102 },
2103 )
2104 .trace_err();
2105 }
2106 }
2107
2108 session
2109 .peer
2110 .send(
2111 session.connection_id,
2112 proto::UnshareProject {
2113 project_id: project.id.to_proto(),
2114 },
2115 )
2116 .trace_err();
2117}
2118
2119pub trait ResultExt {
2120 type Ok;
2121
2122 fn trace_err(self) -> Option<Self::Ok>;
2123}
2124
2125impl<T, E> ResultExt for Result<T, E>
2126where
2127 E: std::fmt::Debug,
2128{
2129 type Ok = T;
2130
2131 fn trace_err(self) -> Option<T> {
2132 match self {
2133 Ok(value) => Some(value),
2134 Err(error) => {
2135 tracing::error!("{:?}", error);
2136 None
2137 }
2138 }
2139 }
2140}