1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, Database, ProjectId, RoomId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
38 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
39};
40use serde::{Serialize, Serializer};
41use std::{
42 any::TypeId,
43 fmt,
44 future::Future,
45 marker::PhantomData,
46 mem,
47 net::SocketAddr,
48 ops::{Deref, DerefMut},
49 rc::Rc,
50 sync::{
51 atomic::{AtomicBool, Ordering::SeqCst},
52 Arc,
53 },
54 time::Duration,
55};
56use tokio::sync::watch;
57use tower::ServiceBuilder;
58use tracing::{info_span, instrument, Instrument};
59
60pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(5);
61pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
62
63lazy_static! {
64 static ref METRIC_CONNECTIONS: IntGauge =
65 register_int_gauge!("connections", "number of connections").unwrap();
66 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
67 "shared_projects",
68 "number of open projects with one or more guests"
69 )
70 .unwrap();
71}
72
73type MessageHandler =
74 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
75
76struct Response<R> {
77 peer: Arc<Peer>,
78 receipt: Receipt<R>,
79 responded: Arc<AtomicBool>,
80}
81
82impl<R: RequestMessage> Response<R> {
83 fn send(self, payload: R::Response) -> Result<()> {
84 self.responded.store(true, SeqCst);
85 self.peer.respond(self.receipt, payload)?;
86 Ok(())
87 }
88}
89
90#[derive(Clone)]
91struct Session {
92 user_id: UserId,
93 connection_id: ConnectionId,
94 db: Arc<tokio::sync::Mutex<DbHandle>>,
95 peer: Arc<Peer>,
96 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
97 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
98}
99
100impl Session {
101 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
102 #[cfg(test)]
103 tokio::task::yield_now().await;
104 let guard = self.db.lock().await;
105 #[cfg(test)]
106 tokio::task::yield_now().await;
107 guard
108 }
109
110 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
111 #[cfg(test)]
112 tokio::task::yield_now().await;
113 let guard = self.connection_pool.lock();
114 ConnectionPoolGuard {
115 guard,
116 _not_send: PhantomData,
117 }
118 }
119}
120
121impl fmt::Debug for Session {
122 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123 f.debug_struct("Session")
124 .field("user_id", &self.user_id)
125 .field("connection_id", &self.connection_id)
126 .finish()
127 }
128}
129
130struct DbHandle(Arc<Database>);
131
132impl Deref for DbHandle {
133 type Target = Database;
134
135 fn deref(&self) -> &Self::Target {
136 self.0.as_ref()
137 }
138}
139
140pub struct Server {
141 peer: Arc<Peer>,
142 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
143 app_state: Arc<AppState>,
144 executor: Executor,
145 handlers: HashMap<TypeId, MessageHandler>,
146 teardown: watch::Sender<()>,
147}
148
149pub(crate) struct ConnectionPoolGuard<'a> {
150 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
151 _not_send: PhantomData<Rc<()>>,
152}
153
154#[derive(Serialize)]
155pub struct ServerSnapshot<'a> {
156 peer: &'a Peer,
157 #[serde(serialize_with = "serialize_deref")]
158 connection_pool: ConnectionPoolGuard<'a>,
159}
160
161pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
162where
163 S: Serializer,
164 T: Deref<Target = U>,
165 U: Serialize,
166{
167 Serialize::serialize(value.deref(), serializer)
168}
169
170impl Server {
171 pub fn new(app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
172 let mut server = Self {
173 peer: Peer::new(),
174 app_state,
175 executor,
176 connection_pool: Default::default(),
177 handlers: Default::default(),
178 teardown: watch::channel(()).0,
179 };
180
181 server
182 .add_request_handler(ping)
183 .add_request_handler(create_room)
184 .add_request_handler(join_room)
185 .add_message_handler(leave_room)
186 .add_request_handler(call)
187 .add_request_handler(cancel_call)
188 .add_message_handler(decline_call)
189 .add_request_handler(update_participant_location)
190 .add_request_handler(share_project)
191 .add_message_handler(unshare_project)
192 .add_request_handler(join_project)
193 .add_message_handler(leave_project)
194 .add_request_handler(update_project)
195 .add_request_handler(update_worktree)
196 .add_message_handler(start_language_server)
197 .add_message_handler(update_language_server)
198 .add_message_handler(update_diagnostic_summary)
199 .add_request_handler(forward_project_request::<proto::GetHover>)
200 .add_request_handler(forward_project_request::<proto::GetDefinition>)
201 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
202 .add_request_handler(forward_project_request::<proto::GetReferences>)
203 .add_request_handler(forward_project_request::<proto::SearchProject>)
204 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
205 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
206 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
207 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
208 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
209 .add_request_handler(forward_project_request::<proto::GetCompletions>)
210 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
211 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
212 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
213 .add_request_handler(forward_project_request::<proto::PrepareRename>)
214 .add_request_handler(forward_project_request::<proto::PerformRename>)
215 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
216 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
217 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
218 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
219 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
220 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
221 .add_message_handler(create_buffer_for_peer)
222 .add_request_handler(update_buffer)
223 .add_message_handler(update_buffer_file)
224 .add_message_handler(buffer_reloaded)
225 .add_message_handler(buffer_saved)
226 .add_request_handler(save_buffer)
227 .add_request_handler(get_users)
228 .add_request_handler(fuzzy_search_users)
229 .add_request_handler(request_contact)
230 .add_request_handler(remove_contact)
231 .add_request_handler(respond_to_contact_request)
232 .add_request_handler(follow)
233 .add_message_handler(unfollow)
234 .add_message_handler(update_followers)
235 .add_message_handler(update_diff_base)
236 .add_request_handler(get_private_user_info);
237
238 Arc::new(server)
239 }
240
241 pub async fn start(&self) -> Result<()> {
242 let db = self.app_state.db.clone();
243 let peer = self.peer.clone();
244 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
245 let pool = self.connection_pool.clone();
246 let live_kit_client = self.app_state.live_kit_client.clone();
247
248 let span = info_span!("start server");
249 let span_enter = span.enter();
250
251 tracing::info!("begin deleting stale projects");
252 self.app_state.db.delete_stale_projects().await?;
253 tracing::info!("finish deleting stale projects");
254
255 drop(span_enter);
256 self.executor.spawn_detached(
257 async move {
258 tracing::info!("waiting for cleanup timeout");
259 timeout.await;
260 tracing::info!("cleanup timeout expired, retrieving stale rooms");
261 if let Some(room_ids) = db.stale_room_ids().await.trace_err() {
262 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
263 for room_id in room_ids {
264 let mut contacts_to_update = HashSet::default();
265 let mut canceled_calls_to_user_ids = Vec::new();
266 let mut live_kit_room = String::new();
267 let mut delete_live_kit_room = false;
268
269 if let Ok(mut refreshed_room) = db.refresh_room(room_id).await {
270 tracing::info!(
271 room_id = room_id.0,
272 new_participant_count = refreshed_room.room.participants.len(),
273 "refreshed room"
274 );
275 room_updated(&refreshed_room.room, &peer);
276 contacts_to_update
277 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
278 contacts_to_update
279 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
280 canceled_calls_to_user_ids =
281 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
282 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
283 delete_live_kit_room = refreshed_room.room.participants.is_empty();
284 }
285
286 {
287 let pool = pool.lock();
288 for canceled_user_id in canceled_calls_to_user_ids {
289 for connection_id in pool.user_connection_ids(canceled_user_id) {
290 peer.send(
291 connection_id,
292 proto::CallCanceled {
293 room_id: room_id.to_proto(),
294 },
295 )
296 .trace_err();
297 }
298 }
299 }
300
301 for user_id in contacts_to_update {
302 let busy = db.is_user_busy(user_id).await.trace_err();
303 let contacts = db.get_contacts(user_id).await.trace_err();
304 if let Some((busy, contacts)) = busy.zip(contacts) {
305 let pool = pool.lock();
306 let updated_contact = contact_for_user(user_id, false, busy, &pool);
307 for contact in contacts {
308 if let db::Contact::Accepted {
309 user_id: contact_user_id,
310 ..
311 } = contact
312 {
313 for contact_conn_id in
314 pool.user_connection_ids(contact_user_id)
315 {
316 peer.send(
317 contact_conn_id,
318 proto::UpdateContacts {
319 contacts: vec![updated_contact.clone()],
320 remove_contacts: Default::default(),
321 incoming_requests: Default::default(),
322 remove_incoming_requests: Default::default(),
323 outgoing_requests: Default::default(),
324 remove_outgoing_requests: Default::default(),
325 },
326 )
327 .trace_err();
328 }
329 }
330 }
331 }
332 }
333
334 if let Some(live_kit) = live_kit_client.as_ref() {
335 if delete_live_kit_room {
336 live_kit.delete_room(live_kit_room).await.trace_err();
337 }
338 }
339 }
340 }
341 }
342 .instrument(span),
343 );
344 Ok(())
345 }
346
347 pub fn teardown(&self) {
348 self.peer.reset();
349 self.connection_pool.lock().reset();
350 let _ = self.teardown.send(());
351 }
352
353 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
354 where
355 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
356 Fut: 'static + Send + Future<Output = Result<()>>,
357 M: EnvelopedMessage,
358 {
359 let prev_handler = self.handlers.insert(
360 TypeId::of::<M>(),
361 Box::new(move |envelope, session| {
362 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
363 let span = info_span!(
364 "handle message",
365 payload_type = envelope.payload_type_name()
366 );
367 span.in_scope(|| {
368 tracing::info!(
369 payload_type = envelope.payload_type_name(),
370 "message received"
371 );
372 });
373 let future = (handler)(*envelope, session);
374 async move {
375 if let Err(error) = future.await {
376 tracing::error!(%error, "error handling message");
377 }
378 }
379 .instrument(span)
380 .boxed()
381 }),
382 );
383 if prev_handler.is_some() {
384 panic!("registered a handler for the same message twice");
385 }
386 self
387 }
388
389 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
390 where
391 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
392 Fut: 'static + Send + Future<Output = Result<()>>,
393 M: EnvelopedMessage,
394 {
395 self.add_handler(move |envelope, session| handler(envelope.payload, session));
396 self
397 }
398
399 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
400 where
401 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
402 Fut: Send + Future<Output = Result<()>>,
403 M: RequestMessage,
404 {
405 let handler = Arc::new(handler);
406 self.add_handler(move |envelope, session| {
407 let receipt = envelope.receipt();
408 let handler = handler.clone();
409 async move {
410 let peer = session.peer.clone();
411 let responded = Arc::new(AtomicBool::default());
412 let response = Response {
413 peer: peer.clone(),
414 responded: responded.clone(),
415 receipt,
416 };
417 match (handler)(envelope.payload, response, session).await {
418 Ok(()) => {
419 if responded.load(std::sync::atomic::Ordering::SeqCst) {
420 Ok(())
421 } else {
422 Err(anyhow!("handler did not send a response"))?
423 }
424 }
425 Err(error) => {
426 peer.respond_with_error(
427 receipt,
428 proto::Error {
429 message: error.to_string(),
430 },
431 )?;
432 Err(error)
433 }
434 }
435 }
436 })
437 }
438
439 pub fn handle_connection(
440 self: &Arc<Self>,
441 connection: Connection,
442 address: String,
443 user: User,
444 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
445 executor: Executor,
446 ) -> impl Future<Output = Result<()>> {
447 let this = self.clone();
448 let user_id = user.id;
449 let login = user.github_login;
450 let span = info_span!("handle connection", %user_id, %login, %address);
451 let mut teardown = self.teardown.subscribe();
452 async move {
453 let (connection_id, handle_io, mut incoming_rx) = this
454 .peer
455 .add_connection(connection, {
456 let executor = executor.clone();
457 move |duration| executor.sleep(duration)
458 });
459
460 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
461 this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
462 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
463
464 if let Some(send_connection_id) = send_connection_id.take() {
465 let _ = send_connection_id.send(connection_id);
466 }
467
468 if !user.connected_once {
469 this.peer.send(connection_id, proto::ShowContacts {})?;
470 this.app_state.db.set_user_connected_once(user_id, true).await?;
471 }
472
473 let (contacts, invite_code) = future::try_join(
474 this.app_state.db.get_contacts(user_id),
475 this.app_state.db.get_invite_code_for_user(user_id)
476 ).await?;
477
478 {
479 let mut pool = this.connection_pool.lock();
480 pool.add_connection(connection_id, user_id, user.admin);
481 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
482
483 if let Some((code, count)) = invite_code {
484 this.peer.send(connection_id, proto::UpdateInviteInfo {
485 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
486 count: count as u32,
487 })?;
488 }
489 }
490
491 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
492 this.peer.send(connection_id, incoming_call)?;
493 }
494
495 let session = Session {
496 user_id,
497 connection_id,
498 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
499 peer: this.peer.clone(),
500 connection_pool: this.connection_pool.clone(),
501 live_kit_client: this.app_state.live_kit_client.clone()
502 };
503 update_user_contacts(user_id, &session).await?;
504
505 let handle_io = handle_io.fuse();
506 futures::pin_mut!(handle_io);
507
508 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
509 // This prevents deadlocks when e.g., client A performs a request to client B and
510 // client B performs a request to client A. If both clients stop processing further
511 // messages until their respective request completes, they won't have a chance to
512 // respond to the other client's request and cause a deadlock.
513 //
514 // This arrangement ensures we will attempt to process earlier messages first, but fall
515 // back to processing messages arrived later in the spirit of making progress.
516 let mut foreground_message_handlers = FuturesUnordered::new();
517 loop {
518 let next_message = incoming_rx.next().fuse();
519 futures::pin_mut!(next_message);
520 futures::select_biased! {
521 _ = teardown.changed().fuse() => return Ok(()),
522 result = handle_io => {
523 if let Err(error) = result {
524 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
525 }
526 break;
527 }
528 _ = foreground_message_handlers.next() => {}
529 message = next_message => {
530 if let Some(message) = message {
531 let type_name = message.payload_type_name();
532 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
533 let span_enter = span.enter();
534 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
535 let is_background = message.is_background();
536 let handle_message = (handler)(message, session.clone());
537 drop(span_enter);
538
539 let handle_message = handle_message.instrument(span);
540 if is_background {
541 executor.spawn_detached(handle_message);
542 } else {
543 foreground_message_handlers.push(handle_message);
544 }
545 } else {
546 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
547 }
548 } else {
549 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
550 break;
551 }
552 }
553 }
554 }
555
556 drop(foreground_message_handlers);
557 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
558 if let Err(error) = sign_out(session, teardown, executor).await {
559 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
560 }
561
562 Ok(())
563 }.instrument(span)
564 }
565
566 pub async fn invite_code_redeemed(
567 self: &Arc<Self>,
568 inviter_id: UserId,
569 invitee_id: UserId,
570 ) -> Result<()> {
571 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
572 if let Some(code) = &user.invite_code {
573 let pool = self.connection_pool.lock();
574 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
575 for connection_id in pool.user_connection_ids(inviter_id) {
576 self.peer.send(
577 connection_id,
578 proto::UpdateContacts {
579 contacts: vec![invitee_contact.clone()],
580 ..Default::default()
581 },
582 )?;
583 self.peer.send(
584 connection_id,
585 proto::UpdateInviteInfo {
586 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
587 count: user.invite_count as u32,
588 },
589 )?;
590 }
591 }
592 }
593 Ok(())
594 }
595
596 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
597 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
598 if let Some(invite_code) = &user.invite_code {
599 let pool = self.connection_pool.lock();
600 for connection_id in pool.user_connection_ids(user_id) {
601 self.peer.send(
602 connection_id,
603 proto::UpdateInviteInfo {
604 url: format!(
605 "{}{}",
606 self.app_state.config.invite_link_prefix, invite_code
607 ),
608 count: user.invite_count as u32,
609 },
610 )?;
611 }
612 }
613 }
614 Ok(())
615 }
616
617 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
618 ServerSnapshot {
619 connection_pool: ConnectionPoolGuard {
620 guard: self.connection_pool.lock(),
621 _not_send: PhantomData,
622 },
623 peer: &self.peer,
624 }
625 }
626}
627
628impl<'a> Deref for ConnectionPoolGuard<'a> {
629 type Target = ConnectionPool;
630
631 fn deref(&self) -> &Self::Target {
632 &*self.guard
633 }
634}
635
636impl<'a> DerefMut for ConnectionPoolGuard<'a> {
637 fn deref_mut(&mut self) -> &mut Self::Target {
638 &mut *self.guard
639 }
640}
641
642impl<'a> Drop for ConnectionPoolGuard<'a> {
643 fn drop(&mut self) {
644 #[cfg(test)]
645 self.check_invariants();
646 }
647}
648
649fn broadcast<F>(
650 sender_id: ConnectionId,
651 receiver_ids: impl IntoIterator<Item = ConnectionId>,
652 mut f: F,
653) where
654 F: FnMut(ConnectionId) -> anyhow::Result<()>,
655{
656 for receiver_id in receiver_ids {
657 if receiver_id != sender_id {
658 f(receiver_id).trace_err();
659 }
660 }
661}
662
663lazy_static! {
664 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
665}
666
667pub struct ProtocolVersion(u32);
668
669impl Header for ProtocolVersion {
670 fn name() -> &'static HeaderName {
671 &ZED_PROTOCOL_VERSION
672 }
673
674 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
675 where
676 Self: Sized,
677 I: Iterator<Item = &'i axum::http::HeaderValue>,
678 {
679 let version = values
680 .next()
681 .ok_or_else(axum::headers::Error::invalid)?
682 .to_str()
683 .map_err(|_| axum::headers::Error::invalid())?
684 .parse()
685 .map_err(|_| axum::headers::Error::invalid())?;
686 Ok(Self(version))
687 }
688
689 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
690 values.extend([self.0.to_string().parse().unwrap()]);
691 }
692}
693
694pub fn routes(server: Arc<Server>) -> Router<Body> {
695 Router::new()
696 .route("/rpc", get(handle_websocket_request))
697 .layer(
698 ServiceBuilder::new()
699 .layer(Extension(server.app_state.clone()))
700 .layer(middleware::from_fn(auth::validate_header)),
701 )
702 .route("/metrics", get(handle_metrics))
703 .layer(Extension(server))
704}
705
706pub async fn handle_websocket_request(
707 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
708 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
709 Extension(server): Extension<Arc<Server>>,
710 Extension(user): Extension<User>,
711 ws: WebSocketUpgrade,
712) -> axum::response::Response {
713 if protocol_version != rpc::PROTOCOL_VERSION {
714 return (
715 StatusCode::UPGRADE_REQUIRED,
716 "client must be upgraded".to_string(),
717 )
718 .into_response();
719 }
720 let socket_address = socket_address.to_string();
721 ws.on_upgrade(move |socket| {
722 use util::ResultExt;
723 let socket = socket
724 .map_ok(to_tungstenite_message)
725 .err_into()
726 .with(|message| async move { Ok(to_axum_message(message)) });
727 let connection = Connection::new(Box::pin(socket));
728 async move {
729 server
730 .handle_connection(connection, socket_address, user, None, Executor::Production)
731 .await
732 .log_err();
733 }
734 })
735}
736
737pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
738 let connections = server
739 .connection_pool
740 .lock()
741 .connections()
742 .filter(|connection| !connection.admin)
743 .count();
744
745 METRIC_CONNECTIONS.set(connections as _);
746
747 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
748 METRIC_SHARED_PROJECTS.set(shared_projects as _);
749
750 let encoder = prometheus::TextEncoder::new();
751 let metric_families = prometheus::gather();
752 let encoded_metrics = encoder
753 .encode_to_string(&metric_families)
754 .map_err(|err| anyhow!("{}", err))?;
755 Ok(encoded_metrics)
756}
757
758#[instrument(err, skip(executor))]
759async fn sign_out(
760 session: Session,
761 mut teardown: watch::Receiver<()>,
762 executor: Executor,
763) -> Result<()> {
764 session.peer.disconnect(session.connection_id);
765 session
766 .connection_pool()
767 .await
768 .remove_connection(session.connection_id)?;
769
770 if let Some(mut left_projects) = session
771 .db()
772 .await
773 .connection_lost(session.connection_id)
774 .await
775 .trace_err()
776 {
777 for left_project in mem::take(&mut *left_projects) {
778 project_left(&left_project, &session);
779 }
780 }
781
782 futures::select_biased! {
783 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
784 leave_room_for_session(&session).await.trace_err();
785
786 if !session
787 .connection_pool()
788 .await
789 .is_user_online(session.user_id)
790 {
791 let db = session.db().await;
792 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err() {
793 room_updated(&room, &session.peer);
794 }
795 }
796 update_user_contacts(session.user_id, &session).await?;
797 }
798 _ = teardown.changed().fuse() => {}
799 }
800
801 Ok(())
802}
803
804async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
805 response.send(proto::Ack {})?;
806 Ok(())
807}
808
809async fn create_room(
810 _request: proto::CreateRoom,
811 response: Response<proto::CreateRoom>,
812 session: Session,
813) -> Result<()> {
814 let live_kit_room = nanoid::nanoid!(30);
815 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
816 if let Some(_) = live_kit
817 .create_room(live_kit_room.clone())
818 .await
819 .trace_err()
820 {
821 if let Some(token) = live_kit
822 .room_token(&live_kit_room, &session.connection_id.to_string())
823 .trace_err()
824 {
825 Some(proto::LiveKitConnectionInfo {
826 server_url: live_kit.url().into(),
827 token,
828 })
829 } else {
830 None
831 }
832 } else {
833 None
834 }
835 } else {
836 None
837 };
838
839 {
840 let room = session
841 .db()
842 .await
843 .create_room(session.user_id, session.connection_id, &live_kit_room)
844 .await?;
845
846 response.send(proto::CreateRoomResponse {
847 room: Some(room.clone()),
848 live_kit_connection_info,
849 })?;
850 }
851
852 update_user_contacts(session.user_id, &session).await?;
853 Ok(())
854}
855
856async fn join_room(
857 request: proto::JoinRoom,
858 response: Response<proto::JoinRoom>,
859 session: Session,
860) -> Result<()> {
861 let room_id = RoomId::from_proto(request.id);
862 let room = {
863 let room = session
864 .db()
865 .await
866 .join_room(room_id, session.user_id, session.connection_id)
867 .await?;
868 room_updated(&room, &session.peer);
869 room.clone()
870 };
871
872 for connection_id in session
873 .connection_pool()
874 .await
875 .user_connection_ids(session.user_id)
876 {
877 session
878 .peer
879 .send(
880 connection_id,
881 proto::CallCanceled {
882 room_id: room_id.to_proto(),
883 },
884 )
885 .trace_err();
886 }
887
888 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
889 if let Some(token) = live_kit
890 .room_token(&room.live_kit_room, &session.connection_id.to_string())
891 .trace_err()
892 {
893 Some(proto::LiveKitConnectionInfo {
894 server_url: live_kit.url().into(),
895 token,
896 })
897 } else {
898 None
899 }
900 } else {
901 None
902 };
903
904 response.send(proto::JoinRoomResponse {
905 room: Some(room),
906 live_kit_connection_info,
907 })?;
908
909 update_user_contacts(session.user_id, &session).await?;
910 Ok(())
911}
912
913async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
914 leave_room_for_session(&session).await
915}
916
917async fn call(
918 request: proto::Call,
919 response: Response<proto::Call>,
920 session: Session,
921) -> Result<()> {
922 let room_id = RoomId::from_proto(request.room_id);
923 let calling_user_id = session.user_id;
924 let calling_connection_id = session.connection_id;
925 let called_user_id = UserId::from_proto(request.called_user_id);
926 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
927 if !session
928 .db()
929 .await
930 .has_contact(calling_user_id, called_user_id)
931 .await?
932 {
933 return Err(anyhow!("cannot call a user who isn't a contact"))?;
934 }
935
936 let incoming_call = {
937 let (room, incoming_call) = &mut *session
938 .db()
939 .await
940 .call(
941 room_id,
942 calling_user_id,
943 calling_connection_id,
944 called_user_id,
945 initial_project_id,
946 )
947 .await?;
948 room_updated(&room, &session.peer);
949 mem::take(incoming_call)
950 };
951 update_user_contacts(called_user_id, &session).await?;
952
953 let mut calls = session
954 .connection_pool()
955 .await
956 .user_connection_ids(called_user_id)
957 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
958 .collect::<FuturesUnordered<_>>();
959
960 while let Some(call_response) = calls.next().await {
961 match call_response.as_ref() {
962 Ok(_) => {
963 response.send(proto::Ack {})?;
964 return Ok(());
965 }
966 Err(_) => {
967 call_response.trace_err();
968 }
969 }
970 }
971
972 {
973 let room = session
974 .db()
975 .await
976 .call_failed(room_id, called_user_id)
977 .await?;
978 room_updated(&room, &session.peer);
979 }
980 update_user_contacts(called_user_id, &session).await?;
981
982 Err(anyhow!("failed to ring user"))?
983}
984
985async fn cancel_call(
986 request: proto::CancelCall,
987 response: Response<proto::CancelCall>,
988 session: Session,
989) -> Result<()> {
990 let called_user_id = UserId::from_proto(request.called_user_id);
991 let room_id = RoomId::from_proto(request.room_id);
992 {
993 let room = session
994 .db()
995 .await
996 .cancel_call(Some(room_id), session.connection_id, called_user_id)
997 .await?;
998 room_updated(&room, &session.peer);
999 }
1000
1001 for connection_id in session
1002 .connection_pool()
1003 .await
1004 .user_connection_ids(called_user_id)
1005 {
1006 session
1007 .peer
1008 .send(
1009 connection_id,
1010 proto::CallCanceled {
1011 room_id: room_id.to_proto(),
1012 },
1013 )
1014 .trace_err();
1015 }
1016 response.send(proto::Ack {})?;
1017
1018 update_user_contacts(called_user_id, &session).await?;
1019 Ok(())
1020}
1021
1022async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1023 let room_id = RoomId::from_proto(message.room_id);
1024 {
1025 let room = session
1026 .db()
1027 .await
1028 .decline_call(Some(room_id), session.user_id)
1029 .await?;
1030 room_updated(&room, &session.peer);
1031 }
1032
1033 for connection_id in session
1034 .connection_pool()
1035 .await
1036 .user_connection_ids(session.user_id)
1037 {
1038 session
1039 .peer
1040 .send(
1041 connection_id,
1042 proto::CallCanceled {
1043 room_id: room_id.to_proto(),
1044 },
1045 )
1046 .trace_err();
1047 }
1048 update_user_contacts(session.user_id, &session).await?;
1049 Ok(())
1050}
1051
1052async fn update_participant_location(
1053 request: proto::UpdateParticipantLocation,
1054 response: Response<proto::UpdateParticipantLocation>,
1055 session: Session,
1056) -> Result<()> {
1057 let room_id = RoomId::from_proto(request.room_id);
1058 let location = request
1059 .location
1060 .ok_or_else(|| anyhow!("invalid location"))?;
1061 let room = session
1062 .db()
1063 .await
1064 .update_room_participant_location(room_id, session.connection_id, location)
1065 .await?;
1066 room_updated(&room, &session.peer);
1067 response.send(proto::Ack {})?;
1068 Ok(())
1069}
1070
1071async fn share_project(
1072 request: proto::ShareProject,
1073 response: Response<proto::ShareProject>,
1074 session: Session,
1075) -> Result<()> {
1076 let (project_id, room) = &*session
1077 .db()
1078 .await
1079 .share_project(
1080 RoomId::from_proto(request.room_id),
1081 session.connection_id,
1082 &request.worktrees,
1083 )
1084 .await?;
1085 response.send(proto::ShareProjectResponse {
1086 project_id: project_id.to_proto(),
1087 })?;
1088 room_updated(&room, &session.peer);
1089
1090 Ok(())
1091}
1092
1093async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1094 let project_id = ProjectId::from_proto(message.project_id);
1095
1096 let (room, guest_connection_ids) = &*session
1097 .db()
1098 .await
1099 .unshare_project(project_id, session.connection_id)
1100 .await?;
1101
1102 broadcast(
1103 session.connection_id,
1104 guest_connection_ids.iter().copied(),
1105 |conn_id| session.peer.send(conn_id, message.clone()),
1106 );
1107 room_updated(&room, &session.peer);
1108
1109 Ok(())
1110}
1111
1112async fn join_project(
1113 request: proto::JoinProject,
1114 response: Response<proto::JoinProject>,
1115 session: Session,
1116) -> Result<()> {
1117 let project_id = ProjectId::from_proto(request.project_id);
1118 let guest_user_id = session.user_id;
1119
1120 tracing::info!(%project_id, "join project");
1121
1122 let (project, replica_id) = &mut *session
1123 .db()
1124 .await
1125 .join_project(project_id, session.connection_id)
1126 .await?;
1127
1128 let collaborators = project
1129 .collaborators
1130 .iter()
1131 .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32)
1132 .map(|collaborator| proto::Collaborator {
1133 peer_id: collaborator.connection_id as u32,
1134 replica_id: collaborator.replica_id.0 as u32,
1135 user_id: collaborator.user_id.to_proto(),
1136 })
1137 .collect::<Vec<_>>();
1138 let worktrees = project
1139 .worktrees
1140 .iter()
1141 .map(|(id, worktree)| proto::WorktreeMetadata {
1142 id: *id,
1143 root_name: worktree.root_name.clone(),
1144 visible: worktree.visible,
1145 abs_path: worktree.abs_path.clone(),
1146 })
1147 .collect::<Vec<_>>();
1148
1149 for collaborator in &collaborators {
1150 session
1151 .peer
1152 .send(
1153 ConnectionId(collaborator.peer_id),
1154 proto::AddProjectCollaborator {
1155 project_id: project_id.to_proto(),
1156 collaborator: Some(proto::Collaborator {
1157 peer_id: session.connection_id.0,
1158 replica_id: replica_id.0 as u32,
1159 user_id: guest_user_id.to_proto(),
1160 }),
1161 },
1162 )
1163 .trace_err();
1164 }
1165
1166 // First, we send the metadata associated with each worktree.
1167 response.send(proto::JoinProjectResponse {
1168 worktrees: worktrees.clone(),
1169 replica_id: replica_id.0 as u32,
1170 collaborators: collaborators.clone(),
1171 language_servers: project.language_servers.clone(),
1172 })?;
1173
1174 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1175 #[cfg(any(test, feature = "test-support"))]
1176 const MAX_CHUNK_SIZE: usize = 2;
1177 #[cfg(not(any(test, feature = "test-support")))]
1178 const MAX_CHUNK_SIZE: usize = 256;
1179
1180 // Stream this worktree's entries.
1181 let message = proto::UpdateWorktree {
1182 project_id: project_id.to_proto(),
1183 worktree_id,
1184 abs_path: worktree.abs_path.clone(),
1185 root_name: worktree.root_name,
1186 updated_entries: worktree.entries,
1187 removed_entries: Default::default(),
1188 scan_id: worktree.scan_id,
1189 is_last_update: worktree.is_complete,
1190 };
1191 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1192 session.peer.send(session.connection_id, update.clone())?;
1193 }
1194
1195 // Stream this worktree's diagnostics.
1196 for summary in worktree.diagnostic_summaries {
1197 session.peer.send(
1198 session.connection_id,
1199 proto::UpdateDiagnosticSummary {
1200 project_id: project_id.to_proto(),
1201 worktree_id: worktree.id,
1202 summary: Some(summary),
1203 },
1204 )?;
1205 }
1206 }
1207
1208 for language_server in &project.language_servers {
1209 session.peer.send(
1210 session.connection_id,
1211 proto::UpdateLanguageServer {
1212 project_id: project_id.to_proto(),
1213 language_server_id: language_server.id,
1214 variant: Some(
1215 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1216 proto::LspDiskBasedDiagnosticsUpdated {},
1217 ),
1218 ),
1219 },
1220 )?;
1221 }
1222
1223 Ok(())
1224}
1225
1226async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1227 let sender_id = session.connection_id;
1228 let project_id = ProjectId::from_proto(request.project_id);
1229
1230 let project = session
1231 .db()
1232 .await
1233 .leave_project(project_id, sender_id)
1234 .await?;
1235 tracing::info!(
1236 %project_id,
1237 host_user_id = %project.host_user_id,
1238 host_connection_id = %project.host_connection_id,
1239 "leave project"
1240 );
1241 project_left(&project, &session);
1242
1243 Ok(())
1244}
1245
1246async fn update_project(
1247 request: proto::UpdateProject,
1248 response: Response<proto::UpdateProject>,
1249 session: Session,
1250) -> Result<()> {
1251 let project_id = ProjectId::from_proto(request.project_id);
1252 let (room, guest_connection_ids) = &*session
1253 .db()
1254 .await
1255 .update_project(project_id, session.connection_id, &request.worktrees)
1256 .await?;
1257 broadcast(
1258 session.connection_id,
1259 guest_connection_ids.iter().copied(),
1260 |connection_id| {
1261 session
1262 .peer
1263 .forward_send(session.connection_id, connection_id, request.clone())
1264 },
1265 );
1266 room_updated(&room, &session.peer);
1267 response.send(proto::Ack {})?;
1268
1269 Ok(())
1270}
1271
1272async fn update_worktree(
1273 request: proto::UpdateWorktree,
1274 response: Response<proto::UpdateWorktree>,
1275 session: Session,
1276) -> Result<()> {
1277 let guest_connection_ids = session
1278 .db()
1279 .await
1280 .update_worktree(&request, session.connection_id)
1281 .await?;
1282
1283 broadcast(
1284 session.connection_id,
1285 guest_connection_ids.iter().copied(),
1286 |connection_id| {
1287 session
1288 .peer
1289 .forward_send(session.connection_id, connection_id, request.clone())
1290 },
1291 );
1292 response.send(proto::Ack {})?;
1293 Ok(())
1294}
1295
1296async fn update_diagnostic_summary(
1297 message: proto::UpdateDiagnosticSummary,
1298 session: Session,
1299) -> Result<()> {
1300 let guest_connection_ids = session
1301 .db()
1302 .await
1303 .update_diagnostic_summary(&message, session.connection_id)
1304 .await?;
1305
1306 broadcast(
1307 session.connection_id,
1308 guest_connection_ids.iter().copied(),
1309 |connection_id| {
1310 session
1311 .peer
1312 .forward_send(session.connection_id, connection_id, message.clone())
1313 },
1314 );
1315
1316 Ok(())
1317}
1318
1319async fn start_language_server(
1320 request: proto::StartLanguageServer,
1321 session: Session,
1322) -> Result<()> {
1323 let guest_connection_ids = session
1324 .db()
1325 .await
1326 .start_language_server(&request, session.connection_id)
1327 .await?;
1328
1329 broadcast(
1330 session.connection_id,
1331 guest_connection_ids.iter().copied(),
1332 |connection_id| {
1333 session
1334 .peer
1335 .forward_send(session.connection_id, connection_id, request.clone())
1336 },
1337 );
1338 Ok(())
1339}
1340
1341async fn update_language_server(
1342 request: proto::UpdateLanguageServer,
1343 session: Session,
1344) -> Result<()> {
1345 let project_id = ProjectId::from_proto(request.project_id);
1346 let project_connection_ids = session
1347 .db()
1348 .await
1349 .project_connection_ids(project_id, session.connection_id)
1350 .await?;
1351 broadcast(
1352 session.connection_id,
1353 project_connection_ids.iter().copied(),
1354 |connection_id| {
1355 session
1356 .peer
1357 .forward_send(session.connection_id, connection_id, request.clone())
1358 },
1359 );
1360 Ok(())
1361}
1362
1363async fn forward_project_request<T>(
1364 request: T,
1365 response: Response<T>,
1366 session: Session,
1367) -> Result<()>
1368where
1369 T: EntityMessage + RequestMessage,
1370{
1371 let project_id = ProjectId::from_proto(request.remote_entity_id());
1372 let host_connection_id = {
1373 let collaborators = session
1374 .db()
1375 .await
1376 .project_collaborators(project_id, session.connection_id)
1377 .await?;
1378 ConnectionId(
1379 collaborators
1380 .iter()
1381 .find(|collaborator| collaborator.is_host)
1382 .ok_or_else(|| anyhow!("host not found"))?
1383 .connection_id as u32,
1384 )
1385 };
1386
1387 let payload = session
1388 .peer
1389 .forward_request(session.connection_id, host_connection_id, request)
1390 .await?;
1391
1392 response.send(payload)?;
1393 Ok(())
1394}
1395
1396async fn save_buffer(
1397 request: proto::SaveBuffer,
1398 response: Response<proto::SaveBuffer>,
1399 session: Session,
1400) -> Result<()> {
1401 let project_id = ProjectId::from_proto(request.project_id);
1402 let host_connection_id = {
1403 let collaborators = session
1404 .db()
1405 .await
1406 .project_collaborators(project_id, session.connection_id)
1407 .await?;
1408 let host = collaborators
1409 .iter()
1410 .find(|collaborator| collaborator.is_host)
1411 .ok_or_else(|| anyhow!("host not found"))?;
1412 ConnectionId(host.connection_id as u32)
1413 };
1414 let response_payload = session
1415 .peer
1416 .forward_request(session.connection_id, host_connection_id, request.clone())
1417 .await?;
1418
1419 let mut collaborators = session
1420 .db()
1421 .await
1422 .project_collaborators(project_id, session.connection_id)
1423 .await?;
1424 collaborators
1425 .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32);
1426 let project_connection_ids = collaborators
1427 .iter()
1428 .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
1429 broadcast(host_connection_id, project_connection_ids, |conn_id| {
1430 session
1431 .peer
1432 .forward_send(host_connection_id, conn_id, response_payload.clone())
1433 });
1434 response.send(response_payload)?;
1435 Ok(())
1436}
1437
1438async fn create_buffer_for_peer(
1439 request: proto::CreateBufferForPeer,
1440 session: Session,
1441) -> Result<()> {
1442 session.peer.forward_send(
1443 session.connection_id,
1444 ConnectionId(request.peer_id),
1445 request,
1446 )?;
1447 Ok(())
1448}
1449
1450async fn update_buffer(
1451 request: proto::UpdateBuffer,
1452 response: Response<proto::UpdateBuffer>,
1453 session: Session,
1454) -> Result<()> {
1455 let project_id = ProjectId::from_proto(request.project_id);
1456 let project_connection_ids = session
1457 .db()
1458 .await
1459 .project_connection_ids(project_id, session.connection_id)
1460 .await?;
1461
1462 broadcast(
1463 session.connection_id,
1464 project_connection_ids.iter().copied(),
1465 |connection_id| {
1466 session
1467 .peer
1468 .forward_send(session.connection_id, connection_id, request.clone())
1469 },
1470 );
1471 response.send(proto::Ack {})?;
1472 Ok(())
1473}
1474
1475async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1476 let project_id = ProjectId::from_proto(request.project_id);
1477 let project_connection_ids = session
1478 .db()
1479 .await
1480 .project_connection_ids(project_id, session.connection_id)
1481 .await?;
1482
1483 broadcast(
1484 session.connection_id,
1485 project_connection_ids.iter().copied(),
1486 |connection_id| {
1487 session
1488 .peer
1489 .forward_send(session.connection_id, connection_id, request.clone())
1490 },
1491 );
1492 Ok(())
1493}
1494
1495async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1496 let project_id = ProjectId::from_proto(request.project_id);
1497 let project_connection_ids = session
1498 .db()
1499 .await
1500 .project_connection_ids(project_id, session.connection_id)
1501 .await?;
1502 broadcast(
1503 session.connection_id,
1504 project_connection_ids.iter().copied(),
1505 |connection_id| {
1506 session
1507 .peer
1508 .forward_send(session.connection_id, connection_id, request.clone())
1509 },
1510 );
1511 Ok(())
1512}
1513
1514async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1515 let project_id = ProjectId::from_proto(request.project_id);
1516 let project_connection_ids = session
1517 .db()
1518 .await
1519 .project_connection_ids(project_id, session.connection_id)
1520 .await?;
1521 broadcast(
1522 session.connection_id,
1523 project_connection_ids.iter().copied(),
1524 |connection_id| {
1525 session
1526 .peer
1527 .forward_send(session.connection_id, connection_id, request.clone())
1528 },
1529 );
1530 Ok(())
1531}
1532
1533async fn follow(
1534 request: proto::Follow,
1535 response: Response<proto::Follow>,
1536 session: Session,
1537) -> Result<()> {
1538 let project_id = ProjectId::from_proto(request.project_id);
1539 let leader_id = ConnectionId(request.leader_id);
1540 let follower_id = session.connection_id;
1541 {
1542 let project_connection_ids = session
1543 .db()
1544 .await
1545 .project_connection_ids(project_id, session.connection_id)
1546 .await?;
1547
1548 if !project_connection_ids.contains(&leader_id) {
1549 Err(anyhow!("no such peer"))?;
1550 }
1551 }
1552
1553 let mut response_payload = session
1554 .peer
1555 .forward_request(session.connection_id, leader_id, request)
1556 .await?;
1557 response_payload
1558 .views
1559 .retain(|view| view.leader_id != Some(follower_id.0));
1560 response.send(response_payload)?;
1561 Ok(())
1562}
1563
1564async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1565 let project_id = ProjectId::from_proto(request.project_id);
1566 let leader_id = ConnectionId(request.leader_id);
1567 let project_connection_ids = session
1568 .db()
1569 .await
1570 .project_connection_ids(project_id, session.connection_id)
1571 .await?;
1572 if !project_connection_ids.contains(&leader_id) {
1573 Err(anyhow!("no such peer"))?;
1574 }
1575 session
1576 .peer
1577 .forward_send(session.connection_id, leader_id, request)?;
1578 Ok(())
1579}
1580
1581async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1582 let project_id = ProjectId::from_proto(request.project_id);
1583 let project_connection_ids = session
1584 .db
1585 .lock()
1586 .await
1587 .project_connection_ids(project_id, session.connection_id)
1588 .await?;
1589
1590 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1591 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1592 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1593 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1594 });
1595 for follower_id in &request.follower_ids {
1596 let follower_id = ConnectionId(*follower_id);
1597 if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1598 session
1599 .peer
1600 .forward_send(session.connection_id, follower_id, request.clone())?;
1601 }
1602 }
1603 Ok(())
1604}
1605
1606async fn get_users(
1607 request: proto::GetUsers,
1608 response: Response<proto::GetUsers>,
1609 session: Session,
1610) -> Result<()> {
1611 let user_ids = request
1612 .user_ids
1613 .into_iter()
1614 .map(UserId::from_proto)
1615 .collect();
1616 let users = session
1617 .db()
1618 .await
1619 .get_users_by_ids(user_ids)
1620 .await?
1621 .into_iter()
1622 .map(|user| proto::User {
1623 id: user.id.to_proto(),
1624 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1625 github_login: user.github_login,
1626 })
1627 .collect();
1628 response.send(proto::UsersResponse { users })?;
1629 Ok(())
1630}
1631
1632async fn fuzzy_search_users(
1633 request: proto::FuzzySearchUsers,
1634 response: Response<proto::FuzzySearchUsers>,
1635 session: Session,
1636) -> Result<()> {
1637 let query = request.query;
1638 let users = match query.len() {
1639 0 => vec![],
1640 1 | 2 => session
1641 .db()
1642 .await
1643 .get_user_by_github_account(&query, None)
1644 .await?
1645 .into_iter()
1646 .collect(),
1647 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1648 };
1649 let users = users
1650 .into_iter()
1651 .filter(|user| user.id != session.user_id)
1652 .map(|user| proto::User {
1653 id: user.id.to_proto(),
1654 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1655 github_login: user.github_login,
1656 })
1657 .collect();
1658 response.send(proto::UsersResponse { users })?;
1659 Ok(())
1660}
1661
1662async fn request_contact(
1663 request: proto::RequestContact,
1664 response: Response<proto::RequestContact>,
1665 session: Session,
1666) -> Result<()> {
1667 let requester_id = session.user_id;
1668 let responder_id = UserId::from_proto(request.responder_id);
1669 if requester_id == responder_id {
1670 return Err(anyhow!("cannot add yourself as a contact"))?;
1671 }
1672
1673 session
1674 .db()
1675 .await
1676 .send_contact_request(requester_id, responder_id)
1677 .await?;
1678
1679 // Update outgoing contact requests of requester
1680 let mut update = proto::UpdateContacts::default();
1681 update.outgoing_requests.push(responder_id.to_proto());
1682 for connection_id in session
1683 .connection_pool()
1684 .await
1685 .user_connection_ids(requester_id)
1686 {
1687 session.peer.send(connection_id, update.clone())?;
1688 }
1689
1690 // Update incoming contact requests of responder
1691 let mut update = proto::UpdateContacts::default();
1692 update
1693 .incoming_requests
1694 .push(proto::IncomingContactRequest {
1695 requester_id: requester_id.to_proto(),
1696 should_notify: true,
1697 });
1698 for connection_id in session
1699 .connection_pool()
1700 .await
1701 .user_connection_ids(responder_id)
1702 {
1703 session.peer.send(connection_id, update.clone())?;
1704 }
1705
1706 response.send(proto::Ack {})?;
1707 Ok(())
1708}
1709
1710async fn respond_to_contact_request(
1711 request: proto::RespondToContactRequest,
1712 response: Response<proto::RespondToContactRequest>,
1713 session: Session,
1714) -> Result<()> {
1715 let responder_id = session.user_id;
1716 let requester_id = UserId::from_proto(request.requester_id);
1717 let db = session.db().await;
1718 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1719 db.dismiss_contact_notification(responder_id, requester_id)
1720 .await?;
1721 } else {
1722 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1723
1724 db.respond_to_contact_request(responder_id, requester_id, accept)
1725 .await?;
1726 let requester_busy = db.is_user_busy(requester_id).await?;
1727 let responder_busy = db.is_user_busy(responder_id).await?;
1728
1729 let pool = session.connection_pool().await;
1730 // Update responder with new contact
1731 let mut update = proto::UpdateContacts::default();
1732 if accept {
1733 update
1734 .contacts
1735 .push(contact_for_user(requester_id, false, requester_busy, &pool));
1736 }
1737 update
1738 .remove_incoming_requests
1739 .push(requester_id.to_proto());
1740 for connection_id in pool.user_connection_ids(responder_id) {
1741 session.peer.send(connection_id, update.clone())?;
1742 }
1743
1744 // Update requester with new contact
1745 let mut update = proto::UpdateContacts::default();
1746 if accept {
1747 update
1748 .contacts
1749 .push(contact_for_user(responder_id, true, responder_busy, &pool));
1750 }
1751 update
1752 .remove_outgoing_requests
1753 .push(responder_id.to_proto());
1754 for connection_id in pool.user_connection_ids(requester_id) {
1755 session.peer.send(connection_id, update.clone())?;
1756 }
1757 }
1758
1759 response.send(proto::Ack {})?;
1760 Ok(())
1761}
1762
1763async fn remove_contact(
1764 request: proto::RemoveContact,
1765 response: Response<proto::RemoveContact>,
1766 session: Session,
1767) -> Result<()> {
1768 let requester_id = session.user_id;
1769 let responder_id = UserId::from_proto(request.user_id);
1770 let db = session.db().await;
1771 db.remove_contact(requester_id, responder_id).await?;
1772
1773 let pool = session.connection_pool().await;
1774 // Update outgoing contact requests of requester
1775 let mut update = proto::UpdateContacts::default();
1776 update
1777 .remove_outgoing_requests
1778 .push(responder_id.to_proto());
1779 for connection_id in pool.user_connection_ids(requester_id) {
1780 session.peer.send(connection_id, update.clone())?;
1781 }
1782
1783 // Update incoming contact requests of responder
1784 let mut update = proto::UpdateContacts::default();
1785 update
1786 .remove_incoming_requests
1787 .push(requester_id.to_proto());
1788 for connection_id in pool.user_connection_ids(responder_id) {
1789 session.peer.send(connection_id, update.clone())?;
1790 }
1791
1792 response.send(proto::Ack {})?;
1793 Ok(())
1794}
1795
1796async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1797 let project_id = ProjectId::from_proto(request.project_id);
1798 let project_connection_ids = session
1799 .db()
1800 .await
1801 .project_connection_ids(project_id, session.connection_id)
1802 .await?;
1803 broadcast(
1804 session.connection_id,
1805 project_connection_ids.iter().copied(),
1806 |connection_id| {
1807 session
1808 .peer
1809 .forward_send(session.connection_id, connection_id, request.clone())
1810 },
1811 );
1812 Ok(())
1813}
1814
1815async fn get_private_user_info(
1816 _request: proto::GetPrivateUserInfo,
1817 response: Response<proto::GetPrivateUserInfo>,
1818 session: Session,
1819) -> Result<()> {
1820 let metrics_id = session
1821 .db()
1822 .await
1823 .get_user_metrics_id(session.user_id)
1824 .await?;
1825 let user = session
1826 .db()
1827 .await
1828 .get_user_by_id(session.user_id)
1829 .await?
1830 .ok_or_else(|| anyhow!("user not found"))?;
1831 response.send(proto::GetPrivateUserInfoResponse {
1832 metrics_id,
1833 staff: user.admin,
1834 })?;
1835 Ok(())
1836}
1837
1838fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1839 match message {
1840 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1841 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1842 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1843 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1844 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1845 code: frame.code.into(),
1846 reason: frame.reason,
1847 })),
1848 }
1849}
1850
1851fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1852 match message {
1853 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1854 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1855 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1856 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1857 AxumMessage::Close(frame) => {
1858 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1859 code: frame.code.into(),
1860 reason: frame.reason,
1861 }))
1862 }
1863 }
1864}
1865
1866fn build_initial_contacts_update(
1867 contacts: Vec<db::Contact>,
1868 pool: &ConnectionPool,
1869) -> proto::UpdateContacts {
1870 let mut update = proto::UpdateContacts::default();
1871
1872 for contact in contacts {
1873 match contact {
1874 db::Contact::Accepted {
1875 user_id,
1876 should_notify,
1877 busy,
1878 } => {
1879 update
1880 .contacts
1881 .push(contact_for_user(user_id, should_notify, busy, &pool));
1882 }
1883 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
1884 db::Contact::Incoming {
1885 user_id,
1886 should_notify,
1887 } => update
1888 .incoming_requests
1889 .push(proto::IncomingContactRequest {
1890 requester_id: user_id.to_proto(),
1891 should_notify,
1892 }),
1893 }
1894 }
1895
1896 update
1897}
1898
1899fn contact_for_user(
1900 user_id: UserId,
1901 should_notify: bool,
1902 busy: bool,
1903 pool: &ConnectionPool,
1904) -> proto::Contact {
1905 proto::Contact {
1906 user_id: user_id.to_proto(),
1907 online: pool.is_user_online(user_id),
1908 busy,
1909 should_notify,
1910 }
1911}
1912
1913fn room_updated(room: &proto::Room, peer: &Peer) {
1914 for participant in &room.participants {
1915 peer.send(
1916 ConnectionId(participant.peer_id),
1917 proto::RoomUpdated {
1918 room: Some(room.clone()),
1919 },
1920 )
1921 .trace_err();
1922 }
1923}
1924
1925async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
1926 let db = session.db().await;
1927 let contacts = db.get_contacts(user_id).await?;
1928 let busy = db.is_user_busy(user_id).await?;
1929
1930 let pool = session.connection_pool().await;
1931 let updated_contact = contact_for_user(user_id, false, busy, &pool);
1932 for contact in contacts {
1933 if let db::Contact::Accepted {
1934 user_id: contact_user_id,
1935 ..
1936 } = contact
1937 {
1938 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
1939 session
1940 .peer
1941 .send(
1942 contact_conn_id,
1943 proto::UpdateContacts {
1944 contacts: vec![updated_contact.clone()],
1945 remove_contacts: Default::default(),
1946 incoming_requests: Default::default(),
1947 remove_incoming_requests: Default::default(),
1948 outgoing_requests: Default::default(),
1949 remove_outgoing_requests: Default::default(),
1950 },
1951 )
1952 .trace_err();
1953 }
1954 }
1955 }
1956 Ok(())
1957}
1958
1959async fn leave_room_for_session(session: &Session) -> Result<()> {
1960 let mut contacts_to_update = HashSet::default();
1961
1962 let room_id;
1963 let canceled_calls_to_user_ids;
1964 let live_kit_room;
1965 let delete_live_kit_room;
1966 {
1967 let mut left_room = session.db().await.leave_room(session.connection_id).await?;
1968 contacts_to_update.insert(session.user_id);
1969
1970 for project in left_room.left_projects.values() {
1971 project_left(project, session);
1972 }
1973
1974 room_updated(&left_room.room, &session.peer);
1975 room_id = RoomId::from_proto(left_room.room.id);
1976 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
1977 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
1978 delete_live_kit_room = left_room.room.participants.is_empty();
1979 }
1980
1981 {
1982 let pool = session.connection_pool().await;
1983 for canceled_user_id in canceled_calls_to_user_ids {
1984 for connection_id in pool.user_connection_ids(canceled_user_id) {
1985 session
1986 .peer
1987 .send(
1988 connection_id,
1989 proto::CallCanceled {
1990 room_id: room_id.to_proto(),
1991 },
1992 )
1993 .trace_err();
1994 }
1995 contacts_to_update.insert(canceled_user_id);
1996 }
1997 }
1998
1999 for contact_user_id in contacts_to_update {
2000 update_user_contacts(contact_user_id, &session).await?;
2001 }
2002
2003 if let Some(live_kit) = session.live_kit_client.as_ref() {
2004 live_kit
2005 .remove_participant(live_kit_room.clone(), session.connection_id.to_string())
2006 .await
2007 .trace_err();
2008
2009 if delete_live_kit_room {
2010 live_kit.delete_room(live_kit_room).await.trace_err();
2011 }
2012 }
2013
2014 Ok(())
2015}
2016
2017fn project_left(project: &db::LeftProject, session: &Session) {
2018 for connection_id in &project.connection_ids {
2019 if project.host_user_id == session.user_id {
2020 session
2021 .peer
2022 .send(
2023 *connection_id,
2024 proto::UnshareProject {
2025 project_id: project.id.to_proto(),
2026 },
2027 )
2028 .trace_err();
2029 } else {
2030 session
2031 .peer
2032 .send(
2033 *connection_id,
2034 proto::RemoveProjectCollaborator {
2035 project_id: project.id.to_proto(),
2036 peer_id: session.connection_id.0,
2037 },
2038 )
2039 .trace_err();
2040 }
2041 }
2042
2043 session
2044 .peer
2045 .send(
2046 session.connection_id,
2047 proto::UnshareProject {
2048 project_id: project.id.to_proto(),
2049 },
2050 )
2051 .trace_err();
2052}
2053
2054pub trait ResultExt {
2055 type Ok;
2056
2057 fn trace_err(self) -> Option<Self::Ok>;
2058}
2059
2060impl<T, E> ResultExt for Result<T, E>
2061where
2062 E: std::fmt::Debug,
2063{
2064 type Ok = T;
2065
2066 fn trace_err(self) -> Option<T> {
2067 match self {
2068 Ok(value) => Some(value),
2069 Err(error) => {
2070 tracing::error!("{:?}", error);
2071 None
2072 }
2073 }
2074 }
2075}