1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, Database, ProjectId, RoomId, ServerId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
38 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
39};
40use serde::{Serialize, Serializer};
41use std::{
42 any::TypeId,
43 fmt,
44 future::Future,
45 marker::PhantomData,
46 mem,
47 net::SocketAddr,
48 ops::{Deref, DerefMut},
49 rc::Rc,
50 sync::{
51 atomic::{AtomicBool, Ordering::SeqCst},
52 Arc,
53 },
54 time::{Duration, Instant},
55};
56use tokio::sync::{watch, Semaphore};
57use tower::ServiceBuilder;
58use tracing::{info_span, instrument, Instrument};
59
60pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
61pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
62
63lazy_static! {
64 static ref METRIC_CONNECTIONS: IntGauge =
65 register_int_gauge!("connections", "number of connections").unwrap();
66 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
67 "shared_projects",
68 "number of open projects with one or more guests"
69 )
70 .unwrap();
71}
72
73type MessageHandler =
74 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
75
76struct Response<R> {
77 peer: Arc<Peer>,
78 receipt: Receipt<R>,
79 responded: Arc<AtomicBool>,
80}
81
82impl<R: RequestMessage> Response<R> {
83 fn send(self, payload: R::Response) -> Result<()> {
84 self.responded.store(true, SeqCst);
85 self.peer.respond(self.receipt, payload)?;
86 Ok(())
87 }
88}
89
90#[derive(Clone)]
91struct Session {
92 user_id: UserId,
93 connection_id: ConnectionId,
94 db: Arc<tokio::sync::Mutex<DbHandle>>,
95 peer: Arc<Peer>,
96 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
97 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
98 executor: Executor,
99}
100
101impl Session {
102 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
103 #[cfg(test)]
104 tokio::task::yield_now().await;
105 let guard = self.db.lock().await;
106 #[cfg(test)]
107 tokio::task::yield_now().await;
108 guard
109 }
110
111 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
112 #[cfg(test)]
113 tokio::task::yield_now().await;
114 let guard = self.connection_pool.lock();
115 ConnectionPoolGuard {
116 guard,
117 _not_send: PhantomData,
118 }
119 }
120}
121
122impl fmt::Debug for Session {
123 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124 f.debug_struct("Session")
125 .field("user_id", &self.user_id)
126 .field("connection_id", &self.connection_id)
127 .finish()
128 }
129}
130
131struct DbHandle(Arc<Database>);
132
133impl Deref for DbHandle {
134 type Target = Database;
135
136 fn deref(&self) -> &Self::Target {
137 self.0.as_ref()
138 }
139}
140
141pub struct Server {
142 id: parking_lot::Mutex<ServerId>,
143 peer: Arc<Peer>,
144 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
145 app_state: Arc<AppState>,
146 executor: Executor,
147 handlers: HashMap<TypeId, MessageHandler>,
148 teardown: watch::Sender<()>,
149}
150
151pub(crate) struct ConnectionPoolGuard<'a> {
152 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
153 _not_send: PhantomData<Rc<()>>,
154}
155
156#[derive(Serialize)]
157pub struct ServerSnapshot<'a> {
158 peer: &'a Peer,
159 #[serde(serialize_with = "serialize_deref")]
160 connection_pool: ConnectionPoolGuard<'a>,
161}
162
163pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
164where
165 S: Serializer,
166 T: Deref<Target = U>,
167 U: Serialize,
168{
169 Serialize::serialize(value.deref(), serializer)
170}
171
172impl Server {
173 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
174 let mut server = Self {
175 id: parking_lot::Mutex::new(id),
176 peer: Peer::new(id.0 as u32),
177 app_state,
178 executor,
179 connection_pool: Default::default(),
180 handlers: Default::default(),
181 teardown: watch::channel(()).0,
182 };
183
184 server
185 .add_request_handler(ping)
186 .add_request_handler(create_room)
187 .add_request_handler(join_room)
188 .add_request_handler(rejoin_room)
189 .add_request_handler(leave_room)
190 .add_request_handler(call)
191 .add_request_handler(cancel_call)
192 .add_message_handler(decline_call)
193 .add_request_handler(update_participant_location)
194 .add_request_handler(share_project)
195 .add_message_handler(unshare_project)
196 .add_request_handler(join_project)
197 .add_message_handler(leave_project)
198 .add_request_handler(update_project)
199 .add_request_handler(update_worktree)
200 .add_message_handler(start_language_server)
201 .add_message_handler(update_language_server)
202 .add_message_handler(update_diagnostic_summary)
203 .add_request_handler(forward_project_request::<proto::GetHover>)
204 .add_request_handler(forward_project_request::<proto::GetDefinition>)
205 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
206 .add_request_handler(forward_project_request::<proto::GetReferences>)
207 .add_request_handler(forward_project_request::<proto::SearchProject>)
208 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
209 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
210 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
211 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
212 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
213 .add_request_handler(forward_project_request::<proto::GetCompletions>)
214 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
215 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
216 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
217 .add_request_handler(forward_project_request::<proto::PrepareRename>)
218 .add_request_handler(forward_project_request::<proto::PerformRename>)
219 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
220 .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
221 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
222 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
223 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
224 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
225 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
226 .add_message_handler(create_buffer_for_peer)
227 .add_request_handler(update_buffer)
228 .add_message_handler(update_buffer_file)
229 .add_message_handler(buffer_reloaded)
230 .add_message_handler(buffer_saved)
231 .add_request_handler(forward_project_request::<proto::SaveBuffer>)
232 .add_request_handler(get_users)
233 .add_request_handler(fuzzy_search_users)
234 .add_request_handler(request_contact)
235 .add_request_handler(remove_contact)
236 .add_request_handler(respond_to_contact_request)
237 .add_request_handler(follow)
238 .add_message_handler(unfollow)
239 .add_message_handler(update_followers)
240 .add_message_handler(update_diff_base)
241 .add_request_handler(get_private_user_info);
242
243 Arc::new(server)
244 }
245
246 pub async fn start(&self) -> Result<()> {
247 let server_id = *self.id.lock();
248 let app_state = self.app_state.clone();
249 let peer = self.peer.clone();
250 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
251 let pool = self.connection_pool.clone();
252 let live_kit_client = self.app_state.live_kit_client.clone();
253
254 let span = info_span!("start server");
255 self.executor.spawn_detached(
256 async move {
257 tracing::info!("waiting for cleanup timeout");
258 timeout.await;
259 tracing::info!("cleanup timeout expired, retrieving stale rooms");
260 if let Some(room_ids) = app_state
261 .db
262 .stale_room_ids(&app_state.config.zed_environment, server_id)
263 .await
264 .trace_err()
265 {
266 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
267 for room_id in room_ids {
268 let mut contacts_to_update = HashSet::default();
269 let mut canceled_calls_to_user_ids = Vec::new();
270 let mut live_kit_room = String::new();
271 let mut delete_live_kit_room = false;
272
273 if let Some(mut refreshed_room) = app_state
274 .db
275 .refresh_room(room_id, server_id)
276 .await
277 .trace_err()
278 {
279 tracing::info!(
280 room_id = room_id.0,
281 new_participant_count = refreshed_room.room.participants.len(),
282 "refreshed room"
283 );
284 room_updated(&refreshed_room.room, &peer);
285 contacts_to_update
286 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
287 contacts_to_update
288 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
289 canceled_calls_to_user_ids =
290 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
291 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
292 delete_live_kit_room = refreshed_room.room.participants.is_empty();
293 }
294
295 {
296 let pool = pool.lock();
297 for canceled_user_id in canceled_calls_to_user_ids {
298 for connection_id in pool.user_connection_ids(canceled_user_id) {
299 peer.send(
300 connection_id,
301 proto::CallCanceled {
302 room_id: room_id.to_proto(),
303 },
304 )
305 .trace_err();
306 }
307 }
308 }
309
310 for user_id in contacts_to_update {
311 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
312 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
313 if let Some((busy, contacts)) = busy.zip(contacts) {
314 let pool = pool.lock();
315 let updated_contact = contact_for_user(user_id, false, busy, &pool);
316 for contact in contacts {
317 if let db::Contact::Accepted {
318 user_id: contact_user_id,
319 ..
320 } = contact
321 {
322 for contact_conn_id in
323 pool.user_connection_ids(contact_user_id)
324 {
325 peer.send(
326 contact_conn_id,
327 proto::UpdateContacts {
328 contacts: vec![updated_contact.clone()],
329 remove_contacts: Default::default(),
330 incoming_requests: Default::default(),
331 remove_incoming_requests: Default::default(),
332 outgoing_requests: Default::default(),
333 remove_outgoing_requests: Default::default(),
334 },
335 )
336 .trace_err();
337 }
338 }
339 }
340 }
341 }
342
343 if let Some(live_kit) = live_kit_client.as_ref() {
344 if delete_live_kit_room {
345 live_kit.delete_room(live_kit_room).await.trace_err();
346 }
347 }
348 }
349 }
350
351 app_state
352 .db
353 .delete_stale_servers(&app_state.config.zed_environment, server_id)
354 .await
355 .trace_err();
356 }
357 .instrument(span),
358 );
359 Ok(())
360 }
361
362 pub fn teardown(&self) {
363 self.peer.teardown();
364 self.connection_pool.lock().reset();
365 let _ = self.teardown.send(());
366 }
367
368 #[cfg(test)]
369 pub fn reset(&self, id: ServerId) {
370 self.teardown();
371 *self.id.lock() = id;
372 self.peer.reset(id.0 as u32);
373 }
374
375 #[cfg(test)]
376 pub fn id(&self) -> ServerId {
377 *self.id.lock()
378 }
379
380 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
381 where
382 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
383 Fut: 'static + Send + Future<Output = Result<()>>,
384 M: EnvelopedMessage,
385 {
386 let prev_handler = self.handlers.insert(
387 TypeId::of::<M>(),
388 Box::new(move |envelope, session| {
389 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
390 let span = info_span!(
391 "handle message",
392 payload_type = envelope.payload_type_name()
393 );
394 span.in_scope(|| {
395 tracing::info!(
396 payload_type = envelope.payload_type_name(),
397 "message received"
398 );
399 });
400 let start_time = Instant::now();
401 let future = (handler)(*envelope, session);
402 async move {
403 let result = future.await;
404 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
405 match result {
406 Err(error) => {
407 tracing::error!(%error, ?duration_ms, "error handling message")
408 }
409 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
410 }
411 }
412 .instrument(span)
413 .boxed()
414 }),
415 );
416 if prev_handler.is_some() {
417 panic!("registered a handler for the same message twice");
418 }
419 self
420 }
421
422 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
423 where
424 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
425 Fut: 'static + Send + Future<Output = Result<()>>,
426 M: EnvelopedMessage,
427 {
428 self.add_handler(move |envelope, session| handler(envelope.payload, session));
429 self
430 }
431
432 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
433 where
434 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
435 Fut: Send + Future<Output = Result<()>>,
436 M: RequestMessage,
437 {
438 let handler = Arc::new(handler);
439 self.add_handler(move |envelope, session| {
440 let receipt = envelope.receipt();
441 let handler = handler.clone();
442 async move {
443 let peer = session.peer.clone();
444 let responded = Arc::new(AtomicBool::default());
445 let response = Response {
446 peer: peer.clone(),
447 responded: responded.clone(),
448 receipt,
449 };
450 match (handler)(envelope.payload, response, session).await {
451 Ok(()) => {
452 if responded.load(std::sync::atomic::Ordering::SeqCst) {
453 Ok(())
454 } else {
455 Err(anyhow!("handler did not send a response"))?
456 }
457 }
458 Err(error) => {
459 peer.respond_with_error(
460 receipt,
461 proto::Error {
462 message: error.to_string(),
463 },
464 )?;
465 Err(error)
466 }
467 }
468 }
469 })
470 }
471
472 pub fn handle_connection(
473 self: &Arc<Self>,
474 connection: Connection,
475 address: String,
476 user: User,
477 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
478 executor: Executor,
479 ) -> impl Future<Output = Result<()>> {
480 let this = self.clone();
481 let user_id = user.id;
482 let login = user.github_login;
483 let span = info_span!("handle connection", %user_id, %login, %address);
484 let mut teardown = self.teardown.subscribe();
485 async move {
486 let (connection_id, handle_io, mut incoming_rx) = this
487 .peer
488 .add_connection(connection, {
489 let executor = executor.clone();
490 move |duration| executor.sleep(duration)
491 });
492
493 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
494 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
495 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
496
497 if let Some(send_connection_id) = send_connection_id.take() {
498 let _ = send_connection_id.send(connection_id);
499 }
500
501 if !user.connected_once {
502 this.peer.send(connection_id, proto::ShowContacts {})?;
503 this.app_state.db.set_user_connected_once(user_id, true).await?;
504 }
505
506 let (contacts, invite_code) = future::try_join(
507 this.app_state.db.get_contacts(user_id),
508 this.app_state.db.get_invite_code_for_user(user_id)
509 ).await?;
510
511 {
512 let mut pool = this.connection_pool.lock();
513 pool.add_connection(connection_id, user_id, user.admin);
514 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
515
516 if let Some((code, count)) = invite_code {
517 this.peer.send(connection_id, proto::UpdateInviteInfo {
518 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
519 count: count as u32,
520 })?;
521 }
522 }
523
524 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
525 this.peer.send(connection_id, incoming_call)?;
526 }
527
528 let session = Session {
529 user_id,
530 connection_id,
531 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
532 peer: this.peer.clone(),
533 connection_pool: this.connection_pool.clone(),
534 live_kit_client: this.app_state.live_kit_client.clone(),
535 executor: executor.clone(),
536 };
537 update_user_contacts(user_id, &session).await?;
538
539 let handle_io = handle_io.fuse();
540 futures::pin_mut!(handle_io);
541
542 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
543 // This prevents deadlocks when e.g., client A performs a request to client B and
544 // client B performs a request to client A. If both clients stop processing further
545 // messages until their respective request completes, they won't have a chance to
546 // respond to the other client's request and cause a deadlock.
547 //
548 // This arrangement ensures we will attempt to process earlier messages first, but fall
549 // back to processing messages arrived later in the spirit of making progress.
550 let mut foreground_message_handlers = FuturesUnordered::new();
551 let concurrent_handlers = Arc::new(Semaphore::new(256));
552 loop {
553 let next_message = async {
554 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
555 let message = incoming_rx.next().await;
556 (permit, message)
557 }.fuse();
558 futures::pin_mut!(next_message);
559 futures::select_biased! {
560 _ = teardown.changed().fuse() => return Ok(()),
561 result = handle_io => {
562 if let Err(error) = result {
563 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
564 }
565 break;
566 }
567 _ = foreground_message_handlers.next() => {}
568 next_message = next_message => {
569 let (permit, message) = next_message;
570 if let Some(message) = message {
571 let type_name = message.payload_type_name();
572 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
573 let span_enter = span.enter();
574 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
575 let is_background = message.is_background();
576 let handle_message = (handler)(message, session.clone());
577 drop(span_enter);
578
579 let handle_message = async move {
580 handle_message.await;
581 drop(permit);
582 }.instrument(span);
583 if is_background {
584 executor.spawn_detached(handle_message);
585 } else {
586 foreground_message_handlers.push(handle_message);
587 }
588 } else {
589 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
590 }
591 } else {
592 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
593 break;
594 }
595 }
596 }
597 }
598
599 drop(foreground_message_handlers);
600 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
601 if let Err(error) = connection_lost(session, teardown, executor).await {
602 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
603 }
604
605 Ok(())
606 }.instrument(span)
607 }
608
609 pub async fn invite_code_redeemed(
610 self: &Arc<Self>,
611 inviter_id: UserId,
612 invitee_id: UserId,
613 ) -> Result<()> {
614 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
615 if let Some(code) = &user.invite_code {
616 let pool = self.connection_pool.lock();
617 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
618 for connection_id in pool.user_connection_ids(inviter_id) {
619 self.peer.send(
620 connection_id,
621 proto::UpdateContacts {
622 contacts: vec![invitee_contact.clone()],
623 ..Default::default()
624 },
625 )?;
626 self.peer.send(
627 connection_id,
628 proto::UpdateInviteInfo {
629 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
630 count: user.invite_count as u32,
631 },
632 )?;
633 }
634 }
635 }
636 Ok(())
637 }
638
639 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
640 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
641 if let Some(invite_code) = &user.invite_code {
642 let pool = self.connection_pool.lock();
643 for connection_id in pool.user_connection_ids(user_id) {
644 self.peer.send(
645 connection_id,
646 proto::UpdateInviteInfo {
647 url: format!(
648 "{}{}",
649 self.app_state.config.invite_link_prefix, invite_code
650 ),
651 count: user.invite_count as u32,
652 },
653 )?;
654 }
655 }
656 }
657 Ok(())
658 }
659
660 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
661 ServerSnapshot {
662 connection_pool: ConnectionPoolGuard {
663 guard: self.connection_pool.lock(),
664 _not_send: PhantomData,
665 },
666 peer: &self.peer,
667 }
668 }
669}
670
671impl<'a> Deref for ConnectionPoolGuard<'a> {
672 type Target = ConnectionPool;
673
674 fn deref(&self) -> &Self::Target {
675 &*self.guard
676 }
677}
678
679impl<'a> DerefMut for ConnectionPoolGuard<'a> {
680 fn deref_mut(&mut self) -> &mut Self::Target {
681 &mut *self.guard
682 }
683}
684
685impl<'a> Drop for ConnectionPoolGuard<'a> {
686 fn drop(&mut self) {
687 #[cfg(test)]
688 self.check_invariants();
689 }
690}
691
692fn broadcast<F>(
693 sender_id: Option<ConnectionId>,
694 receiver_ids: impl IntoIterator<Item = ConnectionId>,
695 mut f: F,
696) where
697 F: FnMut(ConnectionId) -> anyhow::Result<()>,
698{
699 for receiver_id in receiver_ids {
700 if Some(receiver_id) != sender_id {
701 if let Err(error) = f(receiver_id) {
702 tracing::error!("failed to send to {:?} {}", receiver_id, error);
703 }
704 }
705 }
706}
707
708lazy_static! {
709 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
710}
711
712pub struct ProtocolVersion(u32);
713
714impl Header for ProtocolVersion {
715 fn name() -> &'static HeaderName {
716 &ZED_PROTOCOL_VERSION
717 }
718
719 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
720 where
721 Self: Sized,
722 I: Iterator<Item = &'i axum::http::HeaderValue>,
723 {
724 let version = values
725 .next()
726 .ok_or_else(axum::headers::Error::invalid)?
727 .to_str()
728 .map_err(|_| axum::headers::Error::invalid())?
729 .parse()
730 .map_err(|_| axum::headers::Error::invalid())?;
731 Ok(Self(version))
732 }
733
734 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
735 values.extend([self.0.to_string().parse().unwrap()]);
736 }
737}
738
739pub fn routes(server: Arc<Server>) -> Router<Body> {
740 Router::new()
741 .route("/rpc", get(handle_websocket_request))
742 .layer(
743 ServiceBuilder::new()
744 .layer(Extension(server.app_state.clone()))
745 .layer(middleware::from_fn(auth::validate_header)),
746 )
747 .route("/metrics", get(handle_metrics))
748 .layer(Extension(server))
749}
750
751pub async fn handle_websocket_request(
752 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
753 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
754 Extension(server): Extension<Arc<Server>>,
755 Extension(user): Extension<User>,
756 ws: WebSocketUpgrade,
757) -> axum::response::Response {
758 if protocol_version != rpc::PROTOCOL_VERSION {
759 return (
760 StatusCode::UPGRADE_REQUIRED,
761 "client must be upgraded".to_string(),
762 )
763 .into_response();
764 }
765 let socket_address = socket_address.to_string();
766 ws.on_upgrade(move |socket| {
767 use util::ResultExt;
768 let socket = socket
769 .map_ok(to_tungstenite_message)
770 .err_into()
771 .with(|message| async move { Ok(to_axum_message(message)) });
772 let connection = Connection::new(Box::pin(socket));
773 async move {
774 server
775 .handle_connection(connection, socket_address, user, None, Executor::Production)
776 .await
777 .log_err();
778 }
779 })
780}
781
782pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
783 let connections = server
784 .connection_pool
785 .lock()
786 .connections()
787 .filter(|connection| !connection.admin)
788 .count();
789
790 METRIC_CONNECTIONS.set(connections as _);
791
792 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
793 METRIC_SHARED_PROJECTS.set(shared_projects as _);
794
795 let encoder = prometheus::TextEncoder::new();
796 let metric_families = prometheus::gather();
797 let encoded_metrics = encoder
798 .encode_to_string(&metric_families)
799 .map_err(|err| anyhow!("{}", err))?;
800 Ok(encoded_metrics)
801}
802
803#[instrument(err, skip(executor))]
804async fn connection_lost(
805 session: Session,
806 mut teardown: watch::Receiver<()>,
807 executor: Executor,
808) -> Result<()> {
809 session.peer.disconnect(session.connection_id);
810 session
811 .connection_pool()
812 .await
813 .remove_connection(session.connection_id)?;
814
815 session
816 .db()
817 .await
818 .connection_lost(session.connection_id)
819 .await
820 .trace_err();
821
822 futures::select_biased! {
823 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
824 leave_room_for_session(&session).await.trace_err();
825
826 if !session
827 .connection_pool()
828 .await
829 .is_user_online(session.user_id)
830 {
831 let db = session.db().await;
832 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
833 room_updated(&room, &session.peer);
834 }
835 }
836 update_user_contacts(session.user_id, &session).await?;
837 }
838 _ = teardown.changed().fuse() => {}
839 }
840
841 Ok(())
842}
843
844async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
845 response.send(proto::Ack {})?;
846 Ok(())
847}
848
849async fn create_room(
850 _request: proto::CreateRoom,
851 response: Response<proto::CreateRoom>,
852 session: Session,
853) -> Result<()> {
854 let live_kit_room = nanoid::nanoid!(30);
855 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
856 if let Some(_) = live_kit
857 .create_room(live_kit_room.clone())
858 .await
859 .trace_err()
860 {
861 if let Some(token) = live_kit
862 .room_token(&live_kit_room, &session.user_id.to_string())
863 .trace_err()
864 {
865 Some(proto::LiveKitConnectionInfo {
866 server_url: live_kit.url().into(),
867 token,
868 })
869 } else {
870 None
871 }
872 } else {
873 None
874 }
875 } else {
876 None
877 };
878
879 {
880 let room = session
881 .db()
882 .await
883 .create_room(session.user_id, session.connection_id, &live_kit_room)
884 .await?;
885
886 response.send(proto::CreateRoomResponse {
887 room: Some(room.clone()),
888 live_kit_connection_info,
889 })?;
890 }
891
892 update_user_contacts(session.user_id, &session).await?;
893 Ok(())
894}
895
896async fn join_room(
897 request: proto::JoinRoom,
898 response: Response<proto::JoinRoom>,
899 session: Session,
900) -> Result<()> {
901 let room_id = RoomId::from_proto(request.id);
902 let room = {
903 let room = session
904 .db()
905 .await
906 .join_room(room_id, session.user_id, session.connection_id)
907 .await?;
908 room_updated(&room, &session.peer);
909 room.clone()
910 };
911
912 for connection_id in session
913 .connection_pool()
914 .await
915 .user_connection_ids(session.user_id)
916 {
917 session
918 .peer
919 .send(
920 connection_id,
921 proto::CallCanceled {
922 room_id: room_id.to_proto(),
923 },
924 )
925 .trace_err();
926 }
927
928 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
929 if let Some(token) = live_kit
930 .room_token(&room.live_kit_room, &session.user_id.to_string())
931 .trace_err()
932 {
933 Some(proto::LiveKitConnectionInfo {
934 server_url: live_kit.url().into(),
935 token,
936 })
937 } else {
938 None
939 }
940 } else {
941 None
942 };
943
944 response.send(proto::JoinRoomResponse {
945 room: Some(room),
946 live_kit_connection_info,
947 })?;
948
949 update_user_contacts(session.user_id, &session).await?;
950 Ok(())
951}
952
953async fn rejoin_room(
954 request: proto::RejoinRoom,
955 response: Response<proto::RejoinRoom>,
956 session: Session,
957) -> Result<()> {
958 {
959 let mut rejoined_room = session
960 .db()
961 .await
962 .rejoin_room(request, session.user_id, session.connection_id)
963 .await?;
964
965 response.send(proto::RejoinRoomResponse {
966 room: Some(rejoined_room.room.clone()),
967 reshared_projects: rejoined_room
968 .reshared_projects
969 .iter()
970 .map(|project| proto::ResharedProject {
971 id: project.id.to_proto(),
972 collaborators: project
973 .collaborators
974 .iter()
975 .map(|collaborator| collaborator.to_proto())
976 .collect(),
977 })
978 .collect(),
979 rejoined_projects: rejoined_room
980 .rejoined_projects
981 .iter()
982 .map(|rejoined_project| proto::RejoinedProject {
983 id: rejoined_project.id.to_proto(),
984 worktrees: rejoined_project
985 .worktrees
986 .iter()
987 .map(|worktree| proto::WorktreeMetadata {
988 id: worktree.id,
989 root_name: worktree.root_name.clone(),
990 visible: worktree.visible,
991 abs_path: worktree.abs_path.clone(),
992 })
993 .collect(),
994 collaborators: rejoined_project
995 .collaborators
996 .iter()
997 .map(|collaborator| collaborator.to_proto())
998 .collect(),
999 language_servers: rejoined_project.language_servers.clone(),
1000 })
1001 .collect(),
1002 })?;
1003 room_updated(&rejoined_room.room, &session.peer);
1004
1005 for project in &rejoined_room.reshared_projects {
1006 for collaborator in &project.collaborators {
1007 session
1008 .peer
1009 .send(
1010 collaborator.connection_id,
1011 proto::UpdateProjectCollaborator {
1012 project_id: project.id.to_proto(),
1013 old_peer_id: Some(project.old_connection_id.into()),
1014 new_peer_id: Some(session.connection_id.into()),
1015 },
1016 )
1017 .trace_err();
1018 }
1019
1020 broadcast(
1021 Some(session.connection_id),
1022 project
1023 .collaborators
1024 .iter()
1025 .map(|collaborator| collaborator.connection_id),
1026 |connection_id| {
1027 session.peer.forward_send(
1028 session.connection_id,
1029 connection_id,
1030 proto::UpdateProject {
1031 project_id: project.id.to_proto(),
1032 worktrees: project.worktrees.clone(),
1033 },
1034 )
1035 },
1036 );
1037 }
1038
1039 for project in &rejoined_room.rejoined_projects {
1040 for collaborator in &project.collaborators {
1041 session
1042 .peer
1043 .send(
1044 collaborator.connection_id,
1045 proto::UpdateProjectCollaborator {
1046 project_id: project.id.to_proto(),
1047 old_peer_id: Some(project.old_connection_id.into()),
1048 new_peer_id: Some(session.connection_id.into()),
1049 },
1050 )
1051 .trace_err();
1052 }
1053 }
1054
1055 for project in &mut rejoined_room.rejoined_projects {
1056 for worktree in mem::take(&mut project.worktrees) {
1057 #[cfg(any(test, feature = "test-support"))]
1058 const MAX_CHUNK_SIZE: usize = 2;
1059 #[cfg(not(any(test, feature = "test-support")))]
1060 const MAX_CHUNK_SIZE: usize = 256;
1061
1062 // Stream this worktree's entries.
1063 let message = proto::UpdateWorktree {
1064 project_id: project.id.to_proto(),
1065 worktree_id: worktree.id,
1066 abs_path: worktree.abs_path.clone(),
1067 root_name: worktree.root_name,
1068 updated_entries: worktree.updated_entries,
1069 removed_entries: worktree.removed_entries,
1070 scan_id: worktree.scan_id,
1071 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1072 updated_repositories: worktree.updated_repositories,
1073 removed_repositories: worktree.removed_repositories,
1074 };
1075 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1076 session.peer.send(session.connection_id, update.clone())?;
1077 }
1078
1079 // Stream this worktree's diagnostics.
1080 for summary in worktree.diagnostic_summaries {
1081 session.peer.send(
1082 session.connection_id,
1083 proto::UpdateDiagnosticSummary {
1084 project_id: project.id.to_proto(),
1085 worktree_id: worktree.id,
1086 summary: Some(summary),
1087 },
1088 )?;
1089 }
1090 }
1091
1092 for language_server in &project.language_servers {
1093 session.peer.send(
1094 session.connection_id,
1095 proto::UpdateLanguageServer {
1096 project_id: project.id.to_proto(),
1097 language_server_id: language_server.id,
1098 variant: Some(
1099 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1100 proto::LspDiskBasedDiagnosticsUpdated {},
1101 ),
1102 ),
1103 },
1104 )?;
1105 }
1106 }
1107 }
1108
1109 update_user_contacts(session.user_id, &session).await?;
1110 Ok(())
1111}
1112
1113async fn leave_room(
1114 _: proto::LeaveRoom,
1115 response: Response<proto::LeaveRoom>,
1116 session: Session,
1117) -> Result<()> {
1118 leave_room_for_session(&session).await?;
1119 response.send(proto::Ack {})?;
1120 Ok(())
1121}
1122
1123async fn call(
1124 request: proto::Call,
1125 response: Response<proto::Call>,
1126 session: Session,
1127) -> Result<()> {
1128 let room_id = RoomId::from_proto(request.room_id);
1129 let calling_user_id = session.user_id;
1130 let calling_connection_id = session.connection_id;
1131 let called_user_id = UserId::from_proto(request.called_user_id);
1132 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1133 if !session
1134 .db()
1135 .await
1136 .has_contact(calling_user_id, called_user_id)
1137 .await?
1138 {
1139 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1140 }
1141
1142 let incoming_call = {
1143 let (room, incoming_call) = &mut *session
1144 .db()
1145 .await
1146 .call(
1147 room_id,
1148 calling_user_id,
1149 calling_connection_id,
1150 called_user_id,
1151 initial_project_id,
1152 )
1153 .await?;
1154 room_updated(&room, &session.peer);
1155 mem::take(incoming_call)
1156 };
1157 update_user_contacts(called_user_id, &session).await?;
1158
1159 let mut calls = session
1160 .connection_pool()
1161 .await
1162 .user_connection_ids(called_user_id)
1163 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1164 .collect::<FuturesUnordered<_>>();
1165
1166 while let Some(call_response) = calls.next().await {
1167 match call_response.as_ref() {
1168 Ok(_) => {
1169 response.send(proto::Ack {})?;
1170 return Ok(());
1171 }
1172 Err(_) => {
1173 call_response.trace_err();
1174 }
1175 }
1176 }
1177
1178 {
1179 let room = session
1180 .db()
1181 .await
1182 .call_failed(room_id, called_user_id)
1183 .await?;
1184 room_updated(&room, &session.peer);
1185 }
1186 update_user_contacts(called_user_id, &session).await?;
1187
1188 Err(anyhow!("failed to ring user"))?
1189}
1190
1191async fn cancel_call(
1192 request: proto::CancelCall,
1193 response: Response<proto::CancelCall>,
1194 session: Session,
1195) -> Result<()> {
1196 let called_user_id = UserId::from_proto(request.called_user_id);
1197 let room_id = RoomId::from_proto(request.room_id);
1198 {
1199 let room = session
1200 .db()
1201 .await
1202 .cancel_call(room_id, session.connection_id, called_user_id)
1203 .await?;
1204 room_updated(&room, &session.peer);
1205 }
1206
1207 for connection_id in session
1208 .connection_pool()
1209 .await
1210 .user_connection_ids(called_user_id)
1211 {
1212 session
1213 .peer
1214 .send(
1215 connection_id,
1216 proto::CallCanceled {
1217 room_id: room_id.to_proto(),
1218 },
1219 )
1220 .trace_err();
1221 }
1222 response.send(proto::Ack {})?;
1223
1224 update_user_contacts(called_user_id, &session).await?;
1225 Ok(())
1226}
1227
1228async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1229 let room_id = RoomId::from_proto(message.room_id);
1230 {
1231 let room = session
1232 .db()
1233 .await
1234 .decline_call(Some(room_id), session.user_id)
1235 .await?
1236 .ok_or_else(|| anyhow!("failed to decline call"))?;
1237 room_updated(&room, &session.peer);
1238 }
1239
1240 for connection_id in session
1241 .connection_pool()
1242 .await
1243 .user_connection_ids(session.user_id)
1244 {
1245 session
1246 .peer
1247 .send(
1248 connection_id,
1249 proto::CallCanceled {
1250 room_id: room_id.to_proto(),
1251 },
1252 )
1253 .trace_err();
1254 }
1255 update_user_contacts(session.user_id, &session).await?;
1256 Ok(())
1257}
1258
1259async fn update_participant_location(
1260 request: proto::UpdateParticipantLocation,
1261 response: Response<proto::UpdateParticipantLocation>,
1262 session: Session,
1263) -> Result<()> {
1264 let room_id = RoomId::from_proto(request.room_id);
1265 let location = request
1266 .location
1267 .ok_or_else(|| anyhow!("invalid location"))?;
1268 let room = session
1269 .db()
1270 .await
1271 .update_room_participant_location(room_id, session.connection_id, location)
1272 .await?;
1273 room_updated(&room, &session.peer);
1274 response.send(proto::Ack {})?;
1275 Ok(())
1276}
1277
1278async fn share_project(
1279 request: proto::ShareProject,
1280 response: Response<proto::ShareProject>,
1281 session: Session,
1282) -> Result<()> {
1283 let (project_id, room) = &*session
1284 .db()
1285 .await
1286 .share_project(
1287 RoomId::from_proto(request.room_id),
1288 session.connection_id,
1289 &request.worktrees,
1290 )
1291 .await?;
1292 response.send(proto::ShareProjectResponse {
1293 project_id: project_id.to_proto(),
1294 })?;
1295 room_updated(&room, &session.peer);
1296
1297 Ok(())
1298}
1299
1300async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1301 let project_id = ProjectId::from_proto(message.project_id);
1302
1303 let (room, guest_connection_ids) = &*session
1304 .db()
1305 .await
1306 .unshare_project(project_id, session.connection_id)
1307 .await?;
1308
1309 broadcast(
1310 Some(session.connection_id),
1311 guest_connection_ids.iter().copied(),
1312 |conn_id| session.peer.send(conn_id, message.clone()),
1313 );
1314 room_updated(&room, &session.peer);
1315
1316 Ok(())
1317}
1318
1319async fn join_project(
1320 request: proto::JoinProject,
1321 response: Response<proto::JoinProject>,
1322 session: Session,
1323) -> Result<()> {
1324 let project_id = ProjectId::from_proto(request.project_id);
1325 let guest_user_id = session.user_id;
1326
1327 tracing::info!(%project_id, "join project");
1328
1329 let (project, replica_id) = &mut *session
1330 .db()
1331 .await
1332 .join_project(project_id, session.connection_id)
1333 .await?;
1334
1335 let collaborators = project
1336 .collaborators
1337 .iter()
1338 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1339 .map(|collaborator| collaborator.to_proto())
1340 .collect::<Vec<_>>();
1341
1342 let worktrees = project
1343 .worktrees
1344 .iter()
1345 .map(|(id, worktree)| proto::WorktreeMetadata {
1346 id: *id,
1347 root_name: worktree.root_name.clone(),
1348 visible: worktree.visible,
1349 abs_path: worktree.abs_path.clone(),
1350 })
1351 .collect::<Vec<_>>();
1352
1353 for collaborator in &collaborators {
1354 session
1355 .peer
1356 .send(
1357 collaborator.peer_id.unwrap().into(),
1358 proto::AddProjectCollaborator {
1359 project_id: project_id.to_proto(),
1360 collaborator: Some(proto::Collaborator {
1361 peer_id: Some(session.connection_id.into()),
1362 replica_id: replica_id.0 as u32,
1363 user_id: guest_user_id.to_proto(),
1364 }),
1365 },
1366 )
1367 .trace_err();
1368 }
1369
1370 // First, we send the metadata associated with each worktree.
1371 response.send(proto::JoinProjectResponse {
1372 worktrees: worktrees.clone(),
1373 replica_id: replica_id.0 as u32,
1374 collaborators: collaborators.clone(),
1375 language_servers: project.language_servers.clone(),
1376 })?;
1377
1378 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1379 #[cfg(any(test, feature = "test-support"))]
1380 const MAX_CHUNK_SIZE: usize = 2;
1381 #[cfg(not(any(test, feature = "test-support")))]
1382 const MAX_CHUNK_SIZE: usize = 256;
1383
1384 // Stream this worktree's entries.
1385 let message = proto::UpdateWorktree {
1386 project_id: project_id.to_proto(),
1387 worktree_id,
1388 abs_path: worktree.abs_path.clone(),
1389 root_name: worktree.root_name,
1390 updated_entries: worktree.entries,
1391 removed_entries: Default::default(),
1392 scan_id: worktree.scan_id,
1393 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1394 updated_repositories: worktree.repository_entries.into_values().collect(),
1395 removed_repositories: Default::default(),
1396 };
1397 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1398 session.peer.send(session.connection_id, update.clone())?;
1399 }
1400
1401 // Stream this worktree's diagnostics.
1402 for summary in worktree.diagnostic_summaries {
1403 session.peer.send(
1404 session.connection_id,
1405 proto::UpdateDiagnosticSummary {
1406 project_id: project_id.to_proto(),
1407 worktree_id: worktree.id,
1408 summary: Some(summary),
1409 },
1410 )?;
1411 }
1412 }
1413
1414 for language_server in &project.language_servers {
1415 session.peer.send(
1416 session.connection_id,
1417 proto::UpdateLanguageServer {
1418 project_id: project_id.to_proto(),
1419 language_server_id: language_server.id,
1420 variant: Some(
1421 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1422 proto::LspDiskBasedDiagnosticsUpdated {},
1423 ),
1424 ),
1425 },
1426 )?;
1427 }
1428
1429 Ok(())
1430}
1431
1432async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1433 let sender_id = session.connection_id;
1434 let project_id = ProjectId::from_proto(request.project_id);
1435
1436 let (room, project) = &*session
1437 .db()
1438 .await
1439 .leave_project(project_id, sender_id)
1440 .await?;
1441 tracing::info!(
1442 %project_id,
1443 host_user_id = %project.host_user_id,
1444 host_connection_id = %project.host_connection_id,
1445 "leave project"
1446 );
1447
1448 project_left(&project, &session);
1449 room_updated(&room, &session.peer);
1450
1451 Ok(())
1452}
1453
1454async fn update_project(
1455 request: proto::UpdateProject,
1456 response: Response<proto::UpdateProject>,
1457 session: Session,
1458) -> Result<()> {
1459 let project_id = ProjectId::from_proto(request.project_id);
1460 let (room, guest_connection_ids) = &*session
1461 .db()
1462 .await
1463 .update_project(project_id, session.connection_id, &request.worktrees)
1464 .await?;
1465 broadcast(
1466 Some(session.connection_id),
1467 guest_connection_ids.iter().copied(),
1468 |connection_id| {
1469 session
1470 .peer
1471 .forward_send(session.connection_id, connection_id, request.clone())
1472 },
1473 );
1474 room_updated(&room, &session.peer);
1475 response.send(proto::Ack {})?;
1476
1477 Ok(())
1478}
1479
1480async fn update_worktree(
1481 request: proto::UpdateWorktree,
1482 response: Response<proto::UpdateWorktree>,
1483 session: Session,
1484) -> Result<()> {
1485 let guest_connection_ids = session
1486 .db()
1487 .await
1488 .update_worktree(&request, session.connection_id)
1489 .await?;
1490
1491 broadcast(
1492 Some(session.connection_id),
1493 guest_connection_ids.iter().copied(),
1494 |connection_id| {
1495 session
1496 .peer
1497 .forward_send(session.connection_id, connection_id, request.clone())
1498 },
1499 );
1500 response.send(proto::Ack {})?;
1501 Ok(())
1502}
1503
1504async fn update_diagnostic_summary(
1505 message: proto::UpdateDiagnosticSummary,
1506 session: Session,
1507) -> Result<()> {
1508 let guest_connection_ids = session
1509 .db()
1510 .await
1511 .update_diagnostic_summary(&message, session.connection_id)
1512 .await?;
1513
1514 broadcast(
1515 Some(session.connection_id),
1516 guest_connection_ids.iter().copied(),
1517 |connection_id| {
1518 session
1519 .peer
1520 .forward_send(session.connection_id, connection_id, message.clone())
1521 },
1522 );
1523
1524 Ok(())
1525}
1526
1527async fn start_language_server(
1528 request: proto::StartLanguageServer,
1529 session: Session,
1530) -> Result<()> {
1531 let guest_connection_ids = session
1532 .db()
1533 .await
1534 .start_language_server(&request, session.connection_id)
1535 .await?;
1536
1537 broadcast(
1538 Some(session.connection_id),
1539 guest_connection_ids.iter().copied(),
1540 |connection_id| {
1541 session
1542 .peer
1543 .forward_send(session.connection_id, connection_id, request.clone())
1544 },
1545 );
1546 Ok(())
1547}
1548
1549async fn update_language_server(
1550 request: proto::UpdateLanguageServer,
1551 session: Session,
1552) -> Result<()> {
1553 session.executor.record_backtrace();
1554 let project_id = ProjectId::from_proto(request.project_id);
1555 let project_connection_ids = session
1556 .db()
1557 .await
1558 .project_connection_ids(project_id, session.connection_id)
1559 .await?;
1560 broadcast(
1561 Some(session.connection_id),
1562 project_connection_ids.iter().copied(),
1563 |connection_id| {
1564 session
1565 .peer
1566 .forward_send(session.connection_id, connection_id, request.clone())
1567 },
1568 );
1569 Ok(())
1570}
1571
1572async fn forward_project_request<T>(
1573 request: T,
1574 response: Response<T>,
1575 session: Session,
1576) -> Result<()>
1577where
1578 T: EntityMessage + RequestMessage,
1579{
1580 session.executor.record_backtrace();
1581 let project_id = ProjectId::from_proto(request.remote_entity_id());
1582 let host_connection_id = {
1583 let collaborators = session
1584 .db()
1585 .await
1586 .project_collaborators(project_id, session.connection_id)
1587 .await?;
1588 collaborators
1589 .iter()
1590 .find(|collaborator| collaborator.is_host)
1591 .ok_or_else(|| anyhow!("host not found"))?
1592 .connection_id
1593 };
1594
1595 let payload = session
1596 .peer
1597 .forward_request(session.connection_id, host_connection_id, request)
1598 .await?;
1599
1600 response.send(payload)?;
1601 Ok(())
1602}
1603
1604async fn create_buffer_for_peer(
1605 request: proto::CreateBufferForPeer,
1606 session: Session,
1607) -> Result<()> {
1608 session.executor.record_backtrace();
1609 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1610 session
1611 .peer
1612 .forward_send(session.connection_id, peer_id.into(), request)?;
1613 Ok(())
1614}
1615
1616async fn update_buffer(
1617 request: proto::UpdateBuffer,
1618 response: Response<proto::UpdateBuffer>,
1619 session: Session,
1620) -> Result<()> {
1621 session.executor.record_backtrace();
1622 let project_id = ProjectId::from_proto(request.project_id);
1623 let mut guest_connection_ids;
1624 let mut host_connection_id = None;
1625 {
1626 let collaborators = session
1627 .db()
1628 .await
1629 .project_collaborators(project_id, session.connection_id)
1630 .await?;
1631 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1632 for collaborator in collaborators.iter() {
1633 if collaborator.is_host {
1634 host_connection_id = Some(collaborator.connection_id);
1635 } else {
1636 guest_connection_ids.push(collaborator.connection_id);
1637 }
1638 }
1639 }
1640 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1641
1642 session.executor.record_backtrace();
1643 broadcast(
1644 Some(session.connection_id),
1645 guest_connection_ids,
1646 |connection_id| {
1647 session
1648 .peer
1649 .forward_send(session.connection_id, connection_id, request.clone())
1650 },
1651 );
1652 if host_connection_id != session.connection_id {
1653 session
1654 .peer
1655 .forward_request(session.connection_id, host_connection_id, request.clone())
1656 .await?;
1657 }
1658
1659 response.send(proto::Ack {})?;
1660 Ok(())
1661}
1662
1663async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1664 let project_id = ProjectId::from_proto(request.project_id);
1665 let project_connection_ids = session
1666 .db()
1667 .await
1668 .project_connection_ids(project_id, session.connection_id)
1669 .await?;
1670
1671 broadcast(
1672 Some(session.connection_id),
1673 project_connection_ids.iter().copied(),
1674 |connection_id| {
1675 session
1676 .peer
1677 .forward_send(session.connection_id, connection_id, request.clone())
1678 },
1679 );
1680 Ok(())
1681}
1682
1683async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1684 let project_id = ProjectId::from_proto(request.project_id);
1685 let project_connection_ids = session
1686 .db()
1687 .await
1688 .project_connection_ids(project_id, session.connection_id)
1689 .await?;
1690 broadcast(
1691 Some(session.connection_id),
1692 project_connection_ids.iter().copied(),
1693 |connection_id| {
1694 session
1695 .peer
1696 .forward_send(session.connection_id, connection_id, request.clone())
1697 },
1698 );
1699 Ok(())
1700}
1701
1702async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1703 let project_id = ProjectId::from_proto(request.project_id);
1704 let project_connection_ids = session
1705 .db()
1706 .await
1707 .project_connection_ids(project_id, session.connection_id)
1708 .await?;
1709 broadcast(
1710 Some(session.connection_id),
1711 project_connection_ids.iter().copied(),
1712 |connection_id| {
1713 session
1714 .peer
1715 .forward_send(session.connection_id, connection_id, request.clone())
1716 },
1717 );
1718 Ok(())
1719}
1720
1721async fn follow(
1722 request: proto::Follow,
1723 response: Response<proto::Follow>,
1724 session: Session,
1725) -> Result<()> {
1726 let project_id = ProjectId::from_proto(request.project_id);
1727 let leader_id = request
1728 .leader_id
1729 .ok_or_else(|| anyhow!("invalid leader id"))?
1730 .into();
1731 let follower_id = session.connection_id;
1732
1733 {
1734 let project_connection_ids = session
1735 .db()
1736 .await
1737 .project_connection_ids(project_id, session.connection_id)
1738 .await?;
1739
1740 if !project_connection_ids.contains(&leader_id) {
1741 Err(anyhow!("no such peer"))?;
1742 }
1743 }
1744
1745 let mut response_payload = session
1746 .peer
1747 .forward_request(session.connection_id, leader_id, request)
1748 .await?;
1749 response_payload
1750 .views
1751 .retain(|view| view.leader_id != Some(follower_id.into()));
1752 response.send(response_payload)?;
1753
1754 let room = session
1755 .db()
1756 .await
1757 .follow(project_id, leader_id, follower_id)
1758 .await?;
1759 room_updated(&room, &session.peer);
1760
1761 Ok(())
1762}
1763
1764async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1765 let project_id = ProjectId::from_proto(request.project_id);
1766 let leader_id = request
1767 .leader_id
1768 .ok_or_else(|| anyhow!("invalid leader id"))?
1769 .into();
1770 let follower_id = session.connection_id;
1771
1772 if !session
1773 .db()
1774 .await
1775 .project_connection_ids(project_id, session.connection_id)
1776 .await?
1777 .contains(&leader_id)
1778 {
1779 Err(anyhow!("no such peer"))?;
1780 }
1781
1782 session
1783 .peer
1784 .forward_send(session.connection_id, leader_id, request)?;
1785
1786 let room = session
1787 .db()
1788 .await
1789 .unfollow(project_id, leader_id, follower_id)
1790 .await?;
1791 room_updated(&room, &session.peer);
1792
1793 Ok(())
1794}
1795
1796async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1797 let project_id = ProjectId::from_proto(request.project_id);
1798 let project_connection_ids = session
1799 .db
1800 .lock()
1801 .await
1802 .project_connection_ids(project_id, session.connection_id)
1803 .await?;
1804
1805 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1806 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1807 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1808 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1809 });
1810 for follower_peer_id in request.follower_ids.iter().copied() {
1811 let follower_connection_id = follower_peer_id.into();
1812 if project_connection_ids.contains(&follower_connection_id)
1813 && Some(follower_peer_id) != leader_id
1814 {
1815 session.peer.forward_send(
1816 session.connection_id,
1817 follower_connection_id,
1818 request.clone(),
1819 )?;
1820 }
1821 }
1822 Ok(())
1823}
1824
1825async fn get_users(
1826 request: proto::GetUsers,
1827 response: Response<proto::GetUsers>,
1828 session: Session,
1829) -> Result<()> {
1830 let user_ids = request
1831 .user_ids
1832 .into_iter()
1833 .map(UserId::from_proto)
1834 .collect();
1835 let users = session
1836 .db()
1837 .await
1838 .get_users_by_ids(user_ids)
1839 .await?
1840 .into_iter()
1841 .map(|user| proto::User {
1842 id: user.id.to_proto(),
1843 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1844 github_login: user.github_login,
1845 })
1846 .collect();
1847 response.send(proto::UsersResponse { users })?;
1848 Ok(())
1849}
1850
1851async fn fuzzy_search_users(
1852 request: proto::FuzzySearchUsers,
1853 response: Response<proto::FuzzySearchUsers>,
1854 session: Session,
1855) -> Result<()> {
1856 let query = request.query;
1857 let users = match query.len() {
1858 0 => vec![],
1859 1 | 2 => session
1860 .db()
1861 .await
1862 .get_user_by_github_login(&query)
1863 .await?
1864 .into_iter()
1865 .collect(),
1866 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1867 };
1868 let users = users
1869 .into_iter()
1870 .filter(|user| user.id != session.user_id)
1871 .map(|user| proto::User {
1872 id: user.id.to_proto(),
1873 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1874 github_login: user.github_login,
1875 })
1876 .collect();
1877 response.send(proto::UsersResponse { users })?;
1878 Ok(())
1879}
1880
1881async fn request_contact(
1882 request: proto::RequestContact,
1883 response: Response<proto::RequestContact>,
1884 session: Session,
1885) -> Result<()> {
1886 let requester_id = session.user_id;
1887 let responder_id = UserId::from_proto(request.responder_id);
1888 if requester_id == responder_id {
1889 return Err(anyhow!("cannot add yourself as a contact"))?;
1890 }
1891
1892 session
1893 .db()
1894 .await
1895 .send_contact_request(requester_id, responder_id)
1896 .await?;
1897
1898 // Update outgoing contact requests of requester
1899 let mut update = proto::UpdateContacts::default();
1900 update.outgoing_requests.push(responder_id.to_proto());
1901 for connection_id in session
1902 .connection_pool()
1903 .await
1904 .user_connection_ids(requester_id)
1905 {
1906 session.peer.send(connection_id, update.clone())?;
1907 }
1908
1909 // Update incoming contact requests of responder
1910 let mut update = proto::UpdateContacts::default();
1911 update
1912 .incoming_requests
1913 .push(proto::IncomingContactRequest {
1914 requester_id: requester_id.to_proto(),
1915 should_notify: true,
1916 });
1917 for connection_id in session
1918 .connection_pool()
1919 .await
1920 .user_connection_ids(responder_id)
1921 {
1922 session.peer.send(connection_id, update.clone())?;
1923 }
1924
1925 response.send(proto::Ack {})?;
1926 Ok(())
1927}
1928
1929async fn respond_to_contact_request(
1930 request: proto::RespondToContactRequest,
1931 response: Response<proto::RespondToContactRequest>,
1932 session: Session,
1933) -> Result<()> {
1934 let responder_id = session.user_id;
1935 let requester_id = UserId::from_proto(request.requester_id);
1936 let db = session.db().await;
1937 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1938 db.dismiss_contact_notification(responder_id, requester_id)
1939 .await?;
1940 } else {
1941 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1942
1943 db.respond_to_contact_request(responder_id, requester_id, accept)
1944 .await?;
1945 let requester_busy = db.is_user_busy(requester_id).await?;
1946 let responder_busy = db.is_user_busy(responder_id).await?;
1947
1948 let pool = session.connection_pool().await;
1949 // Update responder with new contact
1950 let mut update = proto::UpdateContacts::default();
1951 if accept {
1952 update
1953 .contacts
1954 .push(contact_for_user(requester_id, false, requester_busy, &pool));
1955 }
1956 update
1957 .remove_incoming_requests
1958 .push(requester_id.to_proto());
1959 for connection_id in pool.user_connection_ids(responder_id) {
1960 session.peer.send(connection_id, update.clone())?;
1961 }
1962
1963 // Update requester with new contact
1964 let mut update = proto::UpdateContacts::default();
1965 if accept {
1966 update
1967 .contacts
1968 .push(contact_for_user(responder_id, true, responder_busy, &pool));
1969 }
1970 update
1971 .remove_outgoing_requests
1972 .push(responder_id.to_proto());
1973 for connection_id in pool.user_connection_ids(requester_id) {
1974 session.peer.send(connection_id, update.clone())?;
1975 }
1976 }
1977
1978 response.send(proto::Ack {})?;
1979 Ok(())
1980}
1981
1982async fn remove_contact(
1983 request: proto::RemoveContact,
1984 response: Response<proto::RemoveContact>,
1985 session: Session,
1986) -> Result<()> {
1987 let requester_id = session.user_id;
1988 let responder_id = UserId::from_proto(request.user_id);
1989 let db = session.db().await;
1990 let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
1991
1992 let pool = session.connection_pool().await;
1993 // Update outgoing contact requests of requester
1994 let mut update = proto::UpdateContacts::default();
1995 if contact_accepted {
1996 update.remove_contacts.push(responder_id.to_proto());
1997 } else {
1998 update
1999 .remove_outgoing_requests
2000 .push(responder_id.to_proto());
2001 }
2002 for connection_id in pool.user_connection_ids(requester_id) {
2003 session.peer.send(connection_id, update.clone())?;
2004 }
2005
2006 // Update incoming contact requests of responder
2007 let mut update = proto::UpdateContacts::default();
2008 if contact_accepted {
2009 update.remove_contacts.push(requester_id.to_proto());
2010 } else {
2011 update
2012 .remove_incoming_requests
2013 .push(requester_id.to_proto());
2014 }
2015 for connection_id in pool.user_connection_ids(responder_id) {
2016 session.peer.send(connection_id, update.clone())?;
2017 }
2018
2019 response.send(proto::Ack {})?;
2020 Ok(())
2021}
2022
2023async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
2024 let project_id = ProjectId::from_proto(request.project_id);
2025 let project_connection_ids = session
2026 .db()
2027 .await
2028 .project_connection_ids(project_id, session.connection_id)
2029 .await?;
2030 broadcast(
2031 Some(session.connection_id),
2032 project_connection_ids.iter().copied(),
2033 |connection_id| {
2034 session
2035 .peer
2036 .forward_send(session.connection_id, connection_id, request.clone())
2037 },
2038 );
2039 Ok(())
2040}
2041
2042async fn get_private_user_info(
2043 _request: proto::GetPrivateUserInfo,
2044 response: Response<proto::GetPrivateUserInfo>,
2045 session: Session,
2046) -> Result<()> {
2047 let metrics_id = session
2048 .db()
2049 .await
2050 .get_user_metrics_id(session.user_id)
2051 .await?;
2052 let user = session
2053 .db()
2054 .await
2055 .get_user_by_id(session.user_id)
2056 .await?
2057 .ok_or_else(|| anyhow!("user not found"))?;
2058 response.send(proto::GetPrivateUserInfoResponse {
2059 metrics_id,
2060 staff: user.admin,
2061 })?;
2062 Ok(())
2063}
2064
2065fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
2066 match message {
2067 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
2068 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
2069 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
2070 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
2071 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
2072 code: frame.code.into(),
2073 reason: frame.reason,
2074 })),
2075 }
2076}
2077
2078fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
2079 match message {
2080 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
2081 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
2082 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
2083 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
2084 AxumMessage::Close(frame) => {
2085 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
2086 code: frame.code.into(),
2087 reason: frame.reason,
2088 }))
2089 }
2090 }
2091}
2092
2093fn build_initial_contacts_update(
2094 contacts: Vec<db::Contact>,
2095 pool: &ConnectionPool,
2096) -> proto::UpdateContacts {
2097 let mut update = proto::UpdateContacts::default();
2098
2099 for contact in contacts {
2100 match contact {
2101 db::Contact::Accepted {
2102 user_id,
2103 should_notify,
2104 busy,
2105 } => {
2106 update
2107 .contacts
2108 .push(contact_for_user(user_id, should_notify, busy, &pool));
2109 }
2110 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
2111 db::Contact::Incoming {
2112 user_id,
2113 should_notify,
2114 } => update
2115 .incoming_requests
2116 .push(proto::IncomingContactRequest {
2117 requester_id: user_id.to_proto(),
2118 should_notify,
2119 }),
2120 }
2121 }
2122
2123 update
2124}
2125
2126fn contact_for_user(
2127 user_id: UserId,
2128 should_notify: bool,
2129 busy: bool,
2130 pool: &ConnectionPool,
2131) -> proto::Contact {
2132 proto::Contact {
2133 user_id: user_id.to_proto(),
2134 online: pool.is_user_online(user_id),
2135 busy,
2136 should_notify,
2137 }
2138}
2139
2140fn room_updated(room: &proto::Room, peer: &Peer) {
2141 broadcast(
2142 None,
2143 room.participants
2144 .iter()
2145 .filter_map(|participant| Some(participant.peer_id?.into())),
2146 |peer_id| {
2147 peer.send(
2148 peer_id.into(),
2149 proto::RoomUpdated {
2150 room: Some(room.clone()),
2151 },
2152 )
2153 },
2154 );
2155}
2156
2157async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
2158 let db = session.db().await;
2159 let contacts = db.get_contacts(user_id).await?;
2160 let busy = db.is_user_busy(user_id).await?;
2161
2162 let pool = session.connection_pool().await;
2163 let updated_contact = contact_for_user(user_id, false, busy, &pool);
2164 for contact in contacts {
2165 if let db::Contact::Accepted {
2166 user_id: contact_user_id,
2167 ..
2168 } = contact
2169 {
2170 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
2171 session
2172 .peer
2173 .send(
2174 contact_conn_id,
2175 proto::UpdateContacts {
2176 contacts: vec![updated_contact.clone()],
2177 remove_contacts: Default::default(),
2178 incoming_requests: Default::default(),
2179 remove_incoming_requests: Default::default(),
2180 outgoing_requests: Default::default(),
2181 remove_outgoing_requests: Default::default(),
2182 },
2183 )
2184 .trace_err();
2185 }
2186 }
2187 }
2188 Ok(())
2189}
2190
2191async fn leave_room_for_session(session: &Session) -> Result<()> {
2192 let mut contacts_to_update = HashSet::default();
2193
2194 let room_id;
2195 let canceled_calls_to_user_ids;
2196 let live_kit_room;
2197 let delete_live_kit_room;
2198 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
2199 contacts_to_update.insert(session.user_id);
2200
2201 for project in left_room.left_projects.values() {
2202 project_left(project, session);
2203 }
2204
2205 room_updated(&left_room.room, &session.peer);
2206 room_id = RoomId::from_proto(left_room.room.id);
2207 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
2208 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
2209 delete_live_kit_room = left_room.room.participants.is_empty();
2210 } else {
2211 return Ok(());
2212 }
2213
2214 {
2215 let pool = session.connection_pool().await;
2216 for canceled_user_id in canceled_calls_to_user_ids {
2217 for connection_id in pool.user_connection_ids(canceled_user_id) {
2218 session
2219 .peer
2220 .send(
2221 connection_id,
2222 proto::CallCanceled {
2223 room_id: room_id.to_proto(),
2224 },
2225 )
2226 .trace_err();
2227 }
2228 contacts_to_update.insert(canceled_user_id);
2229 }
2230 }
2231
2232 for contact_user_id in contacts_to_update {
2233 update_user_contacts(contact_user_id, &session).await?;
2234 }
2235
2236 if let Some(live_kit) = session.live_kit_client.as_ref() {
2237 live_kit
2238 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
2239 .await
2240 .trace_err();
2241
2242 if delete_live_kit_room {
2243 live_kit.delete_room(live_kit_room).await.trace_err();
2244 }
2245 }
2246
2247 Ok(())
2248}
2249
2250fn project_left(project: &db::LeftProject, session: &Session) {
2251 for connection_id in &project.connection_ids {
2252 if project.host_user_id == session.user_id {
2253 session
2254 .peer
2255 .send(
2256 *connection_id,
2257 proto::UnshareProject {
2258 project_id: project.id.to_proto(),
2259 },
2260 )
2261 .trace_err();
2262 } else {
2263 session
2264 .peer
2265 .send(
2266 *connection_id,
2267 proto::RemoveProjectCollaborator {
2268 project_id: project.id.to_proto(),
2269 peer_id: Some(session.connection_id.into()),
2270 },
2271 )
2272 .trace_err();
2273 }
2274 }
2275}
2276
2277pub trait ResultExt {
2278 type Ok;
2279
2280 fn trace_err(self) -> Option<Self::Ok>;
2281}
2282
2283impl<T, E> ResultExt for Result<T, E>
2284where
2285 E: std::fmt::Debug,
2286{
2287 type Ok = T;
2288
2289 fn trace_err(self) -> Option<T> {
2290 match self {
2291 Ok(value) => Some(value),
2292 Err(error) => {
2293 tracing::error!("{:?}", error);
2294 None
2295 }
2296 }
2297 }
2298}