1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, Database, ProjectId, RoomId, ServerId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
38 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
39};
40use serde::{Serialize, Serializer};
41use std::{
42 any::TypeId,
43 fmt,
44 future::Future,
45 marker::PhantomData,
46 mem,
47 net::SocketAddr,
48 ops::{Deref, DerefMut},
49 rc::Rc,
50 sync::{
51 atomic::{AtomicBool, Ordering::SeqCst},
52 Arc,
53 },
54 time::Duration,
55};
56use tokio::sync::watch;
57use tower::ServiceBuilder;
58use tracing::{info_span, instrument, Instrument};
59
60pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(5);
61pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
62
63lazy_static! {
64 static ref METRIC_CONNECTIONS: IntGauge =
65 register_int_gauge!("connections", "number of connections").unwrap();
66 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
67 "shared_projects",
68 "number of open projects with one or more guests"
69 )
70 .unwrap();
71}
72
73type MessageHandler =
74 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
75
76struct Response<R> {
77 peer: Arc<Peer>,
78 receipt: Receipt<R>,
79 responded: Arc<AtomicBool>,
80}
81
82impl<R: RequestMessage> Response<R> {
83 fn send(self, payload: R::Response) -> Result<()> {
84 self.responded.store(true, SeqCst);
85 self.peer.respond(self.receipt, payload)?;
86 Ok(())
87 }
88}
89
90#[derive(Clone)]
91struct Session {
92 user_id: UserId,
93 connection_id: ConnectionId,
94 db: Arc<tokio::sync::Mutex<DbHandle>>,
95 peer: Arc<Peer>,
96 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
97 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
98}
99
100impl Session {
101 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
102 #[cfg(test)]
103 tokio::task::yield_now().await;
104 let guard = self.db.lock().await;
105 #[cfg(test)]
106 tokio::task::yield_now().await;
107 guard
108 }
109
110 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
111 #[cfg(test)]
112 tokio::task::yield_now().await;
113 let guard = self.connection_pool.lock();
114 ConnectionPoolGuard {
115 guard,
116 _not_send: PhantomData,
117 }
118 }
119}
120
121impl fmt::Debug for Session {
122 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123 f.debug_struct("Session")
124 .field("user_id", &self.user_id)
125 .field("connection_id", &self.connection_id)
126 .finish()
127 }
128}
129
130struct DbHandle(Arc<Database>);
131
132impl Deref for DbHandle {
133 type Target = Database;
134
135 fn deref(&self) -> &Self::Target {
136 self.0.as_ref()
137 }
138}
139
140pub struct Server {
141 id: parking_lot::Mutex<ServerId>,
142 peer: Arc<Peer>,
143 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
144 app_state: Arc<AppState>,
145 executor: Executor,
146 handlers: HashMap<TypeId, MessageHandler>,
147 teardown: watch::Sender<()>,
148}
149
150pub(crate) struct ConnectionPoolGuard<'a> {
151 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
152 _not_send: PhantomData<Rc<()>>,
153}
154
155#[derive(Serialize)]
156pub struct ServerSnapshot<'a> {
157 peer: &'a Peer,
158 #[serde(serialize_with = "serialize_deref")]
159 connection_pool: ConnectionPoolGuard<'a>,
160}
161
162pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
163where
164 S: Serializer,
165 T: Deref<Target = U>,
166 U: Serialize,
167{
168 Serialize::serialize(value.deref(), serializer)
169}
170
171impl Server {
172 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
173 let mut server = Self {
174 id: parking_lot::Mutex::new(id),
175 peer: Peer::new(id.0 as u32),
176 app_state,
177 executor,
178 connection_pool: Default::default(),
179 handlers: Default::default(),
180 teardown: watch::channel(()).0,
181 };
182
183 server
184 .add_request_handler(ping)
185 .add_request_handler(create_room)
186 .add_request_handler(join_room)
187 .add_request_handler(rejoin_room)
188 .add_message_handler(leave_room)
189 .add_request_handler(call)
190 .add_request_handler(cancel_call)
191 .add_message_handler(decline_call)
192 .add_request_handler(update_participant_location)
193 .add_request_handler(share_project)
194 .add_message_handler(unshare_project)
195 .add_request_handler(join_project)
196 .add_message_handler(leave_project)
197 .add_request_handler(update_project)
198 .add_request_handler(update_worktree)
199 .add_message_handler(start_language_server)
200 .add_message_handler(update_language_server)
201 .add_message_handler(update_diagnostic_summary)
202 .add_request_handler(forward_project_request::<proto::GetHover>)
203 .add_request_handler(forward_project_request::<proto::GetDefinition>)
204 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
205 .add_request_handler(forward_project_request::<proto::GetReferences>)
206 .add_request_handler(forward_project_request::<proto::SearchProject>)
207 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
208 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
209 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
210 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
211 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
212 .add_request_handler(forward_project_request::<proto::GetCompletions>)
213 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
214 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
215 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
216 .add_request_handler(forward_project_request::<proto::PrepareRename>)
217 .add_request_handler(forward_project_request::<proto::PerformRename>)
218 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
219 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
220 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
221 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
222 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
223 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
224 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
225 .add_message_handler(create_buffer_for_peer)
226 .add_request_handler(update_buffer)
227 .add_message_handler(update_buffer_file)
228 .add_message_handler(buffer_reloaded)
229 .add_message_handler(buffer_saved)
230 .add_request_handler(save_buffer)
231 .add_request_handler(get_users)
232 .add_request_handler(fuzzy_search_users)
233 .add_request_handler(request_contact)
234 .add_request_handler(remove_contact)
235 .add_request_handler(respond_to_contact_request)
236 .add_request_handler(follow)
237 .add_message_handler(unfollow)
238 .add_message_handler(update_followers)
239 .add_message_handler(update_diff_base)
240 .add_request_handler(get_private_user_info);
241
242 Arc::new(server)
243 }
244
245 pub async fn start(&self) -> Result<()> {
246 let server_id = *self.id.lock();
247 let app_state = self.app_state.clone();
248 let peer = self.peer.clone();
249 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
250 let pool = self.connection_pool.clone();
251 let live_kit_client = self.app_state.live_kit_client.clone();
252
253 let span = info_span!("start server");
254 self.executor.spawn_detached(
255 async move {
256 tracing::info!("waiting for cleanup timeout");
257 timeout.await;
258 tracing::info!("cleanup timeout expired, retrieving stale rooms");
259 if let Some(room_ids) = app_state
260 .db
261 .stale_room_ids(&app_state.config.zed_environment, server_id)
262 .await
263 .trace_err()
264 {
265 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
266 for room_id in room_ids {
267 let mut contacts_to_update = HashSet::default();
268 let mut canceled_calls_to_user_ids = Vec::new();
269 let mut live_kit_room = String::new();
270 let mut delete_live_kit_room = false;
271
272 if let Ok(mut refreshed_room) =
273 app_state.db.refresh_room(room_id, server_id).await
274 {
275 tracing::info!(
276 room_id = room_id.0,
277 new_participant_count = refreshed_room.room.participants.len(),
278 "refreshed room"
279 );
280 room_updated(&refreshed_room.room, &peer);
281 contacts_to_update
282 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
283 contacts_to_update
284 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
285 canceled_calls_to_user_ids =
286 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
287 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
288 delete_live_kit_room = refreshed_room.room.participants.is_empty();
289 }
290
291 {
292 let pool = pool.lock();
293 for canceled_user_id in canceled_calls_to_user_ids {
294 for connection_id in pool.user_connection_ids(canceled_user_id) {
295 peer.send(
296 connection_id,
297 proto::CallCanceled {
298 room_id: room_id.to_proto(),
299 },
300 )
301 .trace_err();
302 }
303 }
304 }
305
306 for user_id in contacts_to_update {
307 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
308 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
309 if let Some((busy, contacts)) = busy.zip(contacts) {
310 let pool = pool.lock();
311 let updated_contact = contact_for_user(user_id, false, busy, &pool);
312 for contact in contacts {
313 if let db::Contact::Accepted {
314 user_id: contact_user_id,
315 ..
316 } = contact
317 {
318 for contact_conn_id in
319 pool.user_connection_ids(contact_user_id)
320 {
321 peer.send(
322 contact_conn_id,
323 proto::UpdateContacts {
324 contacts: vec![updated_contact.clone()],
325 remove_contacts: Default::default(),
326 incoming_requests: Default::default(),
327 remove_incoming_requests: Default::default(),
328 outgoing_requests: Default::default(),
329 remove_outgoing_requests: Default::default(),
330 },
331 )
332 .trace_err();
333 }
334 }
335 }
336 }
337 }
338
339 if let Some(live_kit) = live_kit_client.as_ref() {
340 if delete_live_kit_room {
341 live_kit.delete_room(live_kit_room).await.trace_err();
342 }
343 }
344 }
345 }
346
347 app_state
348 .db
349 .delete_stale_servers(&app_state.config.zed_environment, server_id)
350 .await
351 .trace_err();
352 }
353 .instrument(span),
354 );
355 Ok(())
356 }
357
358 pub fn teardown(&self) {
359 self.peer.teardown();
360 self.connection_pool.lock().reset();
361 let _ = self.teardown.send(());
362 }
363
364 #[cfg(test)]
365 pub fn reset(&self, id: ServerId) {
366 self.teardown();
367 *self.id.lock() = id;
368 self.peer.reset(id.0 as u32);
369 }
370
371 #[cfg(test)]
372 pub fn id(&self) -> ServerId {
373 *self.id.lock()
374 }
375
376 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
377 where
378 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
379 Fut: 'static + Send + Future<Output = Result<()>>,
380 M: EnvelopedMessage,
381 {
382 let prev_handler = self.handlers.insert(
383 TypeId::of::<M>(),
384 Box::new(move |envelope, session| {
385 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
386 let span = info_span!(
387 "handle message",
388 payload_type = envelope.payload_type_name()
389 );
390 span.in_scope(|| {
391 tracing::info!(
392 payload_type = envelope.payload_type_name(),
393 "message received"
394 );
395 });
396 let future = (handler)(*envelope, session);
397 async move {
398 if let Err(error) = future.await {
399 tracing::error!(%error, "error handling message");
400 }
401 }
402 .instrument(span)
403 .boxed()
404 }),
405 );
406 if prev_handler.is_some() {
407 panic!("registered a handler for the same message twice");
408 }
409 self
410 }
411
412 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
413 where
414 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
415 Fut: 'static + Send + Future<Output = Result<()>>,
416 M: EnvelopedMessage,
417 {
418 self.add_handler(move |envelope, session| handler(envelope.payload, session));
419 self
420 }
421
422 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
423 where
424 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
425 Fut: Send + Future<Output = Result<()>>,
426 M: RequestMessage,
427 {
428 let handler = Arc::new(handler);
429 self.add_handler(move |envelope, session| {
430 let receipt = envelope.receipt();
431 let handler = handler.clone();
432 async move {
433 let peer = session.peer.clone();
434 let responded = Arc::new(AtomicBool::default());
435 let response = Response {
436 peer: peer.clone(),
437 responded: responded.clone(),
438 receipt,
439 };
440 match (handler)(envelope.payload, response, session).await {
441 Ok(()) => {
442 if responded.load(std::sync::atomic::Ordering::SeqCst) {
443 Ok(())
444 } else {
445 Err(anyhow!("handler did not send a response"))?
446 }
447 }
448 Err(error) => {
449 peer.respond_with_error(
450 receipt,
451 proto::Error {
452 message: error.to_string(),
453 },
454 )?;
455 Err(error)
456 }
457 }
458 }
459 })
460 }
461
462 pub fn handle_connection(
463 self: &Arc<Self>,
464 connection: Connection,
465 address: String,
466 user: User,
467 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
468 executor: Executor,
469 ) -> impl Future<Output = Result<()>> {
470 let this = self.clone();
471 let user_id = user.id;
472 let login = user.github_login;
473 let span = info_span!("handle connection", %user_id, %login, %address);
474 let mut teardown = self.teardown.subscribe();
475 async move {
476 let (connection_id, handle_io, mut incoming_rx) = this
477 .peer
478 .add_connection(connection, {
479 let executor = executor.clone();
480 move |duration| executor.sleep(duration)
481 });
482
483 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
484 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
485 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
486
487 if let Some(send_connection_id) = send_connection_id.take() {
488 let _ = send_connection_id.send(connection_id);
489 }
490
491 if !user.connected_once {
492 this.peer.send(connection_id, proto::ShowContacts {})?;
493 this.app_state.db.set_user_connected_once(user_id, true).await?;
494 }
495
496 let (contacts, invite_code) = future::try_join(
497 this.app_state.db.get_contacts(user_id),
498 this.app_state.db.get_invite_code_for_user(user_id)
499 ).await?;
500
501 {
502 let mut pool = this.connection_pool.lock();
503 pool.add_connection(connection_id, user_id, user.admin);
504 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
505
506 if let Some((code, count)) = invite_code {
507 this.peer.send(connection_id, proto::UpdateInviteInfo {
508 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
509 count: count as u32,
510 })?;
511 }
512 }
513
514 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
515 this.peer.send(connection_id, incoming_call)?;
516 }
517
518 let session = Session {
519 user_id,
520 connection_id,
521 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
522 peer: this.peer.clone(),
523 connection_pool: this.connection_pool.clone(),
524 live_kit_client: this.app_state.live_kit_client.clone()
525 };
526 update_user_contacts(user_id, &session).await?;
527
528 let handle_io = handle_io.fuse();
529 futures::pin_mut!(handle_io);
530
531 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
532 // This prevents deadlocks when e.g., client A performs a request to client B and
533 // client B performs a request to client A. If both clients stop processing further
534 // messages until their respective request completes, they won't have a chance to
535 // respond to the other client's request and cause a deadlock.
536 //
537 // This arrangement ensures we will attempt to process earlier messages first, but fall
538 // back to processing messages arrived later in the spirit of making progress.
539 let mut foreground_message_handlers = FuturesUnordered::new();
540 loop {
541 let next_message = incoming_rx.next().fuse();
542 futures::pin_mut!(next_message);
543 futures::select_biased! {
544 _ = teardown.changed().fuse() => return Ok(()),
545 result = handle_io => {
546 if let Err(error) = result {
547 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
548 }
549 break;
550 }
551 _ = foreground_message_handlers.next() => {}
552 message = next_message => {
553 if let Some(message) = message {
554 let type_name = message.payload_type_name();
555 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
556 let span_enter = span.enter();
557 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
558 let is_background = message.is_background();
559 let handle_message = (handler)(message, session.clone());
560 drop(span_enter);
561
562 let handle_message = handle_message.instrument(span);
563 if is_background {
564 executor.spawn_detached(handle_message);
565 } else {
566 foreground_message_handlers.push(handle_message);
567 }
568 } else {
569 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
570 }
571 } else {
572 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
573 break;
574 }
575 }
576 }
577 }
578
579 drop(foreground_message_handlers);
580 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
581 if let Err(error) = sign_out(session, teardown, executor).await {
582 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
583 }
584
585 Ok(())
586 }.instrument(span)
587 }
588
589 pub async fn invite_code_redeemed(
590 self: &Arc<Self>,
591 inviter_id: UserId,
592 invitee_id: UserId,
593 ) -> Result<()> {
594 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
595 if let Some(code) = &user.invite_code {
596 let pool = self.connection_pool.lock();
597 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
598 for connection_id in pool.user_connection_ids(inviter_id) {
599 self.peer.send(
600 connection_id,
601 proto::UpdateContacts {
602 contacts: vec![invitee_contact.clone()],
603 ..Default::default()
604 },
605 )?;
606 self.peer.send(
607 connection_id,
608 proto::UpdateInviteInfo {
609 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
610 count: user.invite_count as u32,
611 },
612 )?;
613 }
614 }
615 }
616 Ok(())
617 }
618
619 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
620 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
621 if let Some(invite_code) = &user.invite_code {
622 let pool = self.connection_pool.lock();
623 for connection_id in pool.user_connection_ids(user_id) {
624 self.peer.send(
625 connection_id,
626 proto::UpdateInviteInfo {
627 url: format!(
628 "{}{}",
629 self.app_state.config.invite_link_prefix, invite_code
630 ),
631 count: user.invite_count as u32,
632 },
633 )?;
634 }
635 }
636 }
637 Ok(())
638 }
639
640 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
641 ServerSnapshot {
642 connection_pool: ConnectionPoolGuard {
643 guard: self.connection_pool.lock(),
644 _not_send: PhantomData,
645 },
646 peer: &self.peer,
647 }
648 }
649}
650
651impl<'a> Deref for ConnectionPoolGuard<'a> {
652 type Target = ConnectionPool;
653
654 fn deref(&self) -> &Self::Target {
655 &*self.guard
656 }
657}
658
659impl<'a> DerefMut for ConnectionPoolGuard<'a> {
660 fn deref_mut(&mut self) -> &mut Self::Target {
661 &mut *self.guard
662 }
663}
664
665impl<'a> Drop for ConnectionPoolGuard<'a> {
666 fn drop(&mut self) {
667 #[cfg(test)]
668 self.check_invariants();
669 }
670}
671
672fn broadcast<F>(
673 sender_id: ConnectionId,
674 receiver_ids: impl IntoIterator<Item = ConnectionId>,
675 mut f: F,
676) where
677 F: FnMut(ConnectionId) -> anyhow::Result<()>,
678{
679 for receiver_id in receiver_ids {
680 if receiver_id != sender_id {
681 f(receiver_id).trace_err();
682 }
683 }
684}
685
686lazy_static! {
687 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
688}
689
690pub struct ProtocolVersion(u32);
691
692impl Header for ProtocolVersion {
693 fn name() -> &'static HeaderName {
694 &ZED_PROTOCOL_VERSION
695 }
696
697 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
698 where
699 Self: Sized,
700 I: Iterator<Item = &'i axum::http::HeaderValue>,
701 {
702 let version = values
703 .next()
704 .ok_or_else(axum::headers::Error::invalid)?
705 .to_str()
706 .map_err(|_| axum::headers::Error::invalid())?
707 .parse()
708 .map_err(|_| axum::headers::Error::invalid())?;
709 Ok(Self(version))
710 }
711
712 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
713 values.extend([self.0.to_string().parse().unwrap()]);
714 }
715}
716
717pub fn routes(server: Arc<Server>) -> Router<Body> {
718 Router::new()
719 .route("/rpc", get(handle_websocket_request))
720 .layer(
721 ServiceBuilder::new()
722 .layer(Extension(server.app_state.clone()))
723 .layer(middleware::from_fn(auth::validate_header)),
724 )
725 .route("/metrics", get(handle_metrics))
726 .layer(Extension(server))
727}
728
729pub async fn handle_websocket_request(
730 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
731 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
732 Extension(server): Extension<Arc<Server>>,
733 Extension(user): Extension<User>,
734 ws: WebSocketUpgrade,
735) -> axum::response::Response {
736 if protocol_version != rpc::PROTOCOL_VERSION {
737 return (
738 StatusCode::UPGRADE_REQUIRED,
739 "client must be upgraded".to_string(),
740 )
741 .into_response();
742 }
743 let socket_address = socket_address.to_string();
744 ws.on_upgrade(move |socket| {
745 use util::ResultExt;
746 let socket = socket
747 .map_ok(to_tungstenite_message)
748 .err_into()
749 .with(|message| async move { Ok(to_axum_message(message)) });
750 let connection = Connection::new(Box::pin(socket));
751 async move {
752 server
753 .handle_connection(connection, socket_address, user, None, Executor::Production)
754 .await
755 .log_err();
756 }
757 })
758}
759
760pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
761 let connections = server
762 .connection_pool
763 .lock()
764 .connections()
765 .filter(|connection| !connection.admin)
766 .count();
767
768 METRIC_CONNECTIONS.set(connections as _);
769
770 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
771 METRIC_SHARED_PROJECTS.set(shared_projects as _);
772
773 let encoder = prometheus::TextEncoder::new();
774 let metric_families = prometheus::gather();
775 let encoded_metrics = encoder
776 .encode_to_string(&metric_families)
777 .map_err(|err| anyhow!("{}", err))?;
778 Ok(encoded_metrics)
779}
780
781#[instrument(err, skip(executor))]
782async fn sign_out(
783 session: Session,
784 mut teardown: watch::Receiver<()>,
785 executor: Executor,
786) -> Result<()> {
787 session.peer.disconnect(session.connection_id);
788 session
789 .connection_pool()
790 .await
791 .remove_connection(session.connection_id)?;
792
793 session
794 .db()
795 .await
796 .connection_lost(session.connection_id)
797 .await
798 .trace_err();
799
800 futures::select_biased! {
801 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
802 leave_room_for_session(&session).await.trace_err();
803
804 if !session
805 .connection_pool()
806 .await
807 .is_user_online(session.user_id)
808 {
809 let db = session.db().await;
810 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
811 room_updated(&room, &session.peer);
812 }
813 }
814 update_user_contacts(session.user_id, &session).await?;
815 }
816 _ = teardown.changed().fuse() => {}
817 }
818
819 Ok(())
820}
821
822async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
823 response.send(proto::Ack {})?;
824 Ok(())
825}
826
827async fn create_room(
828 _request: proto::CreateRoom,
829 response: Response<proto::CreateRoom>,
830 session: Session,
831) -> Result<()> {
832 let live_kit_room = nanoid::nanoid!(30);
833 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
834 if let Some(_) = live_kit
835 .create_room(live_kit_room.clone())
836 .await
837 .trace_err()
838 {
839 if let Some(token) = live_kit
840 .room_token(&live_kit_room, &session.user_id.to_string())
841 .trace_err()
842 {
843 Some(proto::LiveKitConnectionInfo {
844 server_url: live_kit.url().into(),
845 token,
846 })
847 } else {
848 None
849 }
850 } else {
851 None
852 }
853 } else {
854 None
855 };
856
857 {
858 let room = session
859 .db()
860 .await
861 .create_room(session.user_id, session.connection_id, &live_kit_room)
862 .await?;
863
864 response.send(proto::CreateRoomResponse {
865 room: Some(room.clone()),
866 live_kit_connection_info,
867 })?;
868 }
869
870 update_user_contacts(session.user_id, &session).await?;
871 Ok(())
872}
873
874async fn join_room(
875 request: proto::JoinRoom,
876 response: Response<proto::JoinRoom>,
877 session: Session,
878) -> Result<()> {
879 let room_id = RoomId::from_proto(request.id);
880 let room = {
881 let room = session
882 .db()
883 .await
884 .join_room(room_id, session.user_id, session.connection_id)
885 .await?;
886 room_updated(&room, &session.peer);
887 room.clone()
888 };
889
890 for connection_id in session
891 .connection_pool()
892 .await
893 .user_connection_ids(session.user_id)
894 {
895 session
896 .peer
897 .send(
898 connection_id,
899 proto::CallCanceled {
900 room_id: room_id.to_proto(),
901 },
902 )
903 .trace_err();
904 }
905
906 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
907 if let Some(token) = live_kit
908 .room_token(&room.live_kit_room, &session.user_id.to_string())
909 .trace_err()
910 {
911 Some(proto::LiveKitConnectionInfo {
912 server_url: live_kit.url().into(),
913 token,
914 })
915 } else {
916 None
917 }
918 } else {
919 None
920 };
921
922 response.send(proto::JoinRoomResponse {
923 room: Some(room),
924 live_kit_connection_info,
925 })?;
926
927 update_user_contacts(session.user_id, &session).await?;
928 Ok(())
929}
930
931async fn rejoin_room(
932 request: proto::RejoinRoom,
933 response: Response<proto::RejoinRoom>,
934 session: Session,
935) -> Result<()> {
936 {
937 let mut rejoined_room = session
938 .db()
939 .await
940 .rejoin_room(request, session.user_id, session.connection_id)
941 .await?;
942
943 response.send(proto::RejoinRoomResponse {
944 room: Some(rejoined_room.room.clone()),
945 reshared_projects: rejoined_room
946 .reshared_projects
947 .iter()
948 .map(|project| proto::ResharedProject {
949 id: project.id.to_proto(),
950 collaborators: project
951 .collaborators
952 .iter()
953 .map(|collaborator| collaborator.to_proto())
954 .collect(),
955 })
956 .collect(),
957 rejoined_projects: rejoined_room
958 .rejoined_projects
959 .iter()
960 .map(|rejoined_project| proto::RejoinedProject {
961 id: rejoined_project.id.to_proto(),
962 worktrees: rejoined_project
963 .worktrees
964 .iter()
965 .map(|worktree| proto::WorktreeMetadata {
966 id: worktree.id,
967 root_name: worktree.root_name.clone(),
968 visible: worktree.visible,
969 abs_path: worktree.abs_path.clone(),
970 })
971 .collect(),
972 collaborators: rejoined_project
973 .collaborators
974 .iter()
975 .map(|collaborator| collaborator.to_proto())
976 .collect(),
977 language_servers: rejoined_project.language_servers.clone(),
978 })
979 .collect(),
980 })?;
981 room_updated(&rejoined_room.room, &session.peer);
982
983 for project in &rejoined_room.reshared_projects {
984 for collaborator in &project.collaborators {
985 session
986 .peer
987 .send(
988 collaborator.connection_id,
989 proto::UpdateProjectCollaborator {
990 project_id: project.id.to_proto(),
991 old_peer_id: Some(project.old_connection_id.into()),
992 new_peer_id: Some(session.connection_id.into()),
993 },
994 )
995 .trace_err();
996 }
997
998 broadcast(
999 session.connection_id,
1000 project
1001 .collaborators
1002 .iter()
1003 .map(|collaborator| collaborator.connection_id),
1004 |connection_id| {
1005 session.peer.forward_send(
1006 session.connection_id,
1007 connection_id,
1008 proto::UpdateProject {
1009 project_id: project.id.to_proto(),
1010 worktrees: project.worktrees.clone(),
1011 },
1012 )
1013 },
1014 );
1015 }
1016
1017 for project in &rejoined_room.rejoined_projects {
1018 for collaborator in &project.collaborators {
1019 session
1020 .peer
1021 .send(
1022 collaborator.connection_id,
1023 proto::UpdateProjectCollaborator {
1024 project_id: project.id.to_proto(),
1025 old_peer_id: Some(project.old_connection_id.into()),
1026 new_peer_id: Some(session.connection_id.into()),
1027 },
1028 )
1029 .trace_err();
1030 }
1031 }
1032
1033 for project in &mut rejoined_room.rejoined_projects {
1034 for worktree in mem::take(&mut project.worktrees) {
1035 #[cfg(any(test, feature = "test-support"))]
1036 const MAX_CHUNK_SIZE: usize = 2;
1037 #[cfg(not(any(test, feature = "test-support")))]
1038 const MAX_CHUNK_SIZE: usize = 256;
1039
1040 // Stream this worktree's entries.
1041 let message = proto::UpdateWorktree {
1042 project_id: project.id.to_proto(),
1043 worktree_id: worktree.id,
1044 abs_path: worktree.abs_path.clone(),
1045 root_name: worktree.root_name,
1046 updated_entries: worktree.updated_entries,
1047 removed_entries: worktree.removed_entries,
1048 scan_id: worktree.scan_id,
1049 is_last_update: worktree.is_complete,
1050 };
1051 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1052 session.peer.send(session.connection_id, update.clone())?;
1053 }
1054
1055 // Stream this worktree's diagnostics.
1056 for summary in worktree.diagnostic_summaries {
1057 session.peer.send(
1058 session.connection_id,
1059 proto::UpdateDiagnosticSummary {
1060 project_id: project.id.to_proto(),
1061 worktree_id: worktree.id,
1062 summary: Some(summary),
1063 },
1064 )?;
1065 }
1066 }
1067
1068 for language_server in &project.language_servers {
1069 session.peer.send(
1070 session.connection_id,
1071 proto::UpdateLanguageServer {
1072 project_id: project.id.to_proto(),
1073 language_server_id: language_server.id,
1074 variant: Some(
1075 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1076 proto::LspDiskBasedDiagnosticsUpdated {},
1077 ),
1078 ),
1079 },
1080 )?;
1081 }
1082 }
1083 }
1084
1085 update_user_contacts(session.user_id, &session).await?;
1086 Ok(())
1087}
1088
1089async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
1090 leave_room_for_session(&session).await
1091}
1092
1093async fn call(
1094 request: proto::Call,
1095 response: Response<proto::Call>,
1096 session: Session,
1097) -> Result<()> {
1098 let room_id = RoomId::from_proto(request.room_id);
1099 let calling_user_id = session.user_id;
1100 let calling_connection_id = session.connection_id;
1101 let called_user_id = UserId::from_proto(request.called_user_id);
1102 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1103 if !session
1104 .db()
1105 .await
1106 .has_contact(calling_user_id, called_user_id)
1107 .await?
1108 {
1109 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1110 }
1111
1112 let incoming_call = {
1113 let (room, incoming_call) = &mut *session
1114 .db()
1115 .await
1116 .call(
1117 room_id,
1118 calling_user_id,
1119 calling_connection_id,
1120 called_user_id,
1121 initial_project_id,
1122 )
1123 .await?;
1124 room_updated(&room, &session.peer);
1125 mem::take(incoming_call)
1126 };
1127 update_user_contacts(called_user_id, &session).await?;
1128
1129 let mut calls = session
1130 .connection_pool()
1131 .await
1132 .user_connection_ids(called_user_id)
1133 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1134 .collect::<FuturesUnordered<_>>();
1135
1136 while let Some(call_response) = calls.next().await {
1137 match call_response.as_ref() {
1138 Ok(_) => {
1139 response.send(proto::Ack {})?;
1140 return Ok(());
1141 }
1142 Err(_) => {
1143 call_response.trace_err();
1144 }
1145 }
1146 }
1147
1148 {
1149 let room = session
1150 .db()
1151 .await
1152 .call_failed(room_id, called_user_id)
1153 .await?;
1154 room_updated(&room, &session.peer);
1155 }
1156 update_user_contacts(called_user_id, &session).await?;
1157
1158 Err(anyhow!("failed to ring user"))?
1159}
1160
1161async fn cancel_call(
1162 request: proto::CancelCall,
1163 response: Response<proto::CancelCall>,
1164 session: Session,
1165) -> Result<()> {
1166 let called_user_id = UserId::from_proto(request.called_user_id);
1167 let room_id = RoomId::from_proto(request.room_id);
1168 {
1169 let room = session
1170 .db()
1171 .await
1172 .cancel_call(room_id, session.connection_id, called_user_id)
1173 .await?;
1174 room_updated(&room, &session.peer);
1175 }
1176
1177 for connection_id in session
1178 .connection_pool()
1179 .await
1180 .user_connection_ids(called_user_id)
1181 {
1182 session
1183 .peer
1184 .send(
1185 connection_id,
1186 proto::CallCanceled {
1187 room_id: room_id.to_proto(),
1188 },
1189 )
1190 .trace_err();
1191 }
1192 response.send(proto::Ack {})?;
1193
1194 update_user_contacts(called_user_id, &session).await?;
1195 Ok(())
1196}
1197
1198async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1199 let room_id = RoomId::from_proto(message.room_id);
1200 {
1201 let room = session
1202 .db()
1203 .await
1204 .decline_call(Some(room_id), session.user_id)
1205 .await?
1206 .ok_or_else(|| anyhow!("failed to decline call"))?;
1207 room_updated(&room, &session.peer);
1208 }
1209
1210 for connection_id in session
1211 .connection_pool()
1212 .await
1213 .user_connection_ids(session.user_id)
1214 {
1215 session
1216 .peer
1217 .send(
1218 connection_id,
1219 proto::CallCanceled {
1220 room_id: room_id.to_proto(),
1221 },
1222 )
1223 .trace_err();
1224 }
1225 update_user_contacts(session.user_id, &session).await?;
1226 Ok(())
1227}
1228
1229async fn update_participant_location(
1230 request: proto::UpdateParticipantLocation,
1231 response: Response<proto::UpdateParticipantLocation>,
1232 session: Session,
1233) -> Result<()> {
1234 let room_id = RoomId::from_proto(request.room_id);
1235 let location = request
1236 .location
1237 .ok_or_else(|| anyhow!("invalid location"))?;
1238 let room = session
1239 .db()
1240 .await
1241 .update_room_participant_location(room_id, session.connection_id, location)
1242 .await?;
1243 room_updated(&room, &session.peer);
1244 response.send(proto::Ack {})?;
1245 Ok(())
1246}
1247
1248async fn share_project(
1249 request: proto::ShareProject,
1250 response: Response<proto::ShareProject>,
1251 session: Session,
1252) -> Result<()> {
1253 let (project_id, room) = &*session
1254 .db()
1255 .await
1256 .share_project(
1257 RoomId::from_proto(request.room_id),
1258 session.connection_id,
1259 &request.worktrees,
1260 )
1261 .await?;
1262 response.send(proto::ShareProjectResponse {
1263 project_id: project_id.to_proto(),
1264 })?;
1265 room_updated(&room, &session.peer);
1266
1267 Ok(())
1268}
1269
1270async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1271 let project_id = ProjectId::from_proto(message.project_id);
1272
1273 let (room, guest_connection_ids) = &*session
1274 .db()
1275 .await
1276 .unshare_project(project_id, session.connection_id)
1277 .await?;
1278
1279 broadcast(
1280 session.connection_id,
1281 guest_connection_ids.iter().copied(),
1282 |conn_id| session.peer.send(conn_id, message.clone()),
1283 );
1284 room_updated(&room, &session.peer);
1285
1286 Ok(())
1287}
1288
1289async fn join_project(
1290 request: proto::JoinProject,
1291 response: Response<proto::JoinProject>,
1292 session: Session,
1293) -> Result<()> {
1294 let project_id = ProjectId::from_proto(request.project_id);
1295 let guest_user_id = session.user_id;
1296
1297 tracing::info!(%project_id, "join project");
1298
1299 let (project, replica_id) = &mut *session
1300 .db()
1301 .await
1302 .join_project(project_id, session.connection_id)
1303 .await?;
1304
1305 let collaborators = project
1306 .collaborators
1307 .iter()
1308 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1309 .map(|collaborator| collaborator.to_proto())
1310 .collect::<Vec<_>>();
1311 let worktrees = project
1312 .worktrees
1313 .iter()
1314 .map(|(id, worktree)| proto::WorktreeMetadata {
1315 id: *id,
1316 root_name: worktree.root_name.clone(),
1317 visible: worktree.visible,
1318 abs_path: worktree.abs_path.clone(),
1319 })
1320 .collect::<Vec<_>>();
1321
1322 for collaborator in &collaborators {
1323 session
1324 .peer
1325 .send(
1326 collaborator.peer_id.unwrap().into(),
1327 proto::AddProjectCollaborator {
1328 project_id: project_id.to_proto(),
1329 collaborator: Some(proto::Collaborator {
1330 peer_id: Some(session.connection_id.into()),
1331 replica_id: replica_id.0 as u32,
1332 user_id: guest_user_id.to_proto(),
1333 }),
1334 },
1335 )
1336 .trace_err();
1337 }
1338
1339 // First, we send the metadata associated with each worktree.
1340 response.send(proto::JoinProjectResponse {
1341 worktrees: worktrees.clone(),
1342 replica_id: replica_id.0 as u32,
1343 collaborators: collaborators.clone(),
1344 language_servers: project.language_servers.clone(),
1345 })?;
1346
1347 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1348 #[cfg(any(test, feature = "test-support"))]
1349 const MAX_CHUNK_SIZE: usize = 2;
1350 #[cfg(not(any(test, feature = "test-support")))]
1351 const MAX_CHUNK_SIZE: usize = 256;
1352
1353 // Stream this worktree's entries.
1354 let message = proto::UpdateWorktree {
1355 project_id: project_id.to_proto(),
1356 worktree_id,
1357 abs_path: worktree.abs_path.clone(),
1358 root_name: worktree.root_name,
1359 updated_entries: worktree.entries,
1360 removed_entries: Default::default(),
1361 scan_id: worktree.scan_id,
1362 is_last_update: worktree.is_complete,
1363 };
1364 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1365 session.peer.send(session.connection_id, update.clone())?;
1366 }
1367
1368 // Stream this worktree's diagnostics.
1369 for summary in worktree.diagnostic_summaries {
1370 session.peer.send(
1371 session.connection_id,
1372 proto::UpdateDiagnosticSummary {
1373 project_id: project_id.to_proto(),
1374 worktree_id: worktree.id,
1375 summary: Some(summary),
1376 },
1377 )?;
1378 }
1379 }
1380
1381 for language_server in &project.language_servers {
1382 session.peer.send(
1383 session.connection_id,
1384 proto::UpdateLanguageServer {
1385 project_id: project_id.to_proto(),
1386 language_server_id: language_server.id,
1387 variant: Some(
1388 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1389 proto::LspDiskBasedDiagnosticsUpdated {},
1390 ),
1391 ),
1392 },
1393 )?;
1394 }
1395
1396 Ok(())
1397}
1398
1399async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1400 let sender_id = session.connection_id;
1401 let project_id = ProjectId::from_proto(request.project_id);
1402
1403 let project = session
1404 .db()
1405 .await
1406 .leave_project(project_id, sender_id)
1407 .await?;
1408 tracing::info!(
1409 %project_id,
1410 host_user_id = %project.host_user_id,
1411 host_connection_id = %project.host_connection_id,
1412 "leave project"
1413 );
1414 project_left(&project, &session);
1415
1416 Ok(())
1417}
1418
1419async fn update_project(
1420 request: proto::UpdateProject,
1421 response: Response<proto::UpdateProject>,
1422 session: Session,
1423) -> Result<()> {
1424 let project_id = ProjectId::from_proto(request.project_id);
1425 let (room, guest_connection_ids) = &*session
1426 .db()
1427 .await
1428 .update_project(project_id, session.connection_id, &request.worktrees)
1429 .await?;
1430 broadcast(
1431 session.connection_id,
1432 guest_connection_ids.iter().copied(),
1433 |connection_id| {
1434 session
1435 .peer
1436 .forward_send(session.connection_id, connection_id, request.clone())
1437 },
1438 );
1439 room_updated(&room, &session.peer);
1440 response.send(proto::Ack {})?;
1441
1442 Ok(())
1443}
1444
1445async fn update_worktree(
1446 request: proto::UpdateWorktree,
1447 response: Response<proto::UpdateWorktree>,
1448 session: Session,
1449) -> Result<()> {
1450 let guest_connection_ids = session
1451 .db()
1452 .await
1453 .update_worktree(&request, session.connection_id)
1454 .await?;
1455
1456 broadcast(
1457 session.connection_id,
1458 guest_connection_ids.iter().copied(),
1459 |connection_id| {
1460 session
1461 .peer
1462 .forward_send(session.connection_id, connection_id, request.clone())
1463 },
1464 );
1465 response.send(proto::Ack {})?;
1466 Ok(())
1467}
1468
1469async fn update_diagnostic_summary(
1470 message: proto::UpdateDiagnosticSummary,
1471 session: Session,
1472) -> Result<()> {
1473 let guest_connection_ids = session
1474 .db()
1475 .await
1476 .update_diagnostic_summary(&message, session.connection_id)
1477 .await?;
1478
1479 broadcast(
1480 session.connection_id,
1481 guest_connection_ids.iter().copied(),
1482 |connection_id| {
1483 session
1484 .peer
1485 .forward_send(session.connection_id, connection_id, message.clone())
1486 },
1487 );
1488
1489 Ok(())
1490}
1491
1492async fn start_language_server(
1493 request: proto::StartLanguageServer,
1494 session: Session,
1495) -> Result<()> {
1496 let guest_connection_ids = session
1497 .db()
1498 .await
1499 .start_language_server(&request, session.connection_id)
1500 .await?;
1501
1502 broadcast(
1503 session.connection_id,
1504 guest_connection_ids.iter().copied(),
1505 |connection_id| {
1506 session
1507 .peer
1508 .forward_send(session.connection_id, connection_id, request.clone())
1509 },
1510 );
1511 Ok(())
1512}
1513
1514async fn update_language_server(
1515 request: proto::UpdateLanguageServer,
1516 session: Session,
1517) -> Result<()> {
1518 let project_id = ProjectId::from_proto(request.project_id);
1519 let project_connection_ids = session
1520 .db()
1521 .await
1522 .project_connection_ids(project_id, session.connection_id)
1523 .await?;
1524 broadcast(
1525 session.connection_id,
1526 project_connection_ids.iter().copied(),
1527 |connection_id| {
1528 session
1529 .peer
1530 .forward_send(session.connection_id, connection_id, request.clone())
1531 },
1532 );
1533 Ok(())
1534}
1535
1536async fn forward_project_request<T>(
1537 request: T,
1538 response: Response<T>,
1539 session: Session,
1540) -> Result<()>
1541where
1542 T: EntityMessage + RequestMessage,
1543{
1544 let project_id = ProjectId::from_proto(request.remote_entity_id());
1545 let host_connection_id = {
1546 let collaborators = session
1547 .db()
1548 .await
1549 .project_collaborators(project_id, session.connection_id)
1550 .await?;
1551 collaborators
1552 .iter()
1553 .find(|collaborator| collaborator.is_host)
1554 .ok_or_else(|| anyhow!("host not found"))?
1555 .connection_id
1556 };
1557
1558 let payload = session
1559 .peer
1560 .forward_request(session.connection_id, host_connection_id, request)
1561 .await?;
1562
1563 response.send(payload)?;
1564 Ok(())
1565}
1566
1567async fn save_buffer(
1568 request: proto::SaveBuffer,
1569 response: Response<proto::SaveBuffer>,
1570 session: Session,
1571) -> Result<()> {
1572 let project_id = ProjectId::from_proto(request.project_id);
1573 let host_connection_id = {
1574 let collaborators = session
1575 .db()
1576 .await
1577 .project_collaborators(project_id, session.connection_id)
1578 .await?;
1579 collaborators
1580 .iter()
1581 .find(|collaborator| collaborator.is_host)
1582 .ok_or_else(|| anyhow!("host not found"))?
1583 .connection_id
1584 };
1585 let response_payload = session
1586 .peer
1587 .forward_request(session.connection_id, host_connection_id, request.clone())
1588 .await?;
1589
1590 let mut collaborators = session
1591 .db()
1592 .await
1593 .project_collaborators(project_id, session.connection_id)
1594 .await?;
1595 collaborators.retain(|collaborator| collaborator.connection_id != session.connection_id);
1596 let project_connection_ids = collaborators
1597 .iter()
1598 .map(|collaborator| collaborator.connection_id);
1599 broadcast(host_connection_id, project_connection_ids, |conn_id| {
1600 session
1601 .peer
1602 .forward_send(host_connection_id, conn_id, response_payload.clone())
1603 });
1604 response.send(response_payload)?;
1605 Ok(())
1606}
1607
1608async fn create_buffer_for_peer(
1609 request: proto::CreateBufferForPeer,
1610 session: Session,
1611) -> Result<()> {
1612 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1613 session
1614 .peer
1615 .forward_send(session.connection_id, peer_id.into(), request)?;
1616 Ok(())
1617}
1618
1619async fn update_buffer(
1620 request: proto::UpdateBuffer,
1621 response: Response<proto::UpdateBuffer>,
1622 session: Session,
1623) -> Result<()> {
1624 let project_id = ProjectId::from_proto(request.project_id);
1625 let project_connection_ids = session
1626 .db()
1627 .await
1628 .project_connection_ids(project_id, session.connection_id)
1629 .await?;
1630
1631 broadcast(
1632 session.connection_id,
1633 project_connection_ids.iter().copied(),
1634 |connection_id| {
1635 session
1636 .peer
1637 .forward_send(session.connection_id, connection_id, request.clone())
1638 },
1639 );
1640 response.send(proto::Ack {})?;
1641 Ok(())
1642}
1643
1644async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1645 let project_id = ProjectId::from_proto(request.project_id);
1646 let project_connection_ids = session
1647 .db()
1648 .await
1649 .project_connection_ids(project_id, session.connection_id)
1650 .await?;
1651
1652 broadcast(
1653 session.connection_id,
1654 project_connection_ids.iter().copied(),
1655 |connection_id| {
1656 session
1657 .peer
1658 .forward_send(session.connection_id, connection_id, request.clone())
1659 },
1660 );
1661 Ok(())
1662}
1663
1664async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1665 let project_id = ProjectId::from_proto(request.project_id);
1666 let project_connection_ids = session
1667 .db()
1668 .await
1669 .project_connection_ids(project_id, session.connection_id)
1670 .await?;
1671 broadcast(
1672 session.connection_id,
1673 project_connection_ids.iter().copied(),
1674 |connection_id| {
1675 session
1676 .peer
1677 .forward_send(session.connection_id, connection_id, request.clone())
1678 },
1679 );
1680 Ok(())
1681}
1682
1683async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1684 let project_id = ProjectId::from_proto(request.project_id);
1685 let project_connection_ids = session
1686 .db()
1687 .await
1688 .project_connection_ids(project_id, session.connection_id)
1689 .await?;
1690 broadcast(
1691 session.connection_id,
1692 project_connection_ids.iter().copied(),
1693 |connection_id| {
1694 session
1695 .peer
1696 .forward_send(session.connection_id, connection_id, request.clone())
1697 },
1698 );
1699 Ok(())
1700}
1701
1702async fn follow(
1703 request: proto::Follow,
1704 response: Response<proto::Follow>,
1705 session: Session,
1706) -> Result<()> {
1707 let project_id = ProjectId::from_proto(request.project_id);
1708 let leader_id = request
1709 .leader_id
1710 .ok_or_else(|| anyhow!("invalid leader id"))?
1711 .into();
1712 let follower_id = session.connection_id;
1713 {
1714 let project_connection_ids = session
1715 .db()
1716 .await
1717 .project_connection_ids(project_id, session.connection_id)
1718 .await?;
1719
1720 if !project_connection_ids.contains(&leader_id) {
1721 Err(anyhow!("no such peer"))?;
1722 }
1723 }
1724
1725 let mut response_payload = session
1726 .peer
1727 .forward_request(session.connection_id, leader_id, request)
1728 .await?;
1729 response_payload
1730 .views
1731 .retain(|view| view.leader_id != Some(follower_id.into()));
1732 response.send(response_payload)?;
1733 Ok(())
1734}
1735
1736async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1737 let project_id = ProjectId::from_proto(request.project_id);
1738 let leader_id = request
1739 .leader_id
1740 .ok_or_else(|| anyhow!("invalid leader id"))?
1741 .into();
1742 let project_connection_ids = session
1743 .db()
1744 .await
1745 .project_connection_ids(project_id, session.connection_id)
1746 .await?;
1747 if !project_connection_ids.contains(&leader_id) {
1748 Err(anyhow!("no such peer"))?;
1749 }
1750 session
1751 .peer
1752 .forward_send(session.connection_id, leader_id, request)?;
1753 Ok(())
1754}
1755
1756async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1757 let project_id = ProjectId::from_proto(request.project_id);
1758 let project_connection_ids = session
1759 .db
1760 .lock()
1761 .await
1762 .project_connection_ids(project_id, session.connection_id)
1763 .await?;
1764
1765 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1766 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1767 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1768 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1769 });
1770 for follower_peer_id in request.follower_ids.iter().copied() {
1771 let follower_connection_id = follower_peer_id.into();
1772 if project_connection_ids.contains(&follower_connection_id)
1773 && Some(follower_peer_id) != leader_id
1774 {
1775 session.peer.forward_send(
1776 session.connection_id,
1777 follower_connection_id,
1778 request.clone(),
1779 )?;
1780 }
1781 }
1782 Ok(())
1783}
1784
1785async fn get_users(
1786 request: proto::GetUsers,
1787 response: Response<proto::GetUsers>,
1788 session: Session,
1789) -> Result<()> {
1790 let user_ids = request
1791 .user_ids
1792 .into_iter()
1793 .map(UserId::from_proto)
1794 .collect();
1795 let users = session
1796 .db()
1797 .await
1798 .get_users_by_ids(user_ids)
1799 .await?
1800 .into_iter()
1801 .map(|user| proto::User {
1802 id: user.id.to_proto(),
1803 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1804 github_login: user.github_login,
1805 })
1806 .collect();
1807 response.send(proto::UsersResponse { users })?;
1808 Ok(())
1809}
1810
1811async fn fuzzy_search_users(
1812 request: proto::FuzzySearchUsers,
1813 response: Response<proto::FuzzySearchUsers>,
1814 session: Session,
1815) -> Result<()> {
1816 let query = request.query;
1817 let users = match query.len() {
1818 0 => vec![],
1819 1 | 2 => session
1820 .db()
1821 .await
1822 .get_user_by_github_account(&query, None)
1823 .await?
1824 .into_iter()
1825 .collect(),
1826 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1827 };
1828 let users = users
1829 .into_iter()
1830 .filter(|user| user.id != session.user_id)
1831 .map(|user| proto::User {
1832 id: user.id.to_proto(),
1833 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1834 github_login: user.github_login,
1835 })
1836 .collect();
1837 response.send(proto::UsersResponse { users })?;
1838 Ok(())
1839}
1840
1841async fn request_contact(
1842 request: proto::RequestContact,
1843 response: Response<proto::RequestContact>,
1844 session: Session,
1845) -> Result<()> {
1846 let requester_id = session.user_id;
1847 let responder_id = UserId::from_proto(request.responder_id);
1848 if requester_id == responder_id {
1849 return Err(anyhow!("cannot add yourself as a contact"))?;
1850 }
1851
1852 session
1853 .db()
1854 .await
1855 .send_contact_request(requester_id, responder_id)
1856 .await?;
1857
1858 // Update outgoing contact requests of requester
1859 let mut update = proto::UpdateContacts::default();
1860 update.outgoing_requests.push(responder_id.to_proto());
1861 for connection_id in session
1862 .connection_pool()
1863 .await
1864 .user_connection_ids(requester_id)
1865 {
1866 session.peer.send(connection_id, update.clone())?;
1867 }
1868
1869 // Update incoming contact requests of responder
1870 let mut update = proto::UpdateContacts::default();
1871 update
1872 .incoming_requests
1873 .push(proto::IncomingContactRequest {
1874 requester_id: requester_id.to_proto(),
1875 should_notify: true,
1876 });
1877 for connection_id in session
1878 .connection_pool()
1879 .await
1880 .user_connection_ids(responder_id)
1881 {
1882 session.peer.send(connection_id, update.clone())?;
1883 }
1884
1885 response.send(proto::Ack {})?;
1886 Ok(())
1887}
1888
1889async fn respond_to_contact_request(
1890 request: proto::RespondToContactRequest,
1891 response: Response<proto::RespondToContactRequest>,
1892 session: Session,
1893) -> Result<()> {
1894 let responder_id = session.user_id;
1895 let requester_id = UserId::from_proto(request.requester_id);
1896 let db = session.db().await;
1897 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1898 db.dismiss_contact_notification(responder_id, requester_id)
1899 .await?;
1900 } else {
1901 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1902
1903 db.respond_to_contact_request(responder_id, requester_id, accept)
1904 .await?;
1905 let requester_busy = db.is_user_busy(requester_id).await?;
1906 let responder_busy = db.is_user_busy(responder_id).await?;
1907
1908 let pool = session.connection_pool().await;
1909 // Update responder with new contact
1910 let mut update = proto::UpdateContacts::default();
1911 if accept {
1912 update
1913 .contacts
1914 .push(contact_for_user(requester_id, false, requester_busy, &pool));
1915 }
1916 update
1917 .remove_incoming_requests
1918 .push(requester_id.to_proto());
1919 for connection_id in pool.user_connection_ids(responder_id) {
1920 session.peer.send(connection_id, update.clone())?;
1921 }
1922
1923 // Update requester with new contact
1924 let mut update = proto::UpdateContacts::default();
1925 if accept {
1926 update
1927 .contacts
1928 .push(contact_for_user(responder_id, true, responder_busy, &pool));
1929 }
1930 update
1931 .remove_outgoing_requests
1932 .push(responder_id.to_proto());
1933 for connection_id in pool.user_connection_ids(requester_id) {
1934 session.peer.send(connection_id, update.clone())?;
1935 }
1936 }
1937
1938 response.send(proto::Ack {})?;
1939 Ok(())
1940}
1941
1942async fn remove_contact(
1943 request: proto::RemoveContact,
1944 response: Response<proto::RemoveContact>,
1945 session: Session,
1946) -> Result<()> {
1947 let requester_id = session.user_id;
1948 let responder_id = UserId::from_proto(request.user_id);
1949 let db = session.db().await;
1950 db.remove_contact(requester_id, responder_id).await?;
1951
1952 let pool = session.connection_pool().await;
1953 // Update outgoing contact requests of requester
1954 let mut update = proto::UpdateContacts::default();
1955 update
1956 .remove_outgoing_requests
1957 .push(responder_id.to_proto());
1958 for connection_id in pool.user_connection_ids(requester_id) {
1959 session.peer.send(connection_id, update.clone())?;
1960 }
1961
1962 // Update incoming contact requests of responder
1963 let mut update = proto::UpdateContacts::default();
1964 update
1965 .remove_incoming_requests
1966 .push(requester_id.to_proto());
1967 for connection_id in pool.user_connection_ids(responder_id) {
1968 session.peer.send(connection_id, update.clone())?;
1969 }
1970
1971 response.send(proto::Ack {})?;
1972 Ok(())
1973}
1974
1975async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1976 let project_id = ProjectId::from_proto(request.project_id);
1977 let project_connection_ids = session
1978 .db()
1979 .await
1980 .project_connection_ids(project_id, session.connection_id)
1981 .await?;
1982 broadcast(
1983 session.connection_id,
1984 project_connection_ids.iter().copied(),
1985 |connection_id| {
1986 session
1987 .peer
1988 .forward_send(session.connection_id, connection_id, request.clone())
1989 },
1990 );
1991 Ok(())
1992}
1993
1994async fn get_private_user_info(
1995 _request: proto::GetPrivateUserInfo,
1996 response: Response<proto::GetPrivateUserInfo>,
1997 session: Session,
1998) -> Result<()> {
1999 let metrics_id = session
2000 .db()
2001 .await
2002 .get_user_metrics_id(session.user_id)
2003 .await?;
2004 let user = session
2005 .db()
2006 .await
2007 .get_user_by_id(session.user_id)
2008 .await?
2009 .ok_or_else(|| anyhow!("user not found"))?;
2010 response.send(proto::GetPrivateUserInfoResponse {
2011 metrics_id,
2012 staff: user.admin,
2013 })?;
2014 Ok(())
2015}
2016
2017fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2018 match message {
2019 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2020 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2021 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2022 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2023 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2024 code: frame.code.into(),
2025 reason: frame.reason,
2026 })),
2027 }
2028}
2029
2030fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2031 match message {
2032 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2033 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2034 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2035 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2036 AxumMessage::Close(frame) => {
2037 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2038 code: frame.code.into(),
2039 reason: frame.reason,
2040 }))
2041 }
2042 }
2043}
2044
2045fn build_initial_contacts_update(
2046 contacts: Vec<db::Contact>,
2047 pool: &ConnectionPool,
2048) -> proto::UpdateContacts {
2049 let mut update = proto::UpdateContacts::default();
2050
2051 for contact in contacts {
2052 match contact {
2053 db::Contact::Accepted {
2054 user_id,
2055 should_notify,
2056 busy,
2057 } => {
2058 update
2059 .contacts
2060 .push(contact_for_user(user_id, should_notify, busy, &pool));
2061 }
2062 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2063 db::Contact::Incoming {
2064 user_id,
2065 should_notify,
2066 } => update
2067 .incoming_requests
2068 .push(proto::IncomingContactRequest {
2069 requester_id: user_id.to_proto(),
2070 should_notify,
2071 }),
2072 }
2073 }
2074
2075 update
2076}
2077
2078fn contact_for_user(
2079 user_id: UserId,
2080 should_notify: bool,
2081 busy: bool,
2082 pool: &ConnectionPool,
2083) -> proto::Contact {
2084 proto::Contact {
2085 user_id: user_id.to_proto(),
2086 online: pool.is_user_online(user_id),
2087 busy,
2088 should_notify,
2089 }
2090}
2091
2092fn room_updated(room: &proto::Room, peer: &Peer) {
2093 for participant in &room.participants {
2094 if let Some(peer_id) = participant
2095 .peer_id
2096 .ok_or_else(|| anyhow!("invalid participant peer id"))
2097 .trace_err()
2098 {
2099 peer.send(
2100 peer_id.into(),
2101 proto::RoomUpdated {
2102 room: Some(room.clone()),
2103 },
2104 )
2105 .trace_err();
2106 }
2107 }
2108}
2109
2110async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2111 let db = session.db().await;
2112 let contacts = db.get_contacts(user_id).await?;
2113 let busy = db.is_user_busy(user_id).await?;
2114
2115 let pool = session.connection_pool().await;
2116 let updated_contact = contact_for_user(user_id, false, busy, &pool);
2117 for contact in contacts {
2118 if let db::Contact::Accepted {
2119 user_id: contact_user_id,
2120 ..
2121 } = contact
2122 {
2123 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2124 session
2125 .peer
2126 .send(
2127 contact_conn_id,
2128 proto::UpdateContacts {
2129 contacts: vec![updated_contact.clone()],
2130 remove_contacts: Default::default(),
2131 incoming_requests: Default::default(),
2132 remove_incoming_requests: Default::default(),
2133 outgoing_requests: Default::default(),
2134 remove_outgoing_requests: Default::default(),
2135 },
2136 )
2137 .trace_err();
2138 }
2139 }
2140 }
2141 Ok(())
2142}
2143
2144async fn leave_room_for_session(session: &Session) -> Result<()> {
2145 let mut contacts_to_update = HashSet::default();
2146
2147 let room_id;
2148 let canceled_calls_to_user_ids;
2149 let live_kit_room;
2150 let delete_live_kit_room;
2151 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2152 contacts_to_update.insert(session.user_id);
2153
2154 for project in left_room.left_projects.values() {
2155 project_left(project, session);
2156 }
2157
2158 room_updated(&left_room.room, &session.peer);
2159 room_id = RoomId::from_proto(left_room.room.id);
2160 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2161 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2162 delete_live_kit_room = left_room.room.participants.is_empty();
2163 } else {
2164 return Ok(());
2165 }
2166
2167 {
2168 let pool = session.connection_pool().await;
2169 for canceled_user_id in canceled_calls_to_user_ids {
2170 for connection_id in pool.user_connection_ids(canceled_user_id) {
2171 session
2172 .peer
2173 .send(
2174 connection_id,
2175 proto::CallCanceled {
2176 room_id: room_id.to_proto(),
2177 },
2178 )
2179 .trace_err();
2180 }
2181 contacts_to_update.insert(canceled_user_id);
2182 }
2183 }
2184
2185 for contact_user_id in contacts_to_update {
2186 update_user_contacts(contact_user_id, &session).await?;
2187 }
2188
2189 if let Some(live_kit) = session.live_kit_client.as_ref() {
2190 live_kit
2191 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2192 .await
2193 .trace_err();
2194
2195 if delete_live_kit_room {
2196 live_kit.delete_room(live_kit_room).await.trace_err();
2197 }
2198 }
2199
2200 Ok(())
2201}
2202
2203fn project_left(project: &db::LeftProject, session: &Session) {
2204 for connection_id in &project.connection_ids {
2205 if project.host_user_id == session.user_id {
2206 session
2207 .peer
2208 .send(
2209 *connection_id,
2210 proto::UnshareProject {
2211 project_id: project.id.to_proto(),
2212 },
2213 )
2214 .trace_err();
2215 } else {
2216 session
2217 .peer
2218 .send(
2219 *connection_id,
2220 proto::RemoveProjectCollaborator {
2221 project_id: project.id.to_proto(),
2222 peer_id: Some(session.connection_id.into()),
2223 },
2224 )
2225 .trace_err();
2226 }
2227 }
2228}
2229
2230pub trait ResultExt {
2231 type Ok;
2232
2233 fn trace_err(self) -> Option<Self::Ok>;
2234}
2235
2236impl<T, E> ResultExt for Result<T, E>
2237where
2238 E: std::fmt::Debug,
2239{
2240 type Ok = T;
2241
2242 fn trace_err(self) -> Option<T> {
2243 match self {
2244 Ok(value) => Some(value),
2245 Err(error) => {
2246 tracing::error!("{:?}", error);
2247 None
2248 }
2249 }
2250 }
2251}