1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, Database, ProjectId, RoomId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
38 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
39};
40use serde::{Serialize, Serializer};
41use std::{
42 any::TypeId,
43 fmt,
44 future::Future,
45 marker::PhantomData,
46 mem,
47 net::SocketAddr,
48 ops::{Deref, DerefMut},
49 rc::Rc,
50 sync::{
51 atomic::{AtomicBool, Ordering::SeqCst},
52 Arc,
53 },
54 time::Duration,
55};
56use tokio::sync::watch;
57use tower::ServiceBuilder;
58use tracing::{info_span, instrument, Instrument};
59
60pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(5);
61pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
62
63lazy_static! {
64 static ref METRIC_CONNECTIONS: IntGauge =
65 register_int_gauge!("connections", "number of connections").unwrap();
66 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
67 "shared_projects",
68 "number of open projects with one or more guests"
69 )
70 .unwrap();
71}
72
73type MessageHandler =
74 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
75
76struct Response<R> {
77 peer: Arc<Peer>,
78 receipt: Receipt<R>,
79 responded: Arc<AtomicBool>,
80}
81
82impl<R: RequestMessage> Response<R> {
83 fn send(self, payload: R::Response) -> Result<()> {
84 self.responded.store(true, SeqCst);
85 self.peer.respond(self.receipt, payload)?;
86 Ok(())
87 }
88}
89
90#[derive(Clone)]
91struct Session {
92 user_id: UserId,
93 connection_id: ConnectionId,
94 db: Arc<tokio::sync::Mutex<DbHandle>>,
95 peer: Arc<Peer>,
96 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
97 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
98}
99
100impl Session {
101 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
102 #[cfg(test)]
103 tokio::task::yield_now().await;
104 let guard = self.db.lock().await;
105 #[cfg(test)]
106 tokio::task::yield_now().await;
107 guard
108 }
109
110 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
111 #[cfg(test)]
112 tokio::task::yield_now().await;
113 let guard = self.connection_pool.lock();
114 ConnectionPoolGuard {
115 guard,
116 _not_send: PhantomData,
117 }
118 }
119}
120
121impl fmt::Debug for Session {
122 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123 f.debug_struct("Session")
124 .field("user_id", &self.user_id)
125 .field("connection_id", &self.connection_id)
126 .finish()
127 }
128}
129
130struct DbHandle(Arc<Database>);
131
132impl Deref for DbHandle {
133 type Target = Database;
134
135 fn deref(&self) -> &Self::Target {
136 self.0.as_ref()
137 }
138}
139
140pub struct Server {
141 peer: Arc<Peer>,
142 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
143 app_state: Arc<AppState>,
144 executor: Executor,
145 handlers: HashMap<TypeId, MessageHandler>,
146 teardown: watch::Sender<()>,
147}
148
149pub(crate) struct ConnectionPoolGuard<'a> {
150 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
151 _not_send: PhantomData<Rc<()>>,
152}
153
154#[derive(Serialize)]
155pub struct ServerSnapshot<'a> {
156 peer: &'a Peer,
157 #[serde(serialize_with = "serialize_deref")]
158 connection_pool: ConnectionPoolGuard<'a>,
159}
160
161pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
162where
163 S: Serializer,
164 T: Deref<Target = U>,
165 U: Serialize,
166{
167 Serialize::serialize(value.deref(), serializer)
168}
169
170impl Server {
171 pub fn new(app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
172 let mut server = Self {
173 peer: Peer::new(),
174 app_state,
175 executor,
176 connection_pool: Default::default(),
177 handlers: Default::default(),
178 teardown: watch::channel(()).0,
179 };
180
181 server
182 .add_request_handler(ping)
183 .add_request_handler(create_room)
184 .add_request_handler(join_room)
185 .add_message_handler(leave_room)
186 .add_request_handler(call)
187 .add_request_handler(cancel_call)
188 .add_message_handler(decline_call)
189 .add_request_handler(update_participant_location)
190 .add_request_handler(share_project)
191 .add_message_handler(unshare_project)
192 .add_request_handler(join_project)
193 .add_message_handler(leave_project)
194 .add_request_handler(update_project)
195 .add_request_handler(update_worktree)
196 .add_message_handler(start_language_server)
197 .add_message_handler(update_language_server)
198 .add_message_handler(update_diagnostic_summary)
199 .add_request_handler(forward_project_request::<proto::GetHover>)
200 .add_request_handler(forward_project_request::<proto::GetDefinition>)
201 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
202 .add_request_handler(forward_project_request::<proto::GetReferences>)
203 .add_request_handler(forward_project_request::<proto::SearchProject>)
204 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
205 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
206 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
207 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
208 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
209 .add_request_handler(forward_project_request::<proto::GetCompletions>)
210 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
211 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
212 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
213 .add_request_handler(forward_project_request::<proto::PrepareRename>)
214 .add_request_handler(forward_project_request::<proto::PerformRename>)
215 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
216 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
217 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
218 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
219 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
220 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
221 .add_message_handler(create_buffer_for_peer)
222 .add_request_handler(update_buffer)
223 .add_message_handler(update_buffer_file)
224 .add_message_handler(buffer_reloaded)
225 .add_message_handler(buffer_saved)
226 .add_request_handler(save_buffer)
227 .add_request_handler(get_users)
228 .add_request_handler(fuzzy_search_users)
229 .add_request_handler(request_contact)
230 .add_request_handler(remove_contact)
231 .add_request_handler(respond_to_contact_request)
232 .add_request_handler(follow)
233 .add_message_handler(unfollow)
234 .add_message_handler(update_followers)
235 .add_message_handler(update_diff_base)
236 .add_request_handler(get_private_user_info);
237
238 Arc::new(server)
239 }
240
241 pub fn start(&self) {
242 let db = self.app_state.db.clone();
243 let peer = self.peer.clone();
244 let pool = self.connection_pool.clone();
245 let live_kit_client = self.app_state.live_kit_client.clone();
246 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
247 self.executor.spawn_detached(async move {
248 db.delete_stale_projects().await.trace_err();
249 timeout.await;
250 if let Some(room_ids) = db.stale_room_ids().await.trace_err() {
251 for room_id in room_ids {
252 let mut contacts_to_update = HashSet::default();
253 let mut canceled_calls_to_user_ids = Vec::new();
254 let mut live_kit_room = String::new();
255 let mut delete_live_kit_room = false;
256
257 if let Ok(mut refreshed_room) = db.refresh_room(room_id).await {
258 room_updated(&refreshed_room.room, &peer);
259 contacts_to_update
260 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
261 contacts_to_update
262 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
263 canceled_calls_to_user_ids =
264 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
265 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
266 delete_live_kit_room = refreshed_room.room.participants.is_empty();
267 }
268
269 {
270 let pool = pool.lock();
271 for canceled_user_id in canceled_calls_to_user_ids {
272 for connection_id in pool.user_connection_ids(canceled_user_id) {
273 peer.send(
274 connection_id,
275 proto::CallCanceled {
276 room_id: room_id.to_proto(),
277 },
278 )
279 .trace_err();
280 }
281 }
282 }
283
284 for user_id in contacts_to_update {
285 let busy = db.is_user_busy(user_id).await.trace_err();
286 let contacts = db.get_contacts(user_id).await.trace_err();
287 if let Some((busy, contacts)) = busy.zip(contacts) {
288 let pool = pool.lock();
289 let updated_contact = contact_for_user(user_id, false, busy, &pool);
290 for contact in contacts {
291 if let db::Contact::Accepted {
292 user_id: contact_user_id,
293 ..
294 } = contact
295 {
296 for contact_conn_id in pool.user_connection_ids(contact_user_id)
297 {
298 peer.send(
299 contact_conn_id,
300 proto::UpdateContacts {
301 contacts: vec![updated_contact.clone()],
302 remove_contacts: Default::default(),
303 incoming_requests: Default::default(),
304 remove_incoming_requests: Default::default(),
305 outgoing_requests: Default::default(),
306 remove_outgoing_requests: Default::default(),
307 },
308 )
309 .trace_err();
310 }
311 }
312 }
313 }
314 }
315
316 if let Some(live_kit) = live_kit_client.as_ref() {
317 if delete_live_kit_room {
318 live_kit.delete_room(live_kit_room).await.trace_err();
319 }
320 }
321 }
322 }
323 });
324 }
325
326 pub fn teardown(&self) {
327 self.peer.reset();
328 self.connection_pool.lock().reset();
329 let _ = self.teardown.send(());
330 }
331
332 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
333 where
334 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
335 Fut: 'static + Send + Future<Output = Result<()>>,
336 M: EnvelopedMessage,
337 {
338 let prev_handler = self.handlers.insert(
339 TypeId::of::<M>(),
340 Box::new(move |envelope, session| {
341 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
342 let span = info_span!(
343 "handle message",
344 payload_type = envelope.payload_type_name()
345 );
346 span.in_scope(|| {
347 tracing::info!(
348 payload_type = envelope.payload_type_name(),
349 "message received"
350 );
351 });
352 let future = (handler)(*envelope, session);
353 async move {
354 if let Err(error) = future.await {
355 tracing::error!(%error, "error handling message");
356 }
357 }
358 .instrument(span)
359 .boxed()
360 }),
361 );
362 if prev_handler.is_some() {
363 panic!("registered a handler for the same message twice");
364 }
365 self
366 }
367
368 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
369 where
370 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
371 Fut: 'static + Send + Future<Output = Result<()>>,
372 M: EnvelopedMessage,
373 {
374 self.add_handler(move |envelope, session| handler(envelope.payload, session));
375 self
376 }
377
378 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
379 where
380 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
381 Fut: Send + Future<Output = Result<()>>,
382 M: RequestMessage,
383 {
384 let handler = Arc::new(handler);
385 self.add_handler(move |envelope, session| {
386 let receipt = envelope.receipt();
387 let handler = handler.clone();
388 async move {
389 let peer = session.peer.clone();
390 let responded = Arc::new(AtomicBool::default());
391 let response = Response {
392 peer: peer.clone(),
393 responded: responded.clone(),
394 receipt,
395 };
396 match (handler)(envelope.payload, response, session).await {
397 Ok(()) => {
398 if responded.load(std::sync::atomic::Ordering::SeqCst) {
399 Ok(())
400 } else {
401 Err(anyhow!("handler did not send a response"))?
402 }
403 }
404 Err(error) => {
405 peer.respond_with_error(
406 receipt,
407 proto::Error {
408 message: error.to_string(),
409 },
410 )?;
411 Err(error)
412 }
413 }
414 }
415 })
416 }
417
418 pub fn handle_connection(
419 self: &Arc<Self>,
420 connection: Connection,
421 address: String,
422 user: User,
423 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
424 executor: Executor,
425 ) -> impl Future<Output = Result<()>> {
426 let this = self.clone();
427 let user_id = user.id;
428 let login = user.github_login;
429 let span = info_span!("handle connection", %user_id, %login, %address);
430 let mut teardown = self.teardown.subscribe();
431 async move {
432 let (connection_id, handle_io, mut incoming_rx) = this
433 .peer
434 .add_connection(connection, {
435 let executor = executor.clone();
436 move |duration| executor.sleep(duration)
437 });
438
439 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
440 this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
441 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
442
443 if let Some(send_connection_id) = send_connection_id.take() {
444 let _ = send_connection_id.send(connection_id);
445 }
446
447 if !user.connected_once {
448 this.peer.send(connection_id, proto::ShowContacts {})?;
449 this.app_state.db.set_user_connected_once(user_id, true).await?;
450 }
451
452 let (contacts, invite_code) = future::try_join(
453 this.app_state.db.get_contacts(user_id),
454 this.app_state.db.get_invite_code_for_user(user_id)
455 ).await?;
456
457 {
458 let mut pool = this.connection_pool.lock();
459 pool.add_connection(connection_id, user_id, user.admin);
460 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
461
462 if let Some((code, count)) = invite_code {
463 this.peer.send(connection_id, proto::UpdateInviteInfo {
464 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
465 count: count as u32,
466 })?;
467 }
468 }
469
470 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
471 this.peer.send(connection_id, incoming_call)?;
472 }
473
474 let session = Session {
475 user_id,
476 connection_id,
477 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
478 peer: this.peer.clone(),
479 connection_pool: this.connection_pool.clone(),
480 live_kit_client: this.app_state.live_kit_client.clone()
481 };
482 update_user_contacts(user_id, &session).await?;
483
484 let handle_io = handle_io.fuse();
485 futures::pin_mut!(handle_io);
486
487 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
488 // This prevents deadlocks when e.g., client A performs a request to client B and
489 // client B performs a request to client A. If both clients stop processing further
490 // messages until their respective request completes, they won't have a chance to
491 // respond to the other client's request and cause a deadlock.
492 //
493 // This arrangement ensures we will attempt to process earlier messages first, but fall
494 // back to processing messages arrived later in the spirit of making progress.
495 let mut foreground_message_handlers = FuturesUnordered::new();
496 loop {
497 let next_message = incoming_rx.next().fuse();
498 futures::pin_mut!(next_message);
499 futures::select_biased! {
500 _ = teardown.changed().fuse() => return Ok(()),
501 result = handle_io => {
502 if let Err(error) = result {
503 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
504 }
505 break;
506 }
507 _ = foreground_message_handlers.next() => {}
508 message = next_message => {
509 if let Some(message) = message {
510 let type_name = message.payload_type_name();
511 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
512 let span_enter = span.enter();
513 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
514 let is_background = message.is_background();
515 let handle_message = (handler)(message, session.clone());
516 drop(span_enter);
517
518 let handle_message = handle_message.instrument(span);
519 if is_background {
520 executor.spawn_detached(handle_message);
521 } else {
522 foreground_message_handlers.push(handle_message);
523 }
524 } else {
525 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
526 }
527 } else {
528 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
529 break;
530 }
531 }
532 }
533 }
534
535 drop(foreground_message_handlers);
536 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
537 if let Err(error) = sign_out(session, teardown, executor).await {
538 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
539 }
540
541 Ok(())
542 }.instrument(span)
543 }
544
545 pub async fn invite_code_redeemed(
546 self: &Arc<Self>,
547 inviter_id: UserId,
548 invitee_id: UserId,
549 ) -> Result<()> {
550 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
551 if let Some(code) = &user.invite_code {
552 let pool = self.connection_pool.lock();
553 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
554 for connection_id in pool.user_connection_ids(inviter_id) {
555 self.peer.send(
556 connection_id,
557 proto::UpdateContacts {
558 contacts: vec![invitee_contact.clone()],
559 ..Default::default()
560 },
561 )?;
562 self.peer.send(
563 connection_id,
564 proto::UpdateInviteInfo {
565 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
566 count: user.invite_count as u32,
567 },
568 )?;
569 }
570 }
571 }
572 Ok(())
573 }
574
575 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
576 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
577 if let Some(invite_code) = &user.invite_code {
578 let pool = self.connection_pool.lock();
579 for connection_id in pool.user_connection_ids(user_id) {
580 self.peer.send(
581 connection_id,
582 proto::UpdateInviteInfo {
583 url: format!(
584 "{}{}",
585 self.app_state.config.invite_link_prefix, invite_code
586 ),
587 count: user.invite_count as u32,
588 },
589 )?;
590 }
591 }
592 }
593 Ok(())
594 }
595
596 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
597 ServerSnapshot {
598 connection_pool: ConnectionPoolGuard {
599 guard: self.connection_pool.lock(),
600 _not_send: PhantomData,
601 },
602 peer: &self.peer,
603 }
604 }
605}
606
607impl<'a> Deref for ConnectionPoolGuard<'a> {
608 type Target = ConnectionPool;
609
610 fn deref(&self) -> &Self::Target {
611 &*self.guard
612 }
613}
614
615impl<'a> DerefMut for ConnectionPoolGuard<'a> {
616 fn deref_mut(&mut self) -> &mut Self::Target {
617 &mut *self.guard
618 }
619}
620
621impl<'a> Drop for ConnectionPoolGuard<'a> {
622 fn drop(&mut self) {
623 #[cfg(test)]
624 self.check_invariants();
625 }
626}
627
628fn broadcast<F>(
629 sender_id: ConnectionId,
630 receiver_ids: impl IntoIterator<Item = ConnectionId>,
631 mut f: F,
632) where
633 F: FnMut(ConnectionId) -> anyhow::Result<()>,
634{
635 for receiver_id in receiver_ids {
636 if receiver_id != sender_id {
637 f(receiver_id).trace_err();
638 }
639 }
640}
641
642lazy_static! {
643 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
644}
645
646pub struct ProtocolVersion(u32);
647
648impl Header for ProtocolVersion {
649 fn name() -> &'static HeaderName {
650 &ZED_PROTOCOL_VERSION
651 }
652
653 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
654 where
655 Self: Sized,
656 I: Iterator<Item = &'i axum::http::HeaderValue>,
657 {
658 let version = values
659 .next()
660 .ok_or_else(axum::headers::Error::invalid)?
661 .to_str()
662 .map_err(|_| axum::headers::Error::invalid())?
663 .parse()
664 .map_err(|_| axum::headers::Error::invalid())?;
665 Ok(Self(version))
666 }
667
668 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
669 values.extend([self.0.to_string().parse().unwrap()]);
670 }
671}
672
673pub fn routes(server: Arc<Server>) -> Router<Body> {
674 Router::new()
675 .route("/rpc", get(handle_websocket_request))
676 .layer(
677 ServiceBuilder::new()
678 .layer(Extension(server.app_state.clone()))
679 .layer(middleware::from_fn(auth::validate_header)),
680 )
681 .route("/metrics", get(handle_metrics))
682 .layer(Extension(server))
683}
684
685pub async fn handle_websocket_request(
686 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
687 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
688 Extension(server): Extension<Arc<Server>>,
689 Extension(user): Extension<User>,
690 ws: WebSocketUpgrade,
691) -> axum::response::Response {
692 if protocol_version != rpc::PROTOCOL_VERSION {
693 return (
694 StatusCode::UPGRADE_REQUIRED,
695 "client must be upgraded".to_string(),
696 )
697 .into_response();
698 }
699 let socket_address = socket_address.to_string();
700 ws.on_upgrade(move |socket| {
701 use util::ResultExt;
702 let socket = socket
703 .map_ok(to_tungstenite_message)
704 .err_into()
705 .with(|message| async move { Ok(to_axum_message(message)) });
706 let connection = Connection::new(Box::pin(socket));
707 async move {
708 server
709 .handle_connection(connection, socket_address, user, None, Executor::Production)
710 .await
711 .log_err();
712 }
713 })
714}
715
716pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
717 let connections = server
718 .connection_pool
719 .lock()
720 .connections()
721 .filter(|connection| !connection.admin)
722 .count();
723
724 METRIC_CONNECTIONS.set(connections as _);
725
726 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
727 METRIC_SHARED_PROJECTS.set(shared_projects as _);
728
729 let encoder = prometheus::TextEncoder::new();
730 let metric_families = prometheus::gather();
731 let encoded_metrics = encoder
732 .encode_to_string(&metric_families)
733 .map_err(|err| anyhow!("{}", err))?;
734 Ok(encoded_metrics)
735}
736
737#[instrument(err, skip(executor))]
738async fn sign_out(
739 session: Session,
740 mut teardown: watch::Receiver<()>,
741 executor: Executor,
742) -> Result<()> {
743 session.peer.disconnect(session.connection_id);
744 session
745 .connection_pool()
746 .await
747 .remove_connection(session.connection_id)?;
748
749 if let Some(mut left_projects) = session
750 .db()
751 .await
752 .connection_lost(session.connection_id)
753 .await
754 .trace_err()
755 {
756 for left_project in mem::take(&mut *left_projects) {
757 project_left(&left_project, &session);
758 }
759 }
760
761 futures::select_biased! {
762 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
763 leave_room_for_session(&session).await.trace_err();
764
765 if !session
766 .connection_pool()
767 .await
768 .is_user_online(session.user_id)
769 {
770 let db = session.db().await;
771 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err() {
772 room_updated(&room, &session.peer);
773 }
774 }
775 update_user_contacts(session.user_id, &session).await?;
776 }
777 _ = teardown.changed().fuse() => {}
778 }
779
780 Ok(())
781}
782
783async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
784 response.send(proto::Ack {})?;
785 Ok(())
786}
787
788async fn create_room(
789 _request: proto::CreateRoom,
790 response: Response<proto::CreateRoom>,
791 session: Session,
792) -> Result<()> {
793 let live_kit_room = nanoid::nanoid!(30);
794 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
795 if let Some(_) = live_kit
796 .create_room(live_kit_room.clone())
797 .await
798 .trace_err()
799 {
800 if let Some(token) = live_kit
801 .room_token(&live_kit_room, &session.connection_id.to_string())
802 .trace_err()
803 {
804 Some(proto::LiveKitConnectionInfo {
805 server_url: live_kit.url().into(),
806 token,
807 })
808 } else {
809 None
810 }
811 } else {
812 None
813 }
814 } else {
815 None
816 };
817
818 {
819 let room = session
820 .db()
821 .await
822 .create_room(session.user_id, session.connection_id, &live_kit_room)
823 .await?;
824
825 response.send(proto::CreateRoomResponse {
826 room: Some(room.clone()),
827 live_kit_connection_info,
828 })?;
829 }
830
831 update_user_contacts(session.user_id, &session).await?;
832 Ok(())
833}
834
835async fn join_room(
836 request: proto::JoinRoom,
837 response: Response<proto::JoinRoom>,
838 session: Session,
839) -> Result<()> {
840 let room_id = RoomId::from_proto(request.id);
841 let room = {
842 let room = session
843 .db()
844 .await
845 .join_room(room_id, session.user_id, session.connection_id)
846 .await?;
847 room_updated(&room, &session.peer);
848 room.clone()
849 };
850
851 for connection_id in session
852 .connection_pool()
853 .await
854 .user_connection_ids(session.user_id)
855 {
856 session
857 .peer
858 .send(
859 connection_id,
860 proto::CallCanceled {
861 room_id: room_id.to_proto(),
862 },
863 )
864 .trace_err();
865 }
866
867 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
868 if let Some(token) = live_kit
869 .room_token(&room.live_kit_room, &session.connection_id.to_string())
870 .trace_err()
871 {
872 Some(proto::LiveKitConnectionInfo {
873 server_url: live_kit.url().into(),
874 token,
875 })
876 } else {
877 None
878 }
879 } else {
880 None
881 };
882
883 response.send(proto::JoinRoomResponse {
884 room: Some(room),
885 live_kit_connection_info,
886 })?;
887
888 update_user_contacts(session.user_id, &session).await?;
889 Ok(())
890}
891
892async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
893 leave_room_for_session(&session).await
894}
895
896async fn call(
897 request: proto::Call,
898 response: Response<proto::Call>,
899 session: Session,
900) -> Result<()> {
901 let room_id = RoomId::from_proto(request.room_id);
902 let calling_user_id = session.user_id;
903 let calling_connection_id = session.connection_id;
904 let called_user_id = UserId::from_proto(request.called_user_id);
905 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
906 if !session
907 .db()
908 .await
909 .has_contact(calling_user_id, called_user_id)
910 .await?
911 {
912 return Err(anyhow!("cannot call a user who isn't a contact"))?;
913 }
914
915 let incoming_call = {
916 let (room, incoming_call) = &mut *session
917 .db()
918 .await
919 .call(
920 room_id,
921 calling_user_id,
922 calling_connection_id,
923 called_user_id,
924 initial_project_id,
925 )
926 .await?;
927 room_updated(&room, &session.peer);
928 mem::take(incoming_call)
929 };
930 update_user_contacts(called_user_id, &session).await?;
931
932 let mut calls = session
933 .connection_pool()
934 .await
935 .user_connection_ids(called_user_id)
936 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
937 .collect::<FuturesUnordered<_>>();
938
939 while let Some(call_response) = calls.next().await {
940 match call_response.as_ref() {
941 Ok(_) => {
942 response.send(proto::Ack {})?;
943 return Ok(());
944 }
945 Err(_) => {
946 call_response.trace_err();
947 }
948 }
949 }
950
951 {
952 let room = session
953 .db()
954 .await
955 .call_failed(room_id, called_user_id)
956 .await?;
957 room_updated(&room, &session.peer);
958 }
959 update_user_contacts(called_user_id, &session).await?;
960
961 Err(anyhow!("failed to ring user"))?
962}
963
964async fn cancel_call(
965 request: proto::CancelCall,
966 response: Response<proto::CancelCall>,
967 session: Session,
968) -> Result<()> {
969 let called_user_id = UserId::from_proto(request.called_user_id);
970 let room_id = RoomId::from_proto(request.room_id);
971 {
972 let room = session
973 .db()
974 .await
975 .cancel_call(Some(room_id), session.connection_id, called_user_id)
976 .await?;
977 room_updated(&room, &session.peer);
978 }
979
980 for connection_id in session
981 .connection_pool()
982 .await
983 .user_connection_ids(called_user_id)
984 {
985 session
986 .peer
987 .send(
988 connection_id,
989 proto::CallCanceled {
990 room_id: room_id.to_proto(),
991 },
992 )
993 .trace_err();
994 }
995 response.send(proto::Ack {})?;
996
997 update_user_contacts(called_user_id, &session).await?;
998 Ok(())
999}
1000
1001async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1002 let room_id = RoomId::from_proto(message.room_id);
1003 {
1004 let room = session
1005 .db()
1006 .await
1007 .decline_call(Some(room_id), session.user_id)
1008 .await?;
1009 room_updated(&room, &session.peer);
1010 }
1011
1012 for connection_id in session
1013 .connection_pool()
1014 .await
1015 .user_connection_ids(session.user_id)
1016 {
1017 session
1018 .peer
1019 .send(
1020 connection_id,
1021 proto::CallCanceled {
1022 room_id: room_id.to_proto(),
1023 },
1024 )
1025 .trace_err();
1026 }
1027 update_user_contacts(session.user_id, &session).await?;
1028 Ok(())
1029}
1030
1031async fn update_participant_location(
1032 request: proto::UpdateParticipantLocation,
1033 response: Response<proto::UpdateParticipantLocation>,
1034 session: Session,
1035) -> Result<()> {
1036 let room_id = RoomId::from_proto(request.room_id);
1037 let location = request
1038 .location
1039 .ok_or_else(|| anyhow!("invalid location"))?;
1040 let room = session
1041 .db()
1042 .await
1043 .update_room_participant_location(room_id, session.connection_id, location)
1044 .await?;
1045 room_updated(&room, &session.peer);
1046 response.send(proto::Ack {})?;
1047 Ok(())
1048}
1049
1050async fn share_project(
1051 request: proto::ShareProject,
1052 response: Response<proto::ShareProject>,
1053 session: Session,
1054) -> Result<()> {
1055 let (project_id, room) = &*session
1056 .db()
1057 .await
1058 .share_project(
1059 RoomId::from_proto(request.room_id),
1060 session.connection_id,
1061 &request.worktrees,
1062 )
1063 .await?;
1064 response.send(proto::ShareProjectResponse {
1065 project_id: project_id.to_proto(),
1066 })?;
1067 room_updated(&room, &session.peer);
1068
1069 Ok(())
1070}
1071
1072async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1073 let project_id = ProjectId::from_proto(message.project_id);
1074
1075 let (room, guest_connection_ids) = &*session
1076 .db()
1077 .await
1078 .unshare_project(project_id, session.connection_id)
1079 .await?;
1080
1081 broadcast(
1082 session.connection_id,
1083 guest_connection_ids.iter().copied(),
1084 |conn_id| session.peer.send(conn_id, message.clone()),
1085 );
1086 room_updated(&room, &session.peer);
1087
1088 Ok(())
1089}
1090
1091async fn join_project(
1092 request: proto::JoinProject,
1093 response: Response<proto::JoinProject>,
1094 session: Session,
1095) -> Result<()> {
1096 let project_id = ProjectId::from_proto(request.project_id);
1097 let guest_user_id = session.user_id;
1098
1099 tracing::info!(%project_id, "join project");
1100
1101 let (project, replica_id) = &mut *session
1102 .db()
1103 .await
1104 .join_project(project_id, session.connection_id)
1105 .await?;
1106
1107 let collaborators = project
1108 .collaborators
1109 .iter()
1110 .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32)
1111 .map(|collaborator| proto::Collaborator {
1112 peer_id: collaborator.connection_id as u32,
1113 replica_id: collaborator.replica_id.0 as u32,
1114 user_id: collaborator.user_id.to_proto(),
1115 })
1116 .collect::<Vec<_>>();
1117 let worktrees = project
1118 .worktrees
1119 .iter()
1120 .map(|(id, worktree)| proto::WorktreeMetadata {
1121 id: *id,
1122 root_name: worktree.root_name.clone(),
1123 visible: worktree.visible,
1124 abs_path: worktree.abs_path.clone(),
1125 })
1126 .collect::<Vec<_>>();
1127
1128 for collaborator in &collaborators {
1129 session
1130 .peer
1131 .send(
1132 ConnectionId(collaborator.peer_id),
1133 proto::AddProjectCollaborator {
1134 project_id: project_id.to_proto(),
1135 collaborator: Some(proto::Collaborator {
1136 peer_id: session.connection_id.0,
1137 replica_id: replica_id.0 as u32,
1138 user_id: guest_user_id.to_proto(),
1139 }),
1140 },
1141 )
1142 .trace_err();
1143 }
1144
1145 // First, we send the metadata associated with each worktree.
1146 response.send(proto::JoinProjectResponse {
1147 worktrees: worktrees.clone(),
1148 replica_id: replica_id.0 as u32,
1149 collaborators: collaborators.clone(),
1150 language_servers: project.language_servers.clone(),
1151 })?;
1152
1153 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1154 #[cfg(any(test, feature = "test-support"))]
1155 const MAX_CHUNK_SIZE: usize = 2;
1156 #[cfg(not(any(test, feature = "test-support")))]
1157 const MAX_CHUNK_SIZE: usize = 256;
1158
1159 // Stream this worktree's entries.
1160 let message = proto::UpdateWorktree {
1161 project_id: project_id.to_proto(),
1162 worktree_id,
1163 abs_path: worktree.abs_path.clone(),
1164 root_name: worktree.root_name,
1165 updated_entries: worktree.entries,
1166 removed_entries: Default::default(),
1167 scan_id: worktree.scan_id,
1168 is_last_update: worktree.is_complete,
1169 };
1170 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1171 session.peer.send(session.connection_id, update.clone())?;
1172 }
1173
1174 // Stream this worktree's diagnostics.
1175 for summary in worktree.diagnostic_summaries {
1176 session.peer.send(
1177 session.connection_id,
1178 proto::UpdateDiagnosticSummary {
1179 project_id: project_id.to_proto(),
1180 worktree_id: worktree.id,
1181 summary: Some(summary),
1182 },
1183 )?;
1184 }
1185 }
1186
1187 for language_server in &project.language_servers {
1188 session.peer.send(
1189 session.connection_id,
1190 proto::UpdateLanguageServer {
1191 project_id: project_id.to_proto(),
1192 language_server_id: language_server.id,
1193 variant: Some(
1194 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1195 proto::LspDiskBasedDiagnosticsUpdated {},
1196 ),
1197 ),
1198 },
1199 )?;
1200 }
1201
1202 Ok(())
1203}
1204
1205async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1206 let sender_id = session.connection_id;
1207 let project_id = ProjectId::from_proto(request.project_id);
1208
1209 let project = session
1210 .db()
1211 .await
1212 .leave_project(project_id, sender_id)
1213 .await?;
1214 tracing::info!(
1215 %project_id,
1216 host_user_id = %project.host_user_id,
1217 host_connection_id = %project.host_connection_id,
1218 "leave project"
1219 );
1220 project_left(&project, &session);
1221
1222 Ok(())
1223}
1224
1225async fn update_project(
1226 request: proto::UpdateProject,
1227 response: Response<proto::UpdateProject>,
1228 session: Session,
1229) -> Result<()> {
1230 let project_id = ProjectId::from_proto(request.project_id);
1231 let (room, guest_connection_ids) = &*session
1232 .db()
1233 .await
1234 .update_project(project_id, session.connection_id, &request.worktrees)
1235 .await?;
1236 broadcast(
1237 session.connection_id,
1238 guest_connection_ids.iter().copied(),
1239 |connection_id| {
1240 session
1241 .peer
1242 .forward_send(session.connection_id, connection_id, request.clone())
1243 },
1244 );
1245 room_updated(&room, &session.peer);
1246 response.send(proto::Ack {})?;
1247
1248 Ok(())
1249}
1250
1251async fn update_worktree(
1252 request: proto::UpdateWorktree,
1253 response: Response<proto::UpdateWorktree>,
1254 session: Session,
1255) -> Result<()> {
1256 let guest_connection_ids = session
1257 .db()
1258 .await
1259 .update_worktree(&request, session.connection_id)
1260 .await?;
1261
1262 broadcast(
1263 session.connection_id,
1264 guest_connection_ids.iter().copied(),
1265 |connection_id| {
1266 session
1267 .peer
1268 .forward_send(session.connection_id, connection_id, request.clone())
1269 },
1270 );
1271 response.send(proto::Ack {})?;
1272 Ok(())
1273}
1274
1275async fn update_diagnostic_summary(
1276 message: proto::UpdateDiagnosticSummary,
1277 session: Session,
1278) -> Result<()> {
1279 let guest_connection_ids = session
1280 .db()
1281 .await
1282 .update_diagnostic_summary(&message, session.connection_id)
1283 .await?;
1284
1285 broadcast(
1286 session.connection_id,
1287 guest_connection_ids.iter().copied(),
1288 |connection_id| {
1289 session
1290 .peer
1291 .forward_send(session.connection_id, connection_id, message.clone())
1292 },
1293 );
1294
1295 Ok(())
1296}
1297
1298async fn start_language_server(
1299 request: proto::StartLanguageServer,
1300 session: Session,
1301) -> Result<()> {
1302 let guest_connection_ids = session
1303 .db()
1304 .await
1305 .start_language_server(&request, session.connection_id)
1306 .await?;
1307
1308 broadcast(
1309 session.connection_id,
1310 guest_connection_ids.iter().copied(),
1311 |connection_id| {
1312 session
1313 .peer
1314 .forward_send(session.connection_id, connection_id, request.clone())
1315 },
1316 );
1317 Ok(())
1318}
1319
1320async fn update_language_server(
1321 request: proto::UpdateLanguageServer,
1322 session: Session,
1323) -> Result<()> {
1324 let project_id = ProjectId::from_proto(request.project_id);
1325 let project_connection_ids = session
1326 .db()
1327 .await
1328 .project_connection_ids(project_id, session.connection_id)
1329 .await?;
1330 broadcast(
1331 session.connection_id,
1332 project_connection_ids.iter().copied(),
1333 |connection_id| {
1334 session
1335 .peer
1336 .forward_send(session.connection_id, connection_id, request.clone())
1337 },
1338 );
1339 Ok(())
1340}
1341
1342async fn forward_project_request<T>(
1343 request: T,
1344 response: Response<T>,
1345 session: Session,
1346) -> Result<()>
1347where
1348 T: EntityMessage + RequestMessage,
1349{
1350 let project_id = ProjectId::from_proto(request.remote_entity_id());
1351 let host_connection_id = {
1352 let collaborators = session
1353 .db()
1354 .await
1355 .project_collaborators(project_id, session.connection_id)
1356 .await?;
1357 ConnectionId(
1358 collaborators
1359 .iter()
1360 .find(|collaborator| collaborator.is_host)
1361 .ok_or_else(|| anyhow!("host not found"))?
1362 .connection_id as u32,
1363 )
1364 };
1365
1366 let payload = session
1367 .peer
1368 .forward_request(session.connection_id, host_connection_id, request)
1369 .await?;
1370
1371 response.send(payload)?;
1372 Ok(())
1373}
1374
1375async fn save_buffer(
1376 request: proto::SaveBuffer,
1377 response: Response<proto::SaveBuffer>,
1378 session: Session,
1379) -> Result<()> {
1380 let project_id = ProjectId::from_proto(request.project_id);
1381 let host_connection_id = {
1382 let collaborators = session
1383 .db()
1384 .await
1385 .project_collaborators(project_id, session.connection_id)
1386 .await?;
1387 let host = collaborators
1388 .iter()
1389 .find(|collaborator| collaborator.is_host)
1390 .ok_or_else(|| anyhow!("host not found"))?;
1391 ConnectionId(host.connection_id as u32)
1392 };
1393 let response_payload = session
1394 .peer
1395 .forward_request(session.connection_id, host_connection_id, request.clone())
1396 .await?;
1397
1398 let mut collaborators = session
1399 .db()
1400 .await
1401 .project_collaborators(project_id, session.connection_id)
1402 .await?;
1403 collaborators
1404 .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32);
1405 let project_connection_ids = collaborators
1406 .iter()
1407 .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
1408 broadcast(host_connection_id, project_connection_ids, |conn_id| {
1409 session
1410 .peer
1411 .forward_send(host_connection_id, conn_id, response_payload.clone())
1412 });
1413 response.send(response_payload)?;
1414 Ok(())
1415}
1416
1417async fn create_buffer_for_peer(
1418 request: proto::CreateBufferForPeer,
1419 session: Session,
1420) -> Result<()> {
1421 session.peer.forward_send(
1422 session.connection_id,
1423 ConnectionId(request.peer_id),
1424 request,
1425 )?;
1426 Ok(())
1427}
1428
1429async fn update_buffer(
1430 request: proto::UpdateBuffer,
1431 response: Response<proto::UpdateBuffer>,
1432 session: Session,
1433) -> Result<()> {
1434 let project_id = ProjectId::from_proto(request.project_id);
1435 let project_connection_ids = session
1436 .db()
1437 .await
1438 .project_connection_ids(project_id, session.connection_id)
1439 .await?;
1440
1441 broadcast(
1442 session.connection_id,
1443 project_connection_ids.iter().copied(),
1444 |connection_id| {
1445 session
1446 .peer
1447 .forward_send(session.connection_id, connection_id, request.clone())
1448 },
1449 );
1450 response.send(proto::Ack {})?;
1451 Ok(())
1452}
1453
1454async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1455 let project_id = ProjectId::from_proto(request.project_id);
1456 let project_connection_ids = session
1457 .db()
1458 .await
1459 .project_connection_ids(project_id, session.connection_id)
1460 .await?;
1461
1462 broadcast(
1463 session.connection_id,
1464 project_connection_ids.iter().copied(),
1465 |connection_id| {
1466 session
1467 .peer
1468 .forward_send(session.connection_id, connection_id, request.clone())
1469 },
1470 );
1471 Ok(())
1472}
1473
1474async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1475 let project_id = ProjectId::from_proto(request.project_id);
1476 let project_connection_ids = session
1477 .db()
1478 .await
1479 .project_connection_ids(project_id, session.connection_id)
1480 .await?;
1481 broadcast(
1482 session.connection_id,
1483 project_connection_ids.iter().copied(),
1484 |connection_id| {
1485 session
1486 .peer
1487 .forward_send(session.connection_id, connection_id, request.clone())
1488 },
1489 );
1490 Ok(())
1491}
1492
1493async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1494 let project_id = ProjectId::from_proto(request.project_id);
1495 let project_connection_ids = session
1496 .db()
1497 .await
1498 .project_connection_ids(project_id, session.connection_id)
1499 .await?;
1500 broadcast(
1501 session.connection_id,
1502 project_connection_ids.iter().copied(),
1503 |connection_id| {
1504 session
1505 .peer
1506 .forward_send(session.connection_id, connection_id, request.clone())
1507 },
1508 );
1509 Ok(())
1510}
1511
1512async fn follow(
1513 request: proto::Follow,
1514 response: Response<proto::Follow>,
1515 session: Session,
1516) -> Result<()> {
1517 let project_id = ProjectId::from_proto(request.project_id);
1518 let leader_id = ConnectionId(request.leader_id);
1519 let follower_id = session.connection_id;
1520 {
1521 let project_connection_ids = session
1522 .db()
1523 .await
1524 .project_connection_ids(project_id, session.connection_id)
1525 .await?;
1526
1527 if !project_connection_ids.contains(&leader_id) {
1528 Err(anyhow!("no such peer"))?;
1529 }
1530 }
1531
1532 let mut response_payload = session
1533 .peer
1534 .forward_request(session.connection_id, leader_id, request)
1535 .await?;
1536 response_payload
1537 .views
1538 .retain(|view| view.leader_id != Some(follower_id.0));
1539 response.send(response_payload)?;
1540 Ok(())
1541}
1542
1543async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1544 let project_id = ProjectId::from_proto(request.project_id);
1545 let leader_id = ConnectionId(request.leader_id);
1546 let project_connection_ids = session
1547 .db()
1548 .await
1549 .project_connection_ids(project_id, session.connection_id)
1550 .await?;
1551 if !project_connection_ids.contains(&leader_id) {
1552 Err(anyhow!("no such peer"))?;
1553 }
1554 session
1555 .peer
1556 .forward_send(session.connection_id, leader_id, request)?;
1557 Ok(())
1558}
1559
1560async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1561 let project_id = ProjectId::from_proto(request.project_id);
1562 let project_connection_ids = session
1563 .db
1564 .lock()
1565 .await
1566 .project_connection_ids(project_id, session.connection_id)
1567 .await?;
1568
1569 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1570 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1571 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1572 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1573 });
1574 for follower_id in &request.follower_ids {
1575 let follower_id = ConnectionId(*follower_id);
1576 if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1577 session
1578 .peer
1579 .forward_send(session.connection_id, follower_id, request.clone())?;
1580 }
1581 }
1582 Ok(())
1583}
1584
1585async fn get_users(
1586 request: proto::GetUsers,
1587 response: Response<proto::GetUsers>,
1588 session: Session,
1589) -> Result<()> {
1590 let user_ids = request
1591 .user_ids
1592 .into_iter()
1593 .map(UserId::from_proto)
1594 .collect();
1595 let users = session
1596 .db()
1597 .await
1598 .get_users_by_ids(user_ids)
1599 .await?
1600 .into_iter()
1601 .map(|user| proto::User {
1602 id: user.id.to_proto(),
1603 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1604 github_login: user.github_login,
1605 })
1606 .collect();
1607 response.send(proto::UsersResponse { users })?;
1608 Ok(())
1609}
1610
1611async fn fuzzy_search_users(
1612 request: proto::FuzzySearchUsers,
1613 response: Response<proto::FuzzySearchUsers>,
1614 session: Session,
1615) -> Result<()> {
1616 let query = request.query;
1617 let users = match query.len() {
1618 0 => vec![],
1619 1 | 2 => session
1620 .db()
1621 .await
1622 .get_user_by_github_account(&query, None)
1623 .await?
1624 .into_iter()
1625 .collect(),
1626 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1627 };
1628 let users = users
1629 .into_iter()
1630 .filter(|user| user.id != session.user_id)
1631 .map(|user| proto::User {
1632 id: user.id.to_proto(),
1633 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1634 github_login: user.github_login,
1635 })
1636 .collect();
1637 response.send(proto::UsersResponse { users })?;
1638 Ok(())
1639}
1640
1641async fn request_contact(
1642 request: proto::RequestContact,
1643 response: Response<proto::RequestContact>,
1644 session: Session,
1645) -> Result<()> {
1646 let requester_id = session.user_id;
1647 let responder_id = UserId::from_proto(request.responder_id);
1648 if requester_id == responder_id {
1649 return Err(anyhow!("cannot add yourself as a contact"))?;
1650 }
1651
1652 session
1653 .db()
1654 .await
1655 .send_contact_request(requester_id, responder_id)
1656 .await?;
1657
1658 // Update outgoing contact requests of requester
1659 let mut update = proto::UpdateContacts::default();
1660 update.outgoing_requests.push(responder_id.to_proto());
1661 for connection_id in session
1662 .connection_pool()
1663 .await
1664 .user_connection_ids(requester_id)
1665 {
1666 session.peer.send(connection_id, update.clone())?;
1667 }
1668
1669 // Update incoming contact requests of responder
1670 let mut update = proto::UpdateContacts::default();
1671 update
1672 .incoming_requests
1673 .push(proto::IncomingContactRequest {
1674 requester_id: requester_id.to_proto(),
1675 should_notify: true,
1676 });
1677 for connection_id in session
1678 .connection_pool()
1679 .await
1680 .user_connection_ids(responder_id)
1681 {
1682 session.peer.send(connection_id, update.clone())?;
1683 }
1684
1685 response.send(proto::Ack {})?;
1686 Ok(())
1687}
1688
1689async fn respond_to_contact_request(
1690 request: proto::RespondToContactRequest,
1691 response: Response<proto::RespondToContactRequest>,
1692 session: Session,
1693) -> Result<()> {
1694 let responder_id = session.user_id;
1695 let requester_id = UserId::from_proto(request.requester_id);
1696 let db = session.db().await;
1697 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1698 db.dismiss_contact_notification(responder_id, requester_id)
1699 .await?;
1700 } else {
1701 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1702
1703 db.respond_to_contact_request(responder_id, requester_id, accept)
1704 .await?;
1705 let requester_busy = db.is_user_busy(requester_id).await?;
1706 let responder_busy = db.is_user_busy(responder_id).await?;
1707
1708 let pool = session.connection_pool().await;
1709 // Update responder with new contact
1710 let mut update = proto::UpdateContacts::default();
1711 if accept {
1712 update
1713 .contacts
1714 .push(contact_for_user(requester_id, false, requester_busy, &pool));
1715 }
1716 update
1717 .remove_incoming_requests
1718 .push(requester_id.to_proto());
1719 for connection_id in pool.user_connection_ids(responder_id) {
1720 session.peer.send(connection_id, update.clone())?;
1721 }
1722
1723 // Update requester with new contact
1724 let mut update = proto::UpdateContacts::default();
1725 if accept {
1726 update
1727 .contacts
1728 .push(contact_for_user(responder_id, true, responder_busy, &pool));
1729 }
1730 update
1731 .remove_outgoing_requests
1732 .push(responder_id.to_proto());
1733 for connection_id in pool.user_connection_ids(requester_id) {
1734 session.peer.send(connection_id, update.clone())?;
1735 }
1736 }
1737
1738 response.send(proto::Ack {})?;
1739 Ok(())
1740}
1741
1742async fn remove_contact(
1743 request: proto::RemoveContact,
1744 response: Response<proto::RemoveContact>,
1745 session: Session,
1746) -> Result<()> {
1747 let requester_id = session.user_id;
1748 let responder_id = UserId::from_proto(request.user_id);
1749 let db = session.db().await;
1750 db.remove_contact(requester_id, responder_id).await?;
1751
1752 let pool = session.connection_pool().await;
1753 // Update outgoing contact requests of requester
1754 let mut update = proto::UpdateContacts::default();
1755 update
1756 .remove_outgoing_requests
1757 .push(responder_id.to_proto());
1758 for connection_id in pool.user_connection_ids(requester_id) {
1759 session.peer.send(connection_id, update.clone())?;
1760 }
1761
1762 // Update incoming contact requests of responder
1763 let mut update = proto::UpdateContacts::default();
1764 update
1765 .remove_incoming_requests
1766 .push(requester_id.to_proto());
1767 for connection_id in pool.user_connection_ids(responder_id) {
1768 session.peer.send(connection_id, update.clone())?;
1769 }
1770
1771 response.send(proto::Ack {})?;
1772 Ok(())
1773}
1774
1775async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1776 let project_id = ProjectId::from_proto(request.project_id);
1777 let project_connection_ids = session
1778 .db()
1779 .await
1780 .project_connection_ids(project_id, session.connection_id)
1781 .await?;
1782 broadcast(
1783 session.connection_id,
1784 project_connection_ids.iter().copied(),
1785 |connection_id| {
1786 session
1787 .peer
1788 .forward_send(session.connection_id, connection_id, request.clone())
1789 },
1790 );
1791 Ok(())
1792}
1793
1794async fn get_private_user_info(
1795 _request: proto::GetPrivateUserInfo,
1796 response: Response<proto::GetPrivateUserInfo>,
1797 session: Session,
1798) -> Result<()> {
1799 let metrics_id = session
1800 .db()
1801 .await
1802 .get_user_metrics_id(session.user_id)
1803 .await?;
1804 let user = session
1805 .db()
1806 .await
1807 .get_user_by_id(session.user_id)
1808 .await?
1809 .ok_or_else(|| anyhow!("user not found"))?;
1810 response.send(proto::GetPrivateUserInfoResponse {
1811 metrics_id,
1812 staff: user.admin,
1813 })?;
1814 Ok(())
1815}
1816
1817fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1818 match message {
1819 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1820 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1821 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1822 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1823 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1824 code: frame.code.into(),
1825 reason: frame.reason,
1826 })),
1827 }
1828}
1829
1830fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1831 match message {
1832 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1833 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1834 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1835 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1836 AxumMessage::Close(frame) => {
1837 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1838 code: frame.code.into(),
1839 reason: frame.reason,
1840 }))
1841 }
1842 }
1843}
1844
1845fn build_initial_contacts_update(
1846 contacts: Vec<db::Contact>,
1847 pool: &ConnectionPool,
1848) -> proto::UpdateContacts {
1849 let mut update = proto::UpdateContacts::default();
1850
1851 for contact in contacts {
1852 match contact {
1853 db::Contact::Accepted {
1854 user_id,
1855 should_notify,
1856 busy,
1857 } => {
1858 update
1859 .contacts
1860 .push(contact_for_user(user_id, should_notify, busy, &pool));
1861 }
1862 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
1863 db::Contact::Incoming {
1864 user_id,
1865 should_notify,
1866 } => update
1867 .incoming_requests
1868 .push(proto::IncomingContactRequest {
1869 requester_id: user_id.to_proto(),
1870 should_notify,
1871 }),
1872 }
1873 }
1874
1875 update
1876}
1877
1878fn contact_for_user(
1879 user_id: UserId,
1880 should_notify: bool,
1881 busy: bool,
1882 pool: &ConnectionPool,
1883) -> proto::Contact {
1884 proto::Contact {
1885 user_id: user_id.to_proto(),
1886 online: pool.is_user_online(user_id),
1887 busy,
1888 should_notify,
1889 }
1890}
1891
1892fn room_updated(room: &proto::Room, peer: &Peer) {
1893 for participant in &room.participants {
1894 peer.send(
1895 ConnectionId(participant.peer_id),
1896 proto::RoomUpdated {
1897 room: Some(room.clone()),
1898 },
1899 )
1900 .trace_err();
1901 }
1902}
1903
1904async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
1905 let db = session.db().await;
1906 let contacts = db.get_contacts(user_id).await?;
1907 let busy = db.is_user_busy(user_id).await?;
1908
1909 let pool = session.connection_pool().await;
1910 let updated_contact = contact_for_user(user_id, false, busy, &pool);
1911 for contact in contacts {
1912 if let db::Contact::Accepted {
1913 user_id: contact_user_id,
1914 ..
1915 } = contact
1916 {
1917 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
1918 session
1919 .peer
1920 .send(
1921 contact_conn_id,
1922 proto::UpdateContacts {
1923 contacts: vec![updated_contact.clone()],
1924 remove_contacts: Default::default(),
1925 incoming_requests: Default::default(),
1926 remove_incoming_requests: Default::default(),
1927 outgoing_requests: Default::default(),
1928 remove_outgoing_requests: Default::default(),
1929 },
1930 )
1931 .trace_err();
1932 }
1933 }
1934 }
1935 Ok(())
1936}
1937
1938async fn leave_room_for_session(session: &Session) -> Result<()> {
1939 let mut contacts_to_update = HashSet::default();
1940
1941 let room_id;
1942 let canceled_calls_to_user_ids;
1943 let live_kit_room;
1944 let delete_live_kit_room;
1945 {
1946 let mut left_room = session.db().await.leave_room(session.connection_id).await?;
1947 contacts_to_update.insert(session.user_id);
1948
1949 for project in left_room.left_projects.values() {
1950 project_left(project, session);
1951 }
1952
1953 room_updated(&left_room.room, &session.peer);
1954 room_id = RoomId::from_proto(left_room.room.id);
1955 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
1956 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
1957 delete_live_kit_room = left_room.room.participants.is_empty();
1958 }
1959
1960 {
1961 let pool = session.connection_pool().await;
1962 for canceled_user_id in canceled_calls_to_user_ids {
1963 for connection_id in pool.user_connection_ids(canceled_user_id) {
1964 session
1965 .peer
1966 .send(
1967 connection_id,
1968 proto::CallCanceled {
1969 room_id: room_id.to_proto(),
1970 },
1971 )
1972 .trace_err();
1973 }
1974 contacts_to_update.insert(canceled_user_id);
1975 }
1976 }
1977
1978 for contact_user_id in contacts_to_update {
1979 update_user_contacts(contact_user_id, &session).await?;
1980 }
1981
1982 if let Some(live_kit) = session.live_kit_client.as_ref() {
1983 live_kit
1984 .remove_participant(live_kit_room.clone(), session.connection_id.to_string())
1985 .await
1986 .trace_err();
1987
1988 if delete_live_kit_room {
1989 live_kit.delete_room(live_kit_room).await.trace_err();
1990 }
1991 }
1992
1993 Ok(())
1994}
1995
1996fn project_left(project: &db::LeftProject, session: &Session) {
1997 for connection_id in &project.connection_ids {
1998 if project.host_user_id == session.user_id {
1999 session
2000 .peer
2001 .send(
2002 *connection_id,
2003 proto::UnshareProject {
2004 project_id: project.id.to_proto(),
2005 },
2006 )
2007 .trace_err();
2008 } else {
2009 session
2010 .peer
2011 .send(
2012 *connection_id,
2013 proto::RemoveProjectCollaborator {
2014 project_id: project.id.to_proto(),
2015 peer_id: session.connection_id.0,
2016 },
2017 )
2018 .trace_err();
2019 }
2020 }
2021
2022 session
2023 .peer
2024 .send(
2025 session.connection_id,
2026 proto::UnshareProject {
2027 project_id: project.id.to_proto(),
2028 },
2029 )
2030 .trace_err();
2031}
2032
2033pub trait ResultExt {
2034 type Ok;
2035
2036 fn trace_err(self) -> Option<Self::Ok>;
2037}
2038
2039impl<T, E> ResultExt for Result<T, E>
2040where
2041 E: std::fmt::Debug,
2042{
2043 type Ok = T;
2044
2045 fn trace_err(self) -> Option<T> {
2046 match self {
2047 Ok(value) => Some(value),
2048 Err(error) => {
2049 tracing::error!("{:?}", error);
2050 None
2051 }
2052 }
2053 }
2054}