1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, Database, ProjectId, RoomId, ServerId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
38 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
39};
40use serde::{Serialize, Serializer};
41use std::{
42 any::TypeId,
43 fmt,
44 future::Future,
45 marker::PhantomData,
46 mem,
47 net::SocketAddr,
48 ops::{Deref, DerefMut},
49 rc::Rc,
50 sync::{
51 atomic::{AtomicBool, Ordering::SeqCst},
52 Arc,
53 },
54 time::Duration,
55};
56use tokio::sync::watch;
57use tower::ServiceBuilder;
58use tracing::{info_span, instrument, Instrument};
59
60pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(5);
61pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
62
63lazy_static! {
64 static ref METRIC_CONNECTIONS: IntGauge =
65 register_int_gauge!("connections", "number of connections").unwrap();
66 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
67 "shared_projects",
68 "number of open projects with one or more guests"
69 )
70 .unwrap();
71}
72
73type MessageHandler =
74 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
75
76struct Response<R> {
77 peer: Arc<Peer>,
78 receipt: Receipt<R>,
79 responded: Arc<AtomicBool>,
80}
81
82impl<R: RequestMessage> Response<R> {
83 fn send(self, payload: R::Response) -> Result<()> {
84 self.responded.store(true, SeqCst);
85 self.peer.respond(self.receipt, payload)?;
86 Ok(())
87 }
88}
89
90#[derive(Clone)]
91struct Session {
92 user_id: UserId,
93 connection_id: ConnectionId,
94 db: Arc<tokio::sync::Mutex<DbHandle>>,
95 peer: Arc<Peer>,
96 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
97 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
98}
99
100impl Session {
101 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
102 #[cfg(test)]
103 tokio::task::yield_now().await;
104 let guard = self.db.lock().await;
105 #[cfg(test)]
106 tokio::task::yield_now().await;
107 guard
108 }
109
110 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
111 #[cfg(test)]
112 tokio::task::yield_now().await;
113 let guard = self.connection_pool.lock();
114 ConnectionPoolGuard {
115 guard,
116 _not_send: PhantomData,
117 }
118 }
119}
120
121impl fmt::Debug for Session {
122 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123 f.debug_struct("Session")
124 .field("user_id", &self.user_id)
125 .field("connection_id", &self.connection_id)
126 .finish()
127 }
128}
129
130struct DbHandle(Arc<Database>);
131
132impl Deref for DbHandle {
133 type Target = Database;
134
135 fn deref(&self) -> &Self::Target {
136 self.0.as_ref()
137 }
138}
139
140pub struct Server {
141 id: parking_lot::Mutex<ServerId>,
142 peer: Arc<Peer>,
143 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
144 app_state: Arc<AppState>,
145 executor: Executor,
146 handlers: HashMap<TypeId, MessageHandler>,
147 teardown: watch::Sender<()>,
148}
149
150pub(crate) struct ConnectionPoolGuard<'a> {
151 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
152 _not_send: PhantomData<Rc<()>>,
153}
154
155#[derive(Serialize)]
156pub struct ServerSnapshot<'a> {
157 peer: &'a Peer,
158 #[serde(serialize_with = "serialize_deref")]
159 connection_pool: ConnectionPoolGuard<'a>,
160}
161
162pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
163where
164 S: Serializer,
165 T: Deref<Target = U>,
166 U: Serialize,
167{
168 Serialize::serialize(value.deref(), serializer)
169}
170
171impl Server {
172 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
173 let mut server = Self {
174 id: parking_lot::Mutex::new(id),
175 peer: Peer::new(id.0 as u32),
176 app_state,
177 executor,
178 connection_pool: Default::default(),
179 handlers: Default::default(),
180 teardown: watch::channel(()).0,
181 };
182
183 server
184 .add_request_handler(ping)
185 .add_request_handler(create_room)
186 .add_request_handler(join_room)
187 .add_request_handler(rejoin_room)
188 .add_message_handler(leave_room)
189 .add_request_handler(call)
190 .add_request_handler(cancel_call)
191 .add_message_handler(decline_call)
192 .add_request_handler(update_participant_location)
193 .add_request_handler(share_project)
194 .add_message_handler(unshare_project)
195 .add_request_handler(join_project)
196 .add_message_handler(leave_project)
197 .add_request_handler(update_project)
198 .add_request_handler(update_worktree)
199 .add_message_handler(start_language_server)
200 .add_message_handler(update_language_server)
201 .add_message_handler(update_diagnostic_summary)
202 .add_request_handler(forward_project_request::<proto::GetHover>)
203 .add_request_handler(forward_project_request::<proto::GetDefinition>)
204 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
205 .add_request_handler(forward_project_request::<proto::GetReferences>)
206 .add_request_handler(forward_project_request::<proto::SearchProject>)
207 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
208 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
209 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
210 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
211 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
212 .add_request_handler(forward_project_request::<proto::GetCompletions>)
213 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
214 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
215 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
216 .add_request_handler(forward_project_request::<proto::PrepareRename>)
217 .add_request_handler(forward_project_request::<proto::PerformRename>)
218 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
219 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
220 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
221 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
222 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
223 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
224 .add_message_handler(create_buffer_for_peer)
225 .add_request_handler(update_buffer)
226 .add_message_handler(update_buffer_file)
227 .add_message_handler(buffer_reloaded)
228 .add_message_handler(buffer_saved)
229 .add_request_handler(save_buffer)
230 .add_request_handler(get_users)
231 .add_request_handler(fuzzy_search_users)
232 .add_request_handler(request_contact)
233 .add_request_handler(remove_contact)
234 .add_request_handler(respond_to_contact_request)
235 .add_request_handler(follow)
236 .add_message_handler(unfollow)
237 .add_message_handler(update_followers)
238 .add_message_handler(update_diff_base)
239 .add_request_handler(get_private_user_info);
240
241 Arc::new(server)
242 }
243
244 pub async fn start(&self) -> Result<()> {
245 let server_id = *self.id.lock();
246 let app_state = self.app_state.clone();
247 let peer = self.peer.clone();
248 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
249 let pool = self.connection_pool.clone();
250 let live_kit_client = self.app_state.live_kit_client.clone();
251
252 let span = info_span!("start server");
253 let span_enter = span.enter();
254
255 tracing::info!("begin deleting stale projects");
256 app_state
257 .db
258 .delete_stale_projects(&app_state.config.zed_environment, server_id)
259 .await?;
260 tracing::info!("finish deleting stale projects");
261
262 drop(span_enter);
263 self.executor.spawn_detached(
264 async move {
265 tracing::info!("waiting for cleanup timeout");
266 timeout.await;
267 tracing::info!("cleanup timeout expired, retrieving stale rooms");
268 if let Some(room_ids) = app_state
269 .db
270 .stale_room_ids(&app_state.config.zed_environment, server_id)
271 .await
272 .trace_err()
273 {
274 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
275 for room_id in room_ids {
276 let mut contacts_to_update = HashSet::default();
277 let mut canceled_calls_to_user_ids = Vec::new();
278 let mut live_kit_room = String::new();
279 let mut delete_live_kit_room = false;
280
281 if let Ok(mut refreshed_room) =
282 app_state.db.refresh_room(room_id, server_id).await
283 {
284 tracing::info!(
285 room_id = room_id.0,
286 new_participant_count = refreshed_room.room.participants.len(),
287 "refreshed room"
288 );
289 room_updated(&refreshed_room.room, &peer);
290 contacts_to_update
291 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
292 contacts_to_update
293 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
294 canceled_calls_to_user_ids =
295 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
296 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
297 delete_live_kit_room = refreshed_room.room.participants.is_empty();
298 }
299
300 {
301 let pool = pool.lock();
302 for canceled_user_id in canceled_calls_to_user_ids {
303 for connection_id in pool.user_connection_ids(canceled_user_id) {
304 peer.send(
305 connection_id,
306 proto::CallCanceled {
307 room_id: room_id.to_proto(),
308 },
309 )
310 .trace_err();
311 }
312 }
313 }
314
315 for user_id in contacts_to_update {
316 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
317 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
318 if let Some((busy, contacts)) = busy.zip(contacts) {
319 let pool = pool.lock();
320 let updated_contact = contact_for_user(user_id, false, busy, &pool);
321 for contact in contacts {
322 if let db::Contact::Accepted {
323 user_id: contact_user_id,
324 ..
325 } = contact
326 {
327 for contact_conn_id in
328 pool.user_connection_ids(contact_user_id)
329 {
330 peer.send(
331 contact_conn_id,
332 proto::UpdateContacts {
333 contacts: vec![updated_contact.clone()],
334 remove_contacts: Default::default(),
335 incoming_requests: Default::default(),
336 remove_incoming_requests: Default::default(),
337 outgoing_requests: Default::default(),
338 remove_outgoing_requests: Default::default(),
339 },
340 )
341 .trace_err();
342 }
343 }
344 }
345 }
346 }
347
348 if let Some(live_kit) = live_kit_client.as_ref() {
349 if delete_live_kit_room {
350 live_kit.delete_room(live_kit_room).await.trace_err();
351 }
352 }
353 }
354 }
355
356 app_state
357 .db
358 .delete_stale_servers(server_id, &app_state.config.zed_environment)
359 .await
360 .trace_err();
361 }
362 .instrument(span),
363 );
364 Ok(())
365 }
366
367 pub fn teardown(&self) {
368 self.peer.teardown();
369 self.connection_pool.lock().reset();
370 let _ = self.teardown.send(());
371 }
372
373 #[cfg(test)]
374 pub fn reset(&self, id: ServerId) {
375 self.teardown();
376 *self.id.lock() = id;
377 self.peer.reset(id.0 as u32);
378 }
379
380 #[cfg(test)]
381 pub fn id(&self) -> ServerId {
382 *self.id.lock()
383 }
384
385 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
386 where
387 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
388 Fut: 'static + Send + Future<Output = Result<()>>,
389 M: EnvelopedMessage,
390 {
391 let prev_handler = self.handlers.insert(
392 TypeId::of::<M>(),
393 Box::new(move |envelope, session| {
394 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
395 let span = info_span!(
396 "handle message",
397 payload_type = envelope.payload_type_name()
398 );
399 span.in_scope(|| {
400 tracing::info!(
401 payload_type = envelope.payload_type_name(),
402 "message received"
403 );
404 });
405 let future = (handler)(*envelope, session);
406 async move {
407 if let Err(error) = future.await {
408 tracing::error!(%error, "error handling message");
409 }
410 }
411 .instrument(span)
412 .boxed()
413 }),
414 );
415 if prev_handler.is_some() {
416 panic!("registered a handler for the same message twice");
417 }
418 self
419 }
420
421 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
422 where
423 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
424 Fut: 'static + Send + Future<Output = Result<()>>,
425 M: EnvelopedMessage,
426 {
427 self.add_handler(move |envelope, session| handler(envelope.payload, session));
428 self
429 }
430
431 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
432 where
433 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
434 Fut: Send + Future<Output = Result<()>>,
435 M: RequestMessage,
436 {
437 let handler = Arc::new(handler);
438 self.add_handler(move |envelope, session| {
439 let receipt = envelope.receipt();
440 let handler = handler.clone();
441 async move {
442 let peer = session.peer.clone();
443 let responded = Arc::new(AtomicBool::default());
444 let response = Response {
445 peer: peer.clone(),
446 responded: responded.clone(),
447 receipt,
448 };
449 match (handler)(envelope.payload, response, session).await {
450 Ok(()) => {
451 if responded.load(std::sync::atomic::Ordering::SeqCst) {
452 Ok(())
453 } else {
454 Err(anyhow!("handler did not send a response"))?
455 }
456 }
457 Err(error) => {
458 peer.respond_with_error(
459 receipt,
460 proto::Error {
461 message: error.to_string(),
462 },
463 )?;
464 Err(error)
465 }
466 }
467 }
468 })
469 }
470
471 pub fn handle_connection(
472 self: &Arc<Self>,
473 connection: Connection,
474 address: String,
475 user: User,
476 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
477 executor: Executor,
478 ) -> impl Future<Output = Result<()>> {
479 let this = self.clone();
480 let user_id = user.id;
481 let login = user.github_login;
482 let span = info_span!("handle connection", %user_id, %login, %address);
483 let mut teardown = self.teardown.subscribe();
484 async move {
485 let (connection_id, handle_io, mut incoming_rx) = this
486 .peer
487 .add_connection(connection, {
488 let executor = executor.clone();
489 move |duration| executor.sleep(duration)
490 });
491
492 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
493 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
494 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
495
496 if let Some(send_connection_id) = send_connection_id.take() {
497 let _ = send_connection_id.send(connection_id);
498 }
499
500 if !user.connected_once {
501 this.peer.send(connection_id, proto::ShowContacts {})?;
502 this.app_state.db.set_user_connected_once(user_id, true).await?;
503 }
504
505 let (contacts, invite_code) = future::try_join(
506 this.app_state.db.get_contacts(user_id),
507 this.app_state.db.get_invite_code_for_user(user_id)
508 ).await?;
509
510 {
511 let mut pool = this.connection_pool.lock();
512 pool.add_connection(connection_id, user_id, user.admin);
513 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
514
515 if let Some((code, count)) = invite_code {
516 this.peer.send(connection_id, proto::UpdateInviteInfo {
517 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
518 count: count as u32,
519 })?;
520 }
521 }
522
523 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
524 this.peer.send(connection_id, incoming_call)?;
525 }
526
527 let session = Session {
528 user_id,
529 connection_id,
530 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
531 peer: this.peer.clone(),
532 connection_pool: this.connection_pool.clone(),
533 live_kit_client: this.app_state.live_kit_client.clone()
534 };
535 update_user_contacts(user_id, &session).await?;
536
537 let handle_io = handle_io.fuse();
538 futures::pin_mut!(handle_io);
539
540 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
541 // This prevents deadlocks when e.g., client A performs a request to client B and
542 // client B performs a request to client A. If both clients stop processing further
543 // messages until their respective request completes, they won't have a chance to
544 // respond to the other client's request and cause a deadlock.
545 //
546 // This arrangement ensures we will attempt to process earlier messages first, but fall
547 // back to processing messages arrived later in the spirit of making progress.
548 let mut foreground_message_handlers = FuturesUnordered::new();
549 loop {
550 let next_message = incoming_rx.next().fuse();
551 futures::pin_mut!(next_message);
552 futures::select_biased! {
553 _ = teardown.changed().fuse() => return Ok(()),
554 result = handle_io => {
555 if let Err(error) = result {
556 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
557 }
558 break;
559 }
560 _ = foreground_message_handlers.next() => {}
561 message = next_message => {
562 if let Some(message) = message {
563 let type_name = message.payload_type_name();
564 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
565 let span_enter = span.enter();
566 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
567 let is_background = message.is_background();
568 let handle_message = (handler)(message, session.clone());
569 drop(span_enter);
570
571 let handle_message = handle_message.instrument(span);
572 if is_background {
573 executor.spawn_detached(handle_message);
574 } else {
575 foreground_message_handlers.push(handle_message);
576 }
577 } else {
578 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
579 }
580 } else {
581 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
582 break;
583 }
584 }
585 }
586 }
587
588 drop(foreground_message_handlers);
589 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
590 if let Err(error) = sign_out(session, teardown, executor).await {
591 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
592 }
593
594 Ok(())
595 }.instrument(span)
596 }
597
598 pub async fn invite_code_redeemed(
599 self: &Arc<Self>,
600 inviter_id: UserId,
601 invitee_id: UserId,
602 ) -> Result<()> {
603 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
604 if let Some(code) = &user.invite_code {
605 let pool = self.connection_pool.lock();
606 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
607 for connection_id in pool.user_connection_ids(inviter_id) {
608 self.peer.send(
609 connection_id,
610 proto::UpdateContacts {
611 contacts: vec![invitee_contact.clone()],
612 ..Default::default()
613 },
614 )?;
615 self.peer.send(
616 connection_id,
617 proto::UpdateInviteInfo {
618 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
619 count: user.invite_count as u32,
620 },
621 )?;
622 }
623 }
624 }
625 Ok(())
626 }
627
628 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
629 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
630 if let Some(invite_code) = &user.invite_code {
631 let pool = self.connection_pool.lock();
632 for connection_id in pool.user_connection_ids(user_id) {
633 self.peer.send(
634 connection_id,
635 proto::UpdateInviteInfo {
636 url: format!(
637 "{}{}",
638 self.app_state.config.invite_link_prefix, invite_code
639 ),
640 count: user.invite_count as u32,
641 },
642 )?;
643 }
644 }
645 }
646 Ok(())
647 }
648
649 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
650 ServerSnapshot {
651 connection_pool: ConnectionPoolGuard {
652 guard: self.connection_pool.lock(),
653 _not_send: PhantomData,
654 },
655 peer: &self.peer,
656 }
657 }
658}
659
660impl<'a> Deref for ConnectionPoolGuard<'a> {
661 type Target = ConnectionPool;
662
663 fn deref(&self) -> &Self::Target {
664 &*self.guard
665 }
666}
667
668impl<'a> DerefMut for ConnectionPoolGuard<'a> {
669 fn deref_mut(&mut self) -> &mut Self::Target {
670 &mut *self.guard
671 }
672}
673
674impl<'a> Drop for ConnectionPoolGuard<'a> {
675 fn drop(&mut self) {
676 #[cfg(test)]
677 self.check_invariants();
678 }
679}
680
681fn broadcast<F>(
682 sender_id: ConnectionId,
683 receiver_ids: impl IntoIterator<Item = ConnectionId>,
684 mut f: F,
685) where
686 F: FnMut(ConnectionId) -> anyhow::Result<()>,
687{
688 for receiver_id in receiver_ids {
689 if receiver_id != sender_id {
690 f(receiver_id).trace_err();
691 }
692 }
693}
694
695lazy_static! {
696 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
697}
698
699pub struct ProtocolVersion(u32);
700
701impl Header for ProtocolVersion {
702 fn name() -> &'static HeaderName {
703 &ZED_PROTOCOL_VERSION
704 }
705
706 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
707 where
708 Self: Sized,
709 I: Iterator<Item = &'i axum::http::HeaderValue>,
710 {
711 let version = values
712 .next()
713 .ok_or_else(axum::headers::Error::invalid)?
714 .to_str()
715 .map_err(|_| axum::headers::Error::invalid())?
716 .parse()
717 .map_err(|_| axum::headers::Error::invalid())?;
718 Ok(Self(version))
719 }
720
721 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
722 values.extend([self.0.to_string().parse().unwrap()]);
723 }
724}
725
726pub fn routes(server: Arc<Server>) -> Router<Body> {
727 Router::new()
728 .route("/rpc", get(handle_websocket_request))
729 .layer(
730 ServiceBuilder::new()
731 .layer(Extension(server.app_state.clone()))
732 .layer(middleware::from_fn(auth::validate_header)),
733 )
734 .route("/metrics", get(handle_metrics))
735 .layer(Extension(server))
736}
737
738pub async fn handle_websocket_request(
739 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
740 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
741 Extension(server): Extension<Arc<Server>>,
742 Extension(user): Extension<User>,
743 ws: WebSocketUpgrade,
744) -> axum::response::Response {
745 if protocol_version != rpc::PROTOCOL_VERSION {
746 return (
747 StatusCode::UPGRADE_REQUIRED,
748 "client must be upgraded".to_string(),
749 )
750 .into_response();
751 }
752 let socket_address = socket_address.to_string();
753 ws.on_upgrade(move |socket| {
754 use util::ResultExt;
755 let socket = socket
756 .map_ok(to_tungstenite_message)
757 .err_into()
758 .with(|message| async move { Ok(to_axum_message(message)) });
759 let connection = Connection::new(Box::pin(socket));
760 async move {
761 server
762 .handle_connection(connection, socket_address, user, None, Executor::Production)
763 .await
764 .log_err();
765 }
766 })
767}
768
769pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
770 let connections = server
771 .connection_pool
772 .lock()
773 .connections()
774 .filter(|connection| !connection.admin)
775 .count();
776
777 METRIC_CONNECTIONS.set(connections as _);
778
779 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
780 METRIC_SHARED_PROJECTS.set(shared_projects as _);
781
782 let encoder = prometheus::TextEncoder::new();
783 let metric_families = prometheus::gather();
784 let encoded_metrics = encoder
785 .encode_to_string(&metric_families)
786 .map_err(|err| anyhow!("{}", err))?;
787 Ok(encoded_metrics)
788}
789
790#[instrument(err, skip(executor))]
791async fn sign_out(
792 session: Session,
793 mut teardown: watch::Receiver<()>,
794 executor: Executor,
795) -> Result<()> {
796 session.peer.disconnect(session.connection_id);
797 session
798 .connection_pool()
799 .await
800 .remove_connection(session.connection_id)?;
801
802 session
803 .db()
804 .await
805 .connection_lost(session.connection_id)
806 .await
807 .trace_err();
808
809 futures::select_biased! {
810 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
811 leave_room_for_session(&session).await.trace_err();
812
813 if !session
814 .connection_pool()
815 .await
816 .is_user_online(session.user_id)
817 {
818 let db = session.db().await;
819 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
820 room_updated(&room, &session.peer);
821 }
822 }
823 update_user_contacts(session.user_id, &session).await?;
824 }
825 _ = teardown.changed().fuse() => {}
826 }
827
828 Ok(())
829}
830
831async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
832 response.send(proto::Ack {})?;
833 Ok(())
834}
835
836async fn create_room(
837 _request: proto::CreateRoom,
838 response: Response<proto::CreateRoom>,
839 session: Session,
840) -> Result<()> {
841 let live_kit_room = nanoid::nanoid!(30);
842 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
843 if let Some(_) = live_kit
844 .create_room(live_kit_room.clone())
845 .await
846 .trace_err()
847 {
848 if let Some(token) = live_kit
849 .room_token(&live_kit_room, &session.user_id.to_string())
850 .trace_err()
851 {
852 Some(proto::LiveKitConnectionInfo {
853 server_url: live_kit.url().into(),
854 token,
855 })
856 } else {
857 None
858 }
859 } else {
860 None
861 }
862 } else {
863 None
864 };
865
866 {
867 let room = session
868 .db()
869 .await
870 .create_room(session.user_id, session.connection_id, &live_kit_room)
871 .await?;
872
873 response.send(proto::CreateRoomResponse {
874 room: Some(room.clone()),
875 live_kit_connection_info,
876 })?;
877 }
878
879 update_user_contacts(session.user_id, &session).await?;
880 Ok(())
881}
882
883async fn join_room(
884 request: proto::JoinRoom,
885 response: Response<proto::JoinRoom>,
886 session: Session,
887) -> Result<()> {
888 let room_id = RoomId::from_proto(request.id);
889 let room = {
890 let room = session
891 .db()
892 .await
893 .join_room(room_id, session.user_id, session.connection_id)
894 .await?;
895 room_updated(&room, &session.peer);
896 room.clone()
897 };
898
899 for connection_id in session
900 .connection_pool()
901 .await
902 .user_connection_ids(session.user_id)
903 {
904 session
905 .peer
906 .send(
907 connection_id,
908 proto::CallCanceled {
909 room_id: room_id.to_proto(),
910 },
911 )
912 .trace_err();
913 }
914
915 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
916 if let Some(token) = live_kit
917 .room_token(&room.live_kit_room, &session.user_id.to_string())
918 .trace_err()
919 {
920 Some(proto::LiveKitConnectionInfo {
921 server_url: live_kit.url().into(),
922 token,
923 })
924 } else {
925 None
926 }
927 } else {
928 None
929 };
930
931 response.send(proto::JoinRoomResponse {
932 room: Some(room),
933 live_kit_connection_info,
934 })?;
935
936 update_user_contacts(session.user_id, &session).await?;
937 Ok(())
938}
939
940async fn rejoin_room(
941 request: proto::RejoinRoom,
942 response: Response<proto::RejoinRoom>,
943 session: Session,
944) -> Result<()> {
945 {
946 let mut rejoined_room = session
947 .db()
948 .await
949 .rejoin_room(request, session.user_id, session.connection_id)
950 .await?;
951
952 response.send(proto::RejoinRoomResponse {
953 room: Some(rejoined_room.room.clone()),
954 reshared_projects: rejoined_room
955 .reshared_projects
956 .iter()
957 .map(|project| proto::ResharedProject {
958 id: project.id.to_proto(),
959 collaborators: project
960 .collaborators
961 .iter()
962 .map(|collaborator| collaborator.to_proto())
963 .collect(),
964 })
965 .collect(),
966 rejoined_projects: rejoined_room
967 .rejoined_projects
968 .iter()
969 .map(|rejoined_project| proto::RejoinedProject {
970 id: rejoined_project.id.to_proto(),
971 worktrees: rejoined_project
972 .worktrees
973 .iter()
974 .map(|worktree| proto::WorktreeMetadata {
975 id: worktree.id,
976 root_name: worktree.root_name.clone(),
977 visible: worktree.visible,
978 abs_path: worktree.abs_path.clone(),
979 })
980 .collect(),
981 collaborators: rejoined_project
982 .collaborators
983 .iter()
984 .map(|collaborator| collaborator.to_proto())
985 .collect(),
986 language_servers: rejoined_project.language_servers.clone(),
987 })
988 .collect(),
989 })?;
990 room_updated(&rejoined_room.room, &session.peer);
991
992 for project in &rejoined_room.reshared_projects {
993 for collaborator in &project.collaborators {
994 session
995 .peer
996 .send(
997 collaborator.connection_id,
998 proto::UpdateProjectCollaborator {
999 project_id: project.id.to_proto(),
1000 old_peer_id: Some(project.old_connection_id.into()),
1001 new_peer_id: Some(session.connection_id.into()),
1002 },
1003 )
1004 .trace_err();
1005 }
1006
1007 broadcast(
1008 session.connection_id,
1009 project
1010 .collaborators
1011 .iter()
1012 .map(|collaborator| collaborator.connection_id),
1013 |connection_id| {
1014 session.peer.forward_send(
1015 session.connection_id,
1016 connection_id,
1017 proto::UpdateProject {
1018 project_id: project.id.to_proto(),
1019 worktrees: project.worktrees.clone(),
1020 },
1021 )
1022 },
1023 );
1024 }
1025
1026 for project in &rejoined_room.rejoined_projects {
1027 for collaborator in &project.collaborators {
1028 session
1029 .peer
1030 .send(
1031 collaborator.connection_id,
1032 proto::UpdateProjectCollaborator {
1033 project_id: project.id.to_proto(),
1034 old_peer_id: Some(project.old_connection_id.into()),
1035 new_peer_id: Some(session.connection_id.into()),
1036 },
1037 )
1038 .trace_err();
1039 }
1040 }
1041
1042 for project in &mut rejoined_room.rejoined_projects {
1043 for worktree in mem::take(&mut project.worktrees) {
1044 #[cfg(any(test, feature = "test-support"))]
1045 const MAX_CHUNK_SIZE: usize = 2;
1046 #[cfg(not(any(test, feature = "test-support")))]
1047 const MAX_CHUNK_SIZE: usize = 256;
1048
1049 // Stream this worktree's entries.
1050 let message = proto::UpdateWorktree {
1051 project_id: project.id.to_proto(),
1052 worktree_id: worktree.id,
1053 abs_path: worktree.abs_path.clone(),
1054 root_name: worktree.root_name,
1055 updated_entries: worktree.updated_entries,
1056 removed_entries: worktree.removed_entries,
1057 scan_id: worktree.scan_id,
1058 is_last_update: worktree.is_complete,
1059 };
1060 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1061 session.peer.send(session.connection_id, update.clone())?;
1062 }
1063
1064 // Stream this worktree's diagnostics.
1065 for summary in worktree.diagnostic_summaries {
1066 session.peer.send(
1067 session.connection_id,
1068 proto::UpdateDiagnosticSummary {
1069 project_id: project.id.to_proto(),
1070 worktree_id: worktree.id,
1071 summary: Some(summary),
1072 },
1073 )?;
1074 }
1075 }
1076
1077 for language_server in &project.language_servers {
1078 session.peer.send(
1079 session.connection_id,
1080 proto::UpdateLanguageServer {
1081 project_id: project.id.to_proto(),
1082 language_server_id: language_server.id,
1083 variant: Some(
1084 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1085 proto::LspDiskBasedDiagnosticsUpdated {},
1086 ),
1087 ),
1088 },
1089 )?;
1090 }
1091 }
1092 }
1093
1094 update_user_contacts(session.user_id, &session).await?;
1095 Ok(())
1096}
1097
1098async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
1099 leave_room_for_session(&session).await
1100}
1101
1102async fn call(
1103 request: proto::Call,
1104 response: Response<proto::Call>,
1105 session: Session,
1106) -> Result<()> {
1107 let room_id = RoomId::from_proto(request.room_id);
1108 let calling_user_id = session.user_id;
1109 let calling_connection_id = session.connection_id;
1110 let called_user_id = UserId::from_proto(request.called_user_id);
1111 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1112 if !session
1113 .db()
1114 .await
1115 .has_contact(calling_user_id, called_user_id)
1116 .await?
1117 {
1118 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1119 }
1120
1121 let incoming_call = {
1122 let (room, incoming_call) = &mut *session
1123 .db()
1124 .await
1125 .call(
1126 room_id,
1127 calling_user_id,
1128 calling_connection_id,
1129 called_user_id,
1130 initial_project_id,
1131 )
1132 .await?;
1133 room_updated(&room, &session.peer);
1134 mem::take(incoming_call)
1135 };
1136 update_user_contacts(called_user_id, &session).await?;
1137
1138 let mut calls = session
1139 .connection_pool()
1140 .await
1141 .user_connection_ids(called_user_id)
1142 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1143 .collect::<FuturesUnordered<_>>();
1144
1145 while let Some(call_response) = calls.next().await {
1146 match call_response.as_ref() {
1147 Ok(_) => {
1148 response.send(proto::Ack {})?;
1149 return Ok(());
1150 }
1151 Err(_) => {
1152 call_response.trace_err();
1153 }
1154 }
1155 }
1156
1157 {
1158 let room = session
1159 .db()
1160 .await
1161 .call_failed(room_id, called_user_id)
1162 .await?;
1163 room_updated(&room, &session.peer);
1164 }
1165 update_user_contacts(called_user_id, &session).await?;
1166
1167 Err(anyhow!("failed to ring user"))?
1168}
1169
1170async fn cancel_call(
1171 request: proto::CancelCall,
1172 response: Response<proto::CancelCall>,
1173 session: Session,
1174) -> Result<()> {
1175 let called_user_id = UserId::from_proto(request.called_user_id);
1176 let room_id = RoomId::from_proto(request.room_id);
1177 {
1178 let room = session
1179 .db()
1180 .await
1181 .cancel_call(room_id, session.connection_id, called_user_id)
1182 .await?;
1183 room_updated(&room, &session.peer);
1184 }
1185
1186 for connection_id in session
1187 .connection_pool()
1188 .await
1189 .user_connection_ids(called_user_id)
1190 {
1191 session
1192 .peer
1193 .send(
1194 connection_id,
1195 proto::CallCanceled {
1196 room_id: room_id.to_proto(),
1197 },
1198 )
1199 .trace_err();
1200 }
1201 response.send(proto::Ack {})?;
1202
1203 update_user_contacts(called_user_id, &session).await?;
1204 Ok(())
1205}
1206
1207async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1208 let room_id = RoomId::from_proto(message.room_id);
1209 {
1210 let room = session
1211 .db()
1212 .await
1213 .decline_call(Some(room_id), session.user_id)
1214 .await?
1215 .ok_or_else(|| anyhow!("failed to decline call"))?;
1216 room_updated(&room, &session.peer);
1217 }
1218
1219 for connection_id in session
1220 .connection_pool()
1221 .await
1222 .user_connection_ids(session.user_id)
1223 {
1224 session
1225 .peer
1226 .send(
1227 connection_id,
1228 proto::CallCanceled {
1229 room_id: room_id.to_proto(),
1230 },
1231 )
1232 .trace_err();
1233 }
1234 update_user_contacts(session.user_id, &session).await?;
1235 Ok(())
1236}
1237
1238async fn update_participant_location(
1239 request: proto::UpdateParticipantLocation,
1240 response: Response<proto::UpdateParticipantLocation>,
1241 session: Session,
1242) -> Result<()> {
1243 let room_id = RoomId::from_proto(request.room_id);
1244 let location = request
1245 .location
1246 .ok_or_else(|| anyhow!("invalid location"))?;
1247 let room = session
1248 .db()
1249 .await
1250 .update_room_participant_location(room_id, session.connection_id, location)
1251 .await?;
1252 room_updated(&room, &session.peer);
1253 response.send(proto::Ack {})?;
1254 Ok(())
1255}
1256
1257async fn share_project(
1258 request: proto::ShareProject,
1259 response: Response<proto::ShareProject>,
1260 session: Session,
1261) -> Result<()> {
1262 let (project_id, room) = &*session
1263 .db()
1264 .await
1265 .share_project(
1266 RoomId::from_proto(request.room_id),
1267 session.connection_id,
1268 &request.worktrees,
1269 )
1270 .await?;
1271 response.send(proto::ShareProjectResponse {
1272 project_id: project_id.to_proto(),
1273 })?;
1274 room_updated(&room, &session.peer);
1275
1276 Ok(())
1277}
1278
1279async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1280 let project_id = ProjectId::from_proto(message.project_id);
1281
1282 let (room, guest_connection_ids) = &*session
1283 .db()
1284 .await
1285 .unshare_project(project_id, session.connection_id)
1286 .await?;
1287
1288 broadcast(
1289 session.connection_id,
1290 guest_connection_ids.iter().copied(),
1291 |conn_id| session.peer.send(conn_id, message.clone()),
1292 );
1293 room_updated(&room, &session.peer);
1294
1295 Ok(())
1296}
1297
1298async fn join_project(
1299 request: proto::JoinProject,
1300 response: Response<proto::JoinProject>,
1301 session: Session,
1302) -> Result<()> {
1303 let project_id = ProjectId::from_proto(request.project_id);
1304 let guest_user_id = session.user_id;
1305
1306 tracing::info!(%project_id, "join project");
1307
1308 let (project, replica_id) = &mut *session
1309 .db()
1310 .await
1311 .join_project(project_id, session.connection_id)
1312 .await?;
1313
1314 let collaborators = project
1315 .collaborators
1316 .iter()
1317 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1318 .map(|collaborator| collaborator.to_proto())
1319 .collect::<Vec<_>>();
1320 let worktrees = project
1321 .worktrees
1322 .iter()
1323 .map(|(id, worktree)| proto::WorktreeMetadata {
1324 id: *id,
1325 root_name: worktree.root_name.clone(),
1326 visible: worktree.visible,
1327 abs_path: worktree.abs_path.clone(),
1328 })
1329 .collect::<Vec<_>>();
1330
1331 for collaborator in &collaborators {
1332 session
1333 .peer
1334 .send(
1335 collaborator.peer_id.unwrap().into(),
1336 proto::AddProjectCollaborator {
1337 project_id: project_id.to_proto(),
1338 collaborator: Some(proto::Collaborator {
1339 peer_id: Some(session.connection_id.into()),
1340 replica_id: replica_id.0 as u32,
1341 user_id: guest_user_id.to_proto(),
1342 }),
1343 },
1344 )
1345 .trace_err();
1346 }
1347
1348 // First, we send the metadata associated with each worktree.
1349 response.send(proto::JoinProjectResponse {
1350 worktrees: worktrees.clone(),
1351 replica_id: replica_id.0 as u32,
1352 collaborators: collaborators.clone(),
1353 language_servers: project.language_servers.clone(),
1354 })?;
1355
1356 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1357 #[cfg(any(test, feature = "test-support"))]
1358 const MAX_CHUNK_SIZE: usize = 2;
1359 #[cfg(not(any(test, feature = "test-support")))]
1360 const MAX_CHUNK_SIZE: usize = 256;
1361
1362 // Stream this worktree's entries.
1363 let message = proto::UpdateWorktree {
1364 project_id: project_id.to_proto(),
1365 worktree_id,
1366 abs_path: worktree.abs_path.clone(),
1367 root_name: worktree.root_name,
1368 updated_entries: worktree.entries,
1369 removed_entries: Default::default(),
1370 scan_id: worktree.scan_id,
1371 is_last_update: worktree.is_complete,
1372 };
1373 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1374 session.peer.send(session.connection_id, update.clone())?;
1375 }
1376
1377 // Stream this worktree's diagnostics.
1378 for summary in worktree.diagnostic_summaries {
1379 session.peer.send(
1380 session.connection_id,
1381 proto::UpdateDiagnosticSummary {
1382 project_id: project_id.to_proto(),
1383 worktree_id: worktree.id,
1384 summary: Some(summary),
1385 },
1386 )?;
1387 }
1388 }
1389
1390 for language_server in &project.language_servers {
1391 session.peer.send(
1392 session.connection_id,
1393 proto::UpdateLanguageServer {
1394 project_id: project_id.to_proto(),
1395 language_server_id: language_server.id,
1396 variant: Some(
1397 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1398 proto::LspDiskBasedDiagnosticsUpdated {},
1399 ),
1400 ),
1401 },
1402 )?;
1403 }
1404
1405 Ok(())
1406}
1407
1408async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1409 let sender_id = session.connection_id;
1410 let project_id = ProjectId::from_proto(request.project_id);
1411
1412 let project = session
1413 .db()
1414 .await
1415 .leave_project(project_id, sender_id)
1416 .await?;
1417 tracing::info!(
1418 %project_id,
1419 host_user_id = %project.host_user_id,
1420 host_connection_id = %project.host_connection_id,
1421 "leave project"
1422 );
1423 project_left(&project, &session);
1424
1425 Ok(())
1426}
1427
1428async fn update_project(
1429 request: proto::UpdateProject,
1430 response: Response<proto::UpdateProject>,
1431 session: Session,
1432) -> Result<()> {
1433 let project_id = ProjectId::from_proto(request.project_id);
1434 let (room, guest_connection_ids) = &*session
1435 .db()
1436 .await
1437 .update_project(project_id, session.connection_id, &request.worktrees)
1438 .await?;
1439 broadcast(
1440 session.connection_id,
1441 guest_connection_ids.iter().copied(),
1442 |connection_id| {
1443 session
1444 .peer
1445 .forward_send(session.connection_id, connection_id, request.clone())
1446 },
1447 );
1448 room_updated(&room, &session.peer);
1449 response.send(proto::Ack {})?;
1450
1451 Ok(())
1452}
1453
1454async fn update_worktree(
1455 request: proto::UpdateWorktree,
1456 response: Response<proto::UpdateWorktree>,
1457 session: Session,
1458) -> Result<()> {
1459 let guest_connection_ids = session
1460 .db()
1461 .await
1462 .update_worktree(&request, session.connection_id)
1463 .await?;
1464
1465 broadcast(
1466 session.connection_id,
1467 guest_connection_ids.iter().copied(),
1468 |connection_id| {
1469 session
1470 .peer
1471 .forward_send(session.connection_id, connection_id, request.clone())
1472 },
1473 );
1474 response.send(proto::Ack {})?;
1475 Ok(())
1476}
1477
1478async fn update_diagnostic_summary(
1479 message: proto::UpdateDiagnosticSummary,
1480 session: Session,
1481) -> Result<()> {
1482 let guest_connection_ids = session
1483 .db()
1484 .await
1485 .update_diagnostic_summary(&message, session.connection_id)
1486 .await?;
1487
1488 broadcast(
1489 session.connection_id,
1490 guest_connection_ids.iter().copied(),
1491 |connection_id| {
1492 session
1493 .peer
1494 .forward_send(session.connection_id, connection_id, message.clone())
1495 },
1496 );
1497
1498 Ok(())
1499}
1500
1501async fn start_language_server(
1502 request: proto::StartLanguageServer,
1503 session: Session,
1504) -> Result<()> {
1505 let guest_connection_ids = session
1506 .db()
1507 .await
1508 .start_language_server(&request, session.connection_id)
1509 .await?;
1510
1511 broadcast(
1512 session.connection_id,
1513 guest_connection_ids.iter().copied(),
1514 |connection_id| {
1515 session
1516 .peer
1517 .forward_send(session.connection_id, connection_id, request.clone())
1518 },
1519 );
1520 Ok(())
1521}
1522
1523async fn update_language_server(
1524 request: proto::UpdateLanguageServer,
1525 session: Session,
1526) -> Result<()> {
1527 let project_id = ProjectId::from_proto(request.project_id);
1528 let project_connection_ids = session
1529 .db()
1530 .await
1531 .project_connection_ids(project_id, session.connection_id)
1532 .await?;
1533 broadcast(
1534 session.connection_id,
1535 project_connection_ids.iter().copied(),
1536 |connection_id| {
1537 session
1538 .peer
1539 .forward_send(session.connection_id, connection_id, request.clone())
1540 },
1541 );
1542 Ok(())
1543}
1544
1545async fn forward_project_request<T>(
1546 request: T,
1547 response: Response<T>,
1548 session: Session,
1549) -> Result<()>
1550where
1551 T: EntityMessage + RequestMessage,
1552{
1553 let project_id = ProjectId::from_proto(request.remote_entity_id());
1554 let host_connection_id = {
1555 let collaborators = session
1556 .db()
1557 .await
1558 .project_collaborators(project_id, session.connection_id)
1559 .await?;
1560 collaborators
1561 .iter()
1562 .find(|collaborator| collaborator.is_host)
1563 .ok_or_else(|| anyhow!("host not found"))?
1564 .connection_id
1565 };
1566
1567 let payload = session
1568 .peer
1569 .forward_request(session.connection_id, host_connection_id, request)
1570 .await?;
1571
1572 response.send(payload)?;
1573 Ok(())
1574}
1575
1576async fn save_buffer(
1577 request: proto::SaveBuffer,
1578 response: Response<proto::SaveBuffer>,
1579 session: Session,
1580) -> Result<()> {
1581 let project_id = ProjectId::from_proto(request.project_id);
1582 let host_connection_id = {
1583 let collaborators = session
1584 .db()
1585 .await
1586 .project_collaborators(project_id, session.connection_id)
1587 .await?;
1588 collaborators
1589 .iter()
1590 .find(|collaborator| collaborator.is_host)
1591 .ok_or_else(|| anyhow!("host not found"))?
1592 .connection_id
1593 };
1594 let response_payload = session
1595 .peer
1596 .forward_request(session.connection_id, host_connection_id, request.clone())
1597 .await?;
1598
1599 let mut collaborators = session
1600 .db()
1601 .await
1602 .project_collaborators(project_id, session.connection_id)
1603 .await?;
1604 collaborators.retain(|collaborator| collaborator.connection_id != session.connection_id);
1605 let project_connection_ids = collaborators
1606 .iter()
1607 .map(|collaborator| collaborator.connection_id);
1608 broadcast(host_connection_id, project_connection_ids, |conn_id| {
1609 session
1610 .peer
1611 .forward_send(host_connection_id, conn_id, response_payload.clone())
1612 });
1613 response.send(response_payload)?;
1614 Ok(())
1615}
1616
1617async fn create_buffer_for_peer(
1618 request: proto::CreateBufferForPeer,
1619 session: Session,
1620) -> Result<()> {
1621 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1622 session
1623 .peer
1624 .forward_send(session.connection_id, peer_id.into(), request)?;
1625 Ok(())
1626}
1627
1628async fn update_buffer(
1629 request: proto::UpdateBuffer,
1630 response: Response<proto::UpdateBuffer>,
1631 session: Session,
1632) -> Result<()> {
1633 let project_id = ProjectId::from_proto(request.project_id);
1634 let project_connection_ids = session
1635 .db()
1636 .await
1637 .project_connection_ids(project_id, session.connection_id)
1638 .await?;
1639
1640 broadcast(
1641 session.connection_id,
1642 project_connection_ids.iter().copied(),
1643 |connection_id| {
1644 session
1645 .peer
1646 .forward_send(session.connection_id, connection_id, request.clone())
1647 },
1648 );
1649 response.send(proto::Ack {})?;
1650 Ok(())
1651}
1652
1653async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1654 let project_id = ProjectId::from_proto(request.project_id);
1655 let project_connection_ids = session
1656 .db()
1657 .await
1658 .project_connection_ids(project_id, session.connection_id)
1659 .await?;
1660
1661 broadcast(
1662 session.connection_id,
1663 project_connection_ids.iter().copied(),
1664 |connection_id| {
1665 session
1666 .peer
1667 .forward_send(session.connection_id, connection_id, request.clone())
1668 },
1669 );
1670 Ok(())
1671}
1672
1673async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1674 let project_id = ProjectId::from_proto(request.project_id);
1675 let project_connection_ids = session
1676 .db()
1677 .await
1678 .project_connection_ids(project_id, session.connection_id)
1679 .await?;
1680 broadcast(
1681 session.connection_id,
1682 project_connection_ids.iter().copied(),
1683 |connection_id| {
1684 session
1685 .peer
1686 .forward_send(session.connection_id, connection_id, request.clone())
1687 },
1688 );
1689 Ok(())
1690}
1691
1692async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1693 let project_id = ProjectId::from_proto(request.project_id);
1694 let project_connection_ids = session
1695 .db()
1696 .await
1697 .project_connection_ids(project_id, session.connection_id)
1698 .await?;
1699 broadcast(
1700 session.connection_id,
1701 project_connection_ids.iter().copied(),
1702 |connection_id| {
1703 session
1704 .peer
1705 .forward_send(session.connection_id, connection_id, request.clone())
1706 },
1707 );
1708 Ok(())
1709}
1710
1711async fn follow(
1712 request: proto::Follow,
1713 response: Response<proto::Follow>,
1714 session: Session,
1715) -> Result<()> {
1716 let project_id = ProjectId::from_proto(request.project_id);
1717 let leader_id = request
1718 .leader_id
1719 .ok_or_else(|| anyhow!("invalid leader id"))?
1720 .into();
1721 let follower_id = session.connection_id;
1722 {
1723 let project_connection_ids = session
1724 .db()
1725 .await
1726 .project_connection_ids(project_id, session.connection_id)
1727 .await?;
1728
1729 if !project_connection_ids.contains(&leader_id) {
1730 Err(anyhow!("no such peer"))?;
1731 }
1732 }
1733
1734 let mut response_payload = session
1735 .peer
1736 .forward_request(session.connection_id, leader_id, request)
1737 .await?;
1738 response_payload
1739 .views
1740 .retain(|view| view.leader_id != Some(follower_id.into()));
1741 response.send(response_payload)?;
1742 Ok(())
1743}
1744
1745async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1746 let project_id = ProjectId::from_proto(request.project_id);
1747 let leader_id = request
1748 .leader_id
1749 .ok_or_else(|| anyhow!("invalid leader id"))?
1750 .into();
1751 let project_connection_ids = session
1752 .db()
1753 .await
1754 .project_connection_ids(project_id, session.connection_id)
1755 .await?;
1756 if !project_connection_ids.contains(&leader_id) {
1757 Err(anyhow!("no such peer"))?;
1758 }
1759 session
1760 .peer
1761 .forward_send(session.connection_id, leader_id, request)?;
1762 Ok(())
1763}
1764
1765async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1766 let project_id = ProjectId::from_proto(request.project_id);
1767 let project_connection_ids = session
1768 .db
1769 .lock()
1770 .await
1771 .project_connection_ids(project_id, session.connection_id)
1772 .await?;
1773
1774 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1775 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1776 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1777 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1778 });
1779 for follower_peer_id in request.follower_ids.iter().copied() {
1780 let follower_connection_id = follower_peer_id.into();
1781 if project_connection_ids.contains(&follower_connection_id)
1782 && Some(follower_peer_id) != leader_id
1783 {
1784 session.peer.forward_send(
1785 session.connection_id,
1786 follower_connection_id,
1787 request.clone(),
1788 )?;
1789 }
1790 }
1791 Ok(())
1792}
1793
1794async fn get_users(
1795 request: proto::GetUsers,
1796 response: Response<proto::GetUsers>,
1797 session: Session,
1798) -> Result<()> {
1799 let user_ids = request
1800 .user_ids
1801 .into_iter()
1802 .map(UserId::from_proto)
1803 .collect();
1804 let users = session
1805 .db()
1806 .await
1807 .get_users_by_ids(user_ids)
1808 .await?
1809 .into_iter()
1810 .map(|user| proto::User {
1811 id: user.id.to_proto(),
1812 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1813 github_login: user.github_login,
1814 })
1815 .collect();
1816 response.send(proto::UsersResponse { users })?;
1817 Ok(())
1818}
1819
1820async fn fuzzy_search_users(
1821 request: proto::FuzzySearchUsers,
1822 response: Response<proto::FuzzySearchUsers>,
1823 session: Session,
1824) -> Result<()> {
1825 let query = request.query;
1826 let users = match query.len() {
1827 0 => vec![],
1828 1 | 2 => session
1829 .db()
1830 .await
1831 .get_user_by_github_account(&query, None)
1832 .await?
1833 .into_iter()
1834 .collect(),
1835 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1836 };
1837 let users = users
1838 .into_iter()
1839 .filter(|user| user.id != session.user_id)
1840 .map(|user| proto::User {
1841 id: user.id.to_proto(),
1842 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1843 github_login: user.github_login,
1844 })
1845 .collect();
1846 response.send(proto::UsersResponse { users })?;
1847 Ok(())
1848}
1849
1850async fn request_contact(
1851 request: proto::RequestContact,
1852 response: Response<proto::RequestContact>,
1853 session: Session,
1854) -> Result<()> {
1855 let requester_id = session.user_id;
1856 let responder_id = UserId::from_proto(request.responder_id);
1857 if requester_id == responder_id {
1858 return Err(anyhow!("cannot add yourself as a contact"))?;
1859 }
1860
1861 session
1862 .db()
1863 .await
1864 .send_contact_request(requester_id, responder_id)
1865 .await?;
1866
1867 // Update outgoing contact requests of requester
1868 let mut update = proto::UpdateContacts::default();
1869 update.outgoing_requests.push(responder_id.to_proto());
1870 for connection_id in session
1871 .connection_pool()
1872 .await
1873 .user_connection_ids(requester_id)
1874 {
1875 session.peer.send(connection_id, update.clone())?;
1876 }
1877
1878 // Update incoming contact requests of responder
1879 let mut update = proto::UpdateContacts::default();
1880 update
1881 .incoming_requests
1882 .push(proto::IncomingContactRequest {
1883 requester_id: requester_id.to_proto(),
1884 should_notify: true,
1885 });
1886 for connection_id in session
1887 .connection_pool()
1888 .await
1889 .user_connection_ids(responder_id)
1890 {
1891 session.peer.send(connection_id, update.clone())?;
1892 }
1893
1894 response.send(proto::Ack {})?;
1895 Ok(())
1896}
1897
1898async fn respond_to_contact_request(
1899 request: proto::RespondToContactRequest,
1900 response: Response<proto::RespondToContactRequest>,
1901 session: Session,
1902) -> Result<()> {
1903 let responder_id = session.user_id;
1904 let requester_id = UserId::from_proto(request.requester_id);
1905 let db = session.db().await;
1906 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1907 db.dismiss_contact_notification(responder_id, requester_id)
1908 .await?;
1909 } else {
1910 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1911
1912 db.respond_to_contact_request(responder_id, requester_id, accept)
1913 .await?;
1914 let requester_busy = db.is_user_busy(requester_id).await?;
1915 let responder_busy = db.is_user_busy(responder_id).await?;
1916
1917 let pool = session.connection_pool().await;
1918 // Update responder with new contact
1919 let mut update = proto::UpdateContacts::default();
1920 if accept {
1921 update
1922 .contacts
1923 .push(contact_for_user(requester_id, false, requester_busy, &pool));
1924 }
1925 update
1926 .remove_incoming_requests
1927 .push(requester_id.to_proto());
1928 for connection_id in pool.user_connection_ids(responder_id) {
1929 session.peer.send(connection_id, update.clone())?;
1930 }
1931
1932 // Update requester with new contact
1933 let mut update = proto::UpdateContacts::default();
1934 if accept {
1935 update
1936 .contacts
1937 .push(contact_for_user(responder_id, true, responder_busy, &pool));
1938 }
1939 update
1940 .remove_outgoing_requests
1941 .push(responder_id.to_proto());
1942 for connection_id in pool.user_connection_ids(requester_id) {
1943 session.peer.send(connection_id, update.clone())?;
1944 }
1945 }
1946
1947 response.send(proto::Ack {})?;
1948 Ok(())
1949}
1950
1951async fn remove_contact(
1952 request: proto::RemoveContact,
1953 response: Response<proto::RemoveContact>,
1954 session: Session,
1955) -> Result<()> {
1956 let requester_id = session.user_id;
1957 let responder_id = UserId::from_proto(request.user_id);
1958 let db = session.db().await;
1959 db.remove_contact(requester_id, responder_id).await?;
1960
1961 let pool = session.connection_pool().await;
1962 // Update outgoing contact requests of requester
1963 let mut update = proto::UpdateContacts::default();
1964 update
1965 .remove_outgoing_requests
1966 .push(responder_id.to_proto());
1967 for connection_id in pool.user_connection_ids(requester_id) {
1968 session.peer.send(connection_id, update.clone())?;
1969 }
1970
1971 // Update incoming contact requests of responder
1972 let mut update = proto::UpdateContacts::default();
1973 update
1974 .remove_incoming_requests
1975 .push(requester_id.to_proto());
1976 for connection_id in pool.user_connection_ids(responder_id) {
1977 session.peer.send(connection_id, update.clone())?;
1978 }
1979
1980 response.send(proto::Ack {})?;
1981 Ok(())
1982}
1983
1984async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1985 let project_id = ProjectId::from_proto(request.project_id);
1986 let project_connection_ids = session
1987 .db()
1988 .await
1989 .project_connection_ids(project_id, session.connection_id)
1990 .await?;
1991 broadcast(
1992 session.connection_id,
1993 project_connection_ids.iter().copied(),
1994 |connection_id| {
1995 session
1996 .peer
1997 .forward_send(session.connection_id, connection_id, request.clone())
1998 },
1999 );
2000 Ok(())
2001}
2002
2003async fn get_private_user_info(
2004 _request: proto::GetPrivateUserInfo,
2005 response: Response<proto::GetPrivateUserInfo>,
2006 session: Session,
2007) -> Result<()> {
2008 let metrics_id = session
2009 .db()
2010 .await
2011 .get_user_metrics_id(session.user_id)
2012 .await?;
2013 let user = session
2014 .db()
2015 .await
2016 .get_user_by_id(session.user_id)
2017 .await?
2018 .ok_or_else(|| anyhow!("user not found"))?;
2019 response.send(proto::GetPrivateUserInfoResponse {
2020 metrics_id,
2021 staff: user.admin,
2022 })?;
2023 Ok(())
2024}
2025
2026fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2027 match message {
2028 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2029 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2030 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2031 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2032 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2033 code: frame.code.into(),
2034 reason: frame.reason,
2035 })),
2036 }
2037}
2038
2039fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2040 match message {
2041 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2042 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2043 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2044 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2045 AxumMessage::Close(frame) => {
2046 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2047 code: frame.code.into(),
2048 reason: frame.reason,
2049 }))
2050 }
2051 }
2052}
2053
2054fn build_initial_contacts_update(
2055 contacts: Vec<db::Contact>,
2056 pool: &ConnectionPool,
2057) -> proto::UpdateContacts {
2058 let mut update = proto::UpdateContacts::default();
2059
2060 for contact in contacts {
2061 match contact {
2062 db::Contact::Accepted {
2063 user_id,
2064 should_notify,
2065 busy,
2066 } => {
2067 update
2068 .contacts
2069 .push(contact_for_user(user_id, should_notify, busy, &pool));
2070 }
2071 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2072 db::Contact::Incoming {
2073 user_id,
2074 should_notify,
2075 } => update
2076 .incoming_requests
2077 .push(proto::IncomingContactRequest {
2078 requester_id: user_id.to_proto(),
2079 should_notify,
2080 }),
2081 }
2082 }
2083
2084 update
2085}
2086
2087fn contact_for_user(
2088 user_id: UserId,
2089 should_notify: bool,
2090 busy: bool,
2091 pool: &ConnectionPool,
2092) -> proto::Contact {
2093 proto::Contact {
2094 user_id: user_id.to_proto(),
2095 online: pool.is_user_online(user_id),
2096 busy,
2097 should_notify,
2098 }
2099}
2100
2101fn room_updated(room: &proto::Room, peer: &Peer) {
2102 for participant in &room.participants {
2103 if let Some(peer_id) = participant
2104 .peer_id
2105 .ok_or_else(|| anyhow!("invalid participant peer id"))
2106 .trace_err()
2107 {
2108 peer.send(
2109 peer_id.into(),
2110 proto::RoomUpdated {
2111 room: Some(room.clone()),
2112 },
2113 )
2114 .trace_err();
2115 }
2116 }
2117}
2118
2119async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2120 let db = session.db().await;
2121 let contacts = db.get_contacts(user_id).await?;
2122 let busy = db.is_user_busy(user_id).await?;
2123
2124 let pool = session.connection_pool().await;
2125 let updated_contact = contact_for_user(user_id, false, busy, &pool);
2126 for contact in contacts {
2127 if let db::Contact::Accepted {
2128 user_id: contact_user_id,
2129 ..
2130 } = contact
2131 {
2132 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2133 session
2134 .peer
2135 .send(
2136 contact_conn_id,
2137 proto::UpdateContacts {
2138 contacts: vec![updated_contact.clone()],
2139 remove_contacts: Default::default(),
2140 incoming_requests: Default::default(),
2141 remove_incoming_requests: Default::default(),
2142 outgoing_requests: Default::default(),
2143 remove_outgoing_requests: Default::default(),
2144 },
2145 )
2146 .trace_err();
2147 }
2148 }
2149 }
2150 Ok(())
2151}
2152
2153async fn leave_room_for_session(session: &Session) -> Result<()> {
2154 let mut contacts_to_update = HashSet::default();
2155
2156 let room_id;
2157 let canceled_calls_to_user_ids;
2158 let live_kit_room;
2159 let delete_live_kit_room;
2160 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2161 contacts_to_update.insert(session.user_id);
2162
2163 for project in left_room.left_projects.values() {
2164 project_left(project, session);
2165 }
2166
2167 room_updated(&left_room.room, &session.peer);
2168 room_id = RoomId::from_proto(left_room.room.id);
2169 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2170 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2171 delete_live_kit_room = left_room.room.participants.is_empty();
2172 } else {
2173 return Ok(());
2174 }
2175
2176 {
2177 let pool = session.connection_pool().await;
2178 for canceled_user_id in canceled_calls_to_user_ids {
2179 for connection_id in pool.user_connection_ids(canceled_user_id) {
2180 session
2181 .peer
2182 .send(
2183 connection_id,
2184 proto::CallCanceled {
2185 room_id: room_id.to_proto(),
2186 },
2187 )
2188 .trace_err();
2189 }
2190 contacts_to_update.insert(canceled_user_id);
2191 }
2192 }
2193
2194 for contact_user_id in contacts_to_update {
2195 update_user_contacts(contact_user_id, &session).await?;
2196 }
2197
2198 if let Some(live_kit) = session.live_kit_client.as_ref() {
2199 live_kit
2200 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2201 .await
2202 .trace_err();
2203
2204 if delete_live_kit_room {
2205 live_kit.delete_room(live_kit_room).await.trace_err();
2206 }
2207 }
2208
2209 Ok(())
2210}
2211
2212fn project_left(project: &db::LeftProject, session: &Session) {
2213 for connection_id in &project.connection_ids {
2214 if project.host_user_id == session.user_id {
2215 session
2216 .peer
2217 .send(
2218 *connection_id,
2219 proto::UnshareProject {
2220 project_id: project.id.to_proto(),
2221 },
2222 )
2223 .trace_err();
2224 } else {
2225 session
2226 .peer
2227 .send(
2228 *connection_id,
2229 proto::RemoveProjectCollaborator {
2230 project_id: project.id.to_proto(),
2231 peer_id: Some(session.connection_id.into()),
2232 },
2233 )
2234 .trace_err();
2235 }
2236 }
2237
2238 session
2239 .peer
2240 .send(
2241 session.connection_id,
2242 proto::UnshareProject {
2243 project_id: project.id.to_proto(),
2244 },
2245 )
2246 .trace_err();
2247}
2248
2249pub trait ResultExt {
2250 type Ok;
2251
2252 fn trace_err(self) -> Option<Self::Ok>;
2253}
2254
2255impl<T, E> ResultExt for Result<T, E>
2256where
2257 E: std::fmt::Debug,
2258{
2259 type Ok = T;
2260
2261 fn trace_err(self) -> Option<T> {
2262 match self {
2263 Ok(value) => Some(value),
2264 Err(error) => {
2265 tracing::error!("{:?}", error);
2266 None
2267 }
2268 }
2269 }
2270}