1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, Database, ProjectId, RoomId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
38 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
39};
40use serde::{Serialize, Serializer};
41use std::{
42 any::TypeId,
43 fmt,
44 future::Future,
45 marker::PhantomData,
46 mem,
47 net::SocketAddr,
48 ops::{Deref, DerefMut},
49 rc::Rc,
50 sync::{
51 atomic::{AtomicBool, Ordering::SeqCst},
52 Arc,
53 },
54 time::Duration,
55};
56use tokio::sync::watch;
57use tower::ServiceBuilder;
58use tracing::{info_span, instrument, Instrument};
59
60pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(5);
61pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(20);
62
63lazy_static! {
64 static ref METRIC_CONNECTIONS: IntGauge =
65 register_int_gauge!("connections", "number of connections").unwrap();
66 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
67 "shared_projects",
68 "number of open projects with one or more guests"
69 )
70 .unwrap();
71}
72
73type MessageHandler =
74 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
75
76struct Response<R> {
77 peer: Arc<Peer>,
78 receipt: Receipt<R>,
79 responded: Arc<AtomicBool>,
80}
81
82impl<R: RequestMessage> Response<R> {
83 fn send(self, payload: R::Response) -> Result<()> {
84 self.responded.store(true, SeqCst);
85 self.peer.respond(self.receipt, payload)?;
86 Ok(())
87 }
88}
89
90#[derive(Clone)]
91struct Session {
92 user_id: UserId,
93 connection_id: ConnectionId,
94 db: Arc<tokio::sync::Mutex<DbHandle>>,
95 peer: Arc<Peer>,
96 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
97 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
98}
99
100impl Session {
101 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
102 #[cfg(test)]
103 tokio::task::yield_now().await;
104 let guard = self.db.lock().await;
105 #[cfg(test)]
106 tokio::task::yield_now().await;
107 guard
108 }
109
110 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
111 #[cfg(test)]
112 tokio::task::yield_now().await;
113 let guard = self.connection_pool.lock();
114 ConnectionPoolGuard {
115 guard,
116 _not_send: PhantomData,
117 }
118 }
119}
120
121impl fmt::Debug for Session {
122 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123 f.debug_struct("Session")
124 .field("user_id", &self.user_id)
125 .field("connection_id", &self.connection_id)
126 .finish()
127 }
128}
129
130struct DbHandle(Arc<Database>);
131
132impl Deref for DbHandle {
133 type Target = Database;
134
135 fn deref(&self) -> &Self::Target {
136 self.0.as_ref()
137 }
138}
139
140pub struct Server {
141 peer: Arc<Peer>,
142 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
143 app_state: Arc<AppState>,
144 executor: Executor,
145 handlers: HashMap<TypeId, MessageHandler>,
146 teardown: watch::Sender<()>,
147}
148
149pub(crate) struct ConnectionPoolGuard<'a> {
150 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
151 _not_send: PhantomData<Rc<()>>,
152}
153
154#[derive(Serialize)]
155pub struct ServerSnapshot<'a> {
156 peer: &'a Peer,
157 #[serde(serialize_with = "serialize_deref")]
158 connection_pool: ConnectionPoolGuard<'a>,
159}
160
161pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
162where
163 S: Serializer,
164 T: Deref<Target = U>,
165 U: Serialize,
166{
167 Serialize::serialize(value.deref(), serializer)
168}
169
170impl Server {
171 pub fn new(app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
172 let mut server = Self {
173 peer: Peer::new(),
174 app_state,
175 executor,
176 connection_pool: Default::default(),
177 handlers: Default::default(),
178 teardown: watch::channel(()).0,
179 };
180
181 server
182 .add_request_handler(ping)
183 .add_request_handler(create_room)
184 .add_request_handler(join_room)
185 .add_message_handler(leave_room)
186 .add_request_handler(call)
187 .add_request_handler(cancel_call)
188 .add_message_handler(decline_call)
189 .add_request_handler(update_participant_location)
190 .add_request_handler(share_project)
191 .add_message_handler(unshare_project)
192 .add_request_handler(join_project)
193 .add_message_handler(leave_project)
194 .add_request_handler(update_project)
195 .add_request_handler(update_worktree)
196 .add_message_handler(start_language_server)
197 .add_message_handler(update_language_server)
198 .add_message_handler(update_diagnostic_summary)
199 .add_request_handler(forward_project_request::<proto::GetHover>)
200 .add_request_handler(forward_project_request::<proto::GetDefinition>)
201 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
202 .add_request_handler(forward_project_request::<proto::GetReferences>)
203 .add_request_handler(forward_project_request::<proto::SearchProject>)
204 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
205 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
206 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
207 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
208 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
209 .add_request_handler(forward_project_request::<proto::GetCompletions>)
210 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
211 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
212 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
213 .add_request_handler(forward_project_request::<proto::PrepareRename>)
214 .add_request_handler(forward_project_request::<proto::PerformRename>)
215 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
216 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
217 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
218 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
219 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
220 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
221 .add_message_handler(create_buffer_for_peer)
222 .add_request_handler(update_buffer)
223 .add_message_handler(update_buffer_file)
224 .add_message_handler(buffer_reloaded)
225 .add_message_handler(buffer_saved)
226 .add_request_handler(save_buffer)
227 .add_request_handler(get_users)
228 .add_request_handler(fuzzy_search_users)
229 .add_request_handler(request_contact)
230 .add_request_handler(remove_contact)
231 .add_request_handler(respond_to_contact_request)
232 .add_request_handler(follow)
233 .add_message_handler(unfollow)
234 .add_message_handler(update_followers)
235 .add_message_handler(update_diff_base)
236 .add_request_handler(get_private_user_info);
237
238 Arc::new(server)
239 }
240
241 pub async fn start(&self) -> Result<()> {
242 self.app_state.db.delete_stale_projects().await?;
243 let db = self.app_state.db.clone();
244 let peer = self.peer.clone();
245 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
246 let pool = self.connection_pool.clone();
247 let live_kit_client = self.app_state.live_kit_client.clone();
248 self.executor.spawn_detached(async move {
249 timeout.await;
250 if let Some(room_ids) = db.stale_room_ids().await.trace_err() {
251 for room_id in room_ids {
252 let mut contacts_to_update = HashSet::default();
253 let mut canceled_calls_to_user_ids = Vec::new();
254 let mut live_kit_room = String::new();
255 let mut delete_live_kit_room = false;
256
257 if let Ok(mut refreshed_room) = db.refresh_room(room_id).await {
258 room_updated(&refreshed_room.room, &peer);
259 contacts_to_update
260 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
261 contacts_to_update
262 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
263 canceled_calls_to_user_ids =
264 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
265 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
266 delete_live_kit_room = refreshed_room.room.participants.is_empty();
267 }
268
269 {
270 let pool = pool.lock();
271 for canceled_user_id in canceled_calls_to_user_ids {
272 for connection_id in pool.user_connection_ids(canceled_user_id) {
273 peer.send(
274 connection_id,
275 proto::CallCanceled {
276 room_id: room_id.to_proto(),
277 },
278 )
279 .trace_err();
280 }
281 }
282 }
283
284 for user_id in contacts_to_update {
285 let busy = db.is_user_busy(user_id).await.trace_err();
286 let contacts = db.get_contacts(user_id).await.trace_err();
287 if let Some((busy, contacts)) = busy.zip(contacts) {
288 let pool = pool.lock();
289 let updated_contact = contact_for_user(user_id, false, busy, &pool);
290 for contact in contacts {
291 if let db::Contact::Accepted {
292 user_id: contact_user_id,
293 ..
294 } = contact
295 {
296 for contact_conn_id in pool.user_connection_ids(contact_user_id)
297 {
298 peer.send(
299 contact_conn_id,
300 proto::UpdateContacts {
301 contacts: vec![updated_contact.clone()],
302 remove_contacts: Default::default(),
303 incoming_requests: Default::default(),
304 remove_incoming_requests: Default::default(),
305 outgoing_requests: Default::default(),
306 remove_outgoing_requests: Default::default(),
307 },
308 )
309 .trace_err();
310 }
311 }
312 }
313 }
314 }
315
316 if let Some(live_kit) = live_kit_client.as_ref() {
317 if delete_live_kit_room {
318 live_kit.delete_room(live_kit_room).await.trace_err();
319 }
320 }
321 }
322 }
323 });
324 Ok(())
325 }
326
327 pub fn teardown(&self) {
328 self.peer.reset();
329 self.connection_pool.lock().reset();
330 let _ = self.teardown.send(());
331 }
332
333 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
334 where
335 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
336 Fut: 'static + Send + Future<Output = Result<()>>,
337 M: EnvelopedMessage,
338 {
339 let prev_handler = self.handlers.insert(
340 TypeId::of::<M>(),
341 Box::new(move |envelope, session| {
342 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
343 let span = info_span!(
344 "handle message",
345 payload_type = envelope.payload_type_name()
346 );
347 span.in_scope(|| {
348 tracing::info!(
349 payload_type = envelope.payload_type_name(),
350 "message received"
351 );
352 });
353 let future = (handler)(*envelope, session);
354 async move {
355 if let Err(error) = future.await {
356 tracing::error!(%error, "error handling message");
357 }
358 }
359 .instrument(span)
360 .boxed()
361 }),
362 );
363 if prev_handler.is_some() {
364 panic!("registered a handler for the same message twice");
365 }
366 self
367 }
368
369 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
370 where
371 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
372 Fut: 'static + Send + Future<Output = Result<()>>,
373 M: EnvelopedMessage,
374 {
375 self.add_handler(move |envelope, session| handler(envelope.payload, session));
376 self
377 }
378
379 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
380 where
381 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
382 Fut: Send + Future<Output = Result<()>>,
383 M: RequestMessage,
384 {
385 let handler = Arc::new(handler);
386 self.add_handler(move |envelope, session| {
387 let receipt = envelope.receipt();
388 let handler = handler.clone();
389 async move {
390 let peer = session.peer.clone();
391 let responded = Arc::new(AtomicBool::default());
392 let response = Response {
393 peer: peer.clone(),
394 responded: responded.clone(),
395 receipt,
396 };
397 match (handler)(envelope.payload, response, session).await {
398 Ok(()) => {
399 if responded.load(std::sync::atomic::Ordering::SeqCst) {
400 Ok(())
401 } else {
402 Err(anyhow!("handler did not send a response"))?
403 }
404 }
405 Err(error) => {
406 peer.respond_with_error(
407 receipt,
408 proto::Error {
409 message: error.to_string(),
410 },
411 )?;
412 Err(error)
413 }
414 }
415 }
416 })
417 }
418
419 pub fn handle_connection(
420 self: &Arc<Self>,
421 connection: Connection,
422 address: String,
423 user: User,
424 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
425 executor: Executor,
426 ) -> impl Future<Output = Result<()>> {
427 let this = self.clone();
428 let user_id = user.id;
429 let login = user.github_login;
430 let span = info_span!("handle connection", %user_id, %login, %address);
431 let mut teardown = self.teardown.subscribe();
432 async move {
433 let (connection_id, handle_io, mut incoming_rx) = this
434 .peer
435 .add_connection(connection, {
436 let executor = executor.clone();
437 move |duration| executor.sleep(duration)
438 });
439
440 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
441 this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
442 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
443
444 if let Some(send_connection_id) = send_connection_id.take() {
445 let _ = send_connection_id.send(connection_id);
446 }
447
448 if !user.connected_once {
449 this.peer.send(connection_id, proto::ShowContacts {})?;
450 this.app_state.db.set_user_connected_once(user_id, true).await?;
451 }
452
453 let (contacts, invite_code) = future::try_join(
454 this.app_state.db.get_contacts(user_id),
455 this.app_state.db.get_invite_code_for_user(user_id)
456 ).await?;
457
458 {
459 let mut pool = this.connection_pool.lock();
460 pool.add_connection(connection_id, user_id, user.admin);
461 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
462
463 if let Some((code, count)) = invite_code {
464 this.peer.send(connection_id, proto::UpdateInviteInfo {
465 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
466 count: count as u32,
467 })?;
468 }
469 }
470
471 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
472 this.peer.send(connection_id, incoming_call)?;
473 }
474
475 let session = Session {
476 user_id,
477 connection_id,
478 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
479 peer: this.peer.clone(),
480 connection_pool: this.connection_pool.clone(),
481 live_kit_client: this.app_state.live_kit_client.clone()
482 };
483 update_user_contacts(user_id, &session).await?;
484
485 let handle_io = handle_io.fuse();
486 futures::pin_mut!(handle_io);
487
488 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
489 // This prevents deadlocks when e.g., client A performs a request to client B and
490 // client B performs a request to client A. If both clients stop processing further
491 // messages until their respective request completes, they won't have a chance to
492 // respond to the other client's request and cause a deadlock.
493 //
494 // This arrangement ensures we will attempt to process earlier messages first, but fall
495 // back to processing messages arrived later in the spirit of making progress.
496 let mut foreground_message_handlers = FuturesUnordered::new();
497 loop {
498 let next_message = incoming_rx.next().fuse();
499 futures::pin_mut!(next_message);
500 futures::select_biased! {
501 _ = teardown.changed().fuse() => return Ok(()),
502 result = handle_io => {
503 if let Err(error) = result {
504 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
505 }
506 break;
507 }
508 _ = foreground_message_handlers.next() => {}
509 message = next_message => {
510 if let Some(message) = message {
511 let type_name = message.payload_type_name();
512 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
513 let span_enter = span.enter();
514 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
515 let is_background = message.is_background();
516 let handle_message = (handler)(message, session.clone());
517 drop(span_enter);
518
519 let handle_message = handle_message.instrument(span);
520 if is_background {
521 executor.spawn_detached(handle_message);
522 } else {
523 foreground_message_handlers.push(handle_message);
524 }
525 } else {
526 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
527 }
528 } else {
529 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
530 break;
531 }
532 }
533 }
534 }
535
536 drop(foreground_message_handlers);
537 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
538 if let Err(error) = sign_out(session, teardown, executor).await {
539 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
540 }
541
542 Ok(())
543 }.instrument(span)
544 }
545
546 pub async fn invite_code_redeemed(
547 self: &Arc<Self>,
548 inviter_id: UserId,
549 invitee_id: UserId,
550 ) -> Result<()> {
551 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
552 if let Some(code) = &user.invite_code {
553 let pool = self.connection_pool.lock();
554 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
555 for connection_id in pool.user_connection_ids(inviter_id) {
556 self.peer.send(
557 connection_id,
558 proto::UpdateContacts {
559 contacts: vec![invitee_contact.clone()],
560 ..Default::default()
561 },
562 )?;
563 self.peer.send(
564 connection_id,
565 proto::UpdateInviteInfo {
566 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
567 count: user.invite_count as u32,
568 },
569 )?;
570 }
571 }
572 }
573 Ok(())
574 }
575
576 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
577 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
578 if let Some(invite_code) = &user.invite_code {
579 let pool = self.connection_pool.lock();
580 for connection_id in pool.user_connection_ids(user_id) {
581 self.peer.send(
582 connection_id,
583 proto::UpdateInviteInfo {
584 url: format!(
585 "{}{}",
586 self.app_state.config.invite_link_prefix, invite_code
587 ),
588 count: user.invite_count as u32,
589 },
590 )?;
591 }
592 }
593 }
594 Ok(())
595 }
596
597 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
598 ServerSnapshot {
599 connection_pool: ConnectionPoolGuard {
600 guard: self.connection_pool.lock(),
601 _not_send: PhantomData,
602 },
603 peer: &self.peer,
604 }
605 }
606}
607
608impl<'a> Deref for ConnectionPoolGuard<'a> {
609 type Target = ConnectionPool;
610
611 fn deref(&self) -> &Self::Target {
612 &*self.guard
613 }
614}
615
616impl<'a> DerefMut for ConnectionPoolGuard<'a> {
617 fn deref_mut(&mut self) -> &mut Self::Target {
618 &mut *self.guard
619 }
620}
621
622impl<'a> Drop for ConnectionPoolGuard<'a> {
623 fn drop(&mut self) {
624 #[cfg(test)]
625 self.check_invariants();
626 }
627}
628
629fn broadcast<F>(
630 sender_id: ConnectionId,
631 receiver_ids: impl IntoIterator<Item = ConnectionId>,
632 mut f: F,
633) where
634 F: FnMut(ConnectionId) -> anyhow::Result<()>,
635{
636 for receiver_id in receiver_ids {
637 if receiver_id != sender_id {
638 f(receiver_id).trace_err();
639 }
640 }
641}
642
643lazy_static! {
644 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
645}
646
647pub struct ProtocolVersion(u32);
648
649impl Header for ProtocolVersion {
650 fn name() -> &'static HeaderName {
651 &ZED_PROTOCOL_VERSION
652 }
653
654 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
655 where
656 Self: Sized,
657 I: Iterator<Item = &'i axum::http::HeaderValue>,
658 {
659 let version = values
660 .next()
661 .ok_or_else(axum::headers::Error::invalid)?
662 .to_str()
663 .map_err(|_| axum::headers::Error::invalid())?
664 .parse()
665 .map_err(|_| axum::headers::Error::invalid())?;
666 Ok(Self(version))
667 }
668
669 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
670 values.extend([self.0.to_string().parse().unwrap()]);
671 }
672}
673
674pub fn routes(server: Arc<Server>) -> Router<Body> {
675 Router::new()
676 .route("/rpc", get(handle_websocket_request))
677 .layer(
678 ServiceBuilder::new()
679 .layer(Extension(server.app_state.clone()))
680 .layer(middleware::from_fn(auth::validate_header)),
681 )
682 .route("/metrics", get(handle_metrics))
683 .layer(Extension(server))
684}
685
686pub async fn handle_websocket_request(
687 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
688 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
689 Extension(server): Extension<Arc<Server>>,
690 Extension(user): Extension<User>,
691 ws: WebSocketUpgrade,
692) -> axum::response::Response {
693 if protocol_version != rpc::PROTOCOL_VERSION {
694 return (
695 StatusCode::UPGRADE_REQUIRED,
696 "client must be upgraded".to_string(),
697 )
698 .into_response();
699 }
700 let socket_address = socket_address.to_string();
701 ws.on_upgrade(move |socket| {
702 use util::ResultExt;
703 let socket = socket
704 .map_ok(to_tungstenite_message)
705 .err_into()
706 .with(|message| async move { Ok(to_axum_message(message)) });
707 let connection = Connection::new(Box::pin(socket));
708 async move {
709 server
710 .handle_connection(connection, socket_address, user, None, Executor::Production)
711 .await
712 .log_err();
713 }
714 })
715}
716
717pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
718 let connections = server
719 .connection_pool
720 .lock()
721 .connections()
722 .filter(|connection| !connection.admin)
723 .count();
724
725 METRIC_CONNECTIONS.set(connections as _);
726
727 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
728 METRIC_SHARED_PROJECTS.set(shared_projects as _);
729
730 let encoder = prometheus::TextEncoder::new();
731 let metric_families = prometheus::gather();
732 let encoded_metrics = encoder
733 .encode_to_string(&metric_families)
734 .map_err(|err| anyhow!("{}", err))?;
735 Ok(encoded_metrics)
736}
737
738#[instrument(err, skip(executor))]
739async fn sign_out(
740 session: Session,
741 mut teardown: watch::Receiver<()>,
742 executor: Executor,
743) -> Result<()> {
744 session.peer.disconnect(session.connection_id);
745 session
746 .connection_pool()
747 .await
748 .remove_connection(session.connection_id)?;
749
750 if let Some(mut left_projects) = session
751 .db()
752 .await
753 .connection_lost(session.connection_id)
754 .await
755 .trace_err()
756 {
757 for left_project in mem::take(&mut *left_projects) {
758 project_left(&left_project, &session);
759 }
760 }
761
762 futures::select_biased! {
763 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
764 leave_room_for_session(&session).await.trace_err();
765
766 if !session
767 .connection_pool()
768 .await
769 .is_user_online(session.user_id)
770 {
771 let db = session.db().await;
772 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err() {
773 room_updated(&room, &session.peer);
774 }
775 }
776 update_user_contacts(session.user_id, &session).await?;
777 }
778 _ = teardown.changed().fuse() => {}
779 }
780
781 Ok(())
782}
783
784async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
785 response.send(proto::Ack {})?;
786 Ok(())
787}
788
789async fn create_room(
790 _request: proto::CreateRoom,
791 response: Response<proto::CreateRoom>,
792 session: Session,
793) -> Result<()> {
794 let live_kit_room = nanoid::nanoid!(30);
795 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
796 if let Some(_) = live_kit
797 .create_room(live_kit_room.clone())
798 .await
799 .trace_err()
800 {
801 if let Some(token) = live_kit
802 .room_token(&live_kit_room, &session.connection_id.to_string())
803 .trace_err()
804 {
805 Some(proto::LiveKitConnectionInfo {
806 server_url: live_kit.url().into(),
807 token,
808 })
809 } else {
810 None
811 }
812 } else {
813 None
814 }
815 } else {
816 None
817 };
818
819 {
820 let room = session
821 .db()
822 .await
823 .create_room(session.user_id, session.connection_id, &live_kit_room)
824 .await?;
825
826 response.send(proto::CreateRoomResponse {
827 room: Some(room.clone()),
828 live_kit_connection_info,
829 })?;
830 }
831
832 update_user_contacts(session.user_id, &session).await?;
833 Ok(())
834}
835
836async fn join_room(
837 request: proto::JoinRoom,
838 response: Response<proto::JoinRoom>,
839 session: Session,
840) -> Result<()> {
841 let room_id = RoomId::from_proto(request.id);
842 let room = {
843 let room = session
844 .db()
845 .await
846 .join_room(room_id, session.user_id, session.connection_id)
847 .await?;
848 room_updated(&room, &session.peer);
849 room.clone()
850 };
851
852 for connection_id in session
853 .connection_pool()
854 .await
855 .user_connection_ids(session.user_id)
856 {
857 session
858 .peer
859 .send(
860 connection_id,
861 proto::CallCanceled {
862 room_id: room_id.to_proto(),
863 },
864 )
865 .trace_err();
866 }
867
868 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
869 if let Some(token) = live_kit
870 .room_token(&room.live_kit_room, &session.connection_id.to_string())
871 .trace_err()
872 {
873 Some(proto::LiveKitConnectionInfo {
874 server_url: live_kit.url().into(),
875 token,
876 })
877 } else {
878 None
879 }
880 } else {
881 None
882 };
883
884 response.send(proto::JoinRoomResponse {
885 room: Some(room),
886 live_kit_connection_info,
887 })?;
888
889 update_user_contacts(session.user_id, &session).await?;
890 Ok(())
891}
892
893async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
894 leave_room_for_session(&session).await
895}
896
897async fn call(
898 request: proto::Call,
899 response: Response<proto::Call>,
900 session: Session,
901) -> Result<()> {
902 let room_id = RoomId::from_proto(request.room_id);
903 let calling_user_id = session.user_id;
904 let calling_connection_id = session.connection_id;
905 let called_user_id = UserId::from_proto(request.called_user_id);
906 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
907 if !session
908 .db()
909 .await
910 .has_contact(calling_user_id, called_user_id)
911 .await?
912 {
913 return Err(anyhow!("cannot call a user who isn't a contact"))?;
914 }
915
916 let incoming_call = {
917 let (room, incoming_call) = &mut *session
918 .db()
919 .await
920 .call(
921 room_id,
922 calling_user_id,
923 calling_connection_id,
924 called_user_id,
925 initial_project_id,
926 )
927 .await?;
928 room_updated(&room, &session.peer);
929 mem::take(incoming_call)
930 };
931 update_user_contacts(called_user_id, &session).await?;
932
933 let mut calls = session
934 .connection_pool()
935 .await
936 .user_connection_ids(called_user_id)
937 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
938 .collect::<FuturesUnordered<_>>();
939
940 while let Some(call_response) = calls.next().await {
941 match call_response.as_ref() {
942 Ok(_) => {
943 response.send(proto::Ack {})?;
944 return Ok(());
945 }
946 Err(_) => {
947 call_response.trace_err();
948 }
949 }
950 }
951
952 {
953 let room = session
954 .db()
955 .await
956 .call_failed(room_id, called_user_id)
957 .await?;
958 room_updated(&room, &session.peer);
959 }
960 update_user_contacts(called_user_id, &session).await?;
961
962 Err(anyhow!("failed to ring user"))?
963}
964
965async fn cancel_call(
966 request: proto::CancelCall,
967 response: Response<proto::CancelCall>,
968 session: Session,
969) -> Result<()> {
970 let called_user_id = UserId::from_proto(request.called_user_id);
971 let room_id = RoomId::from_proto(request.room_id);
972 {
973 let room = session
974 .db()
975 .await
976 .cancel_call(Some(room_id), session.connection_id, called_user_id)
977 .await?;
978 room_updated(&room, &session.peer);
979 }
980
981 for connection_id in session
982 .connection_pool()
983 .await
984 .user_connection_ids(called_user_id)
985 {
986 session
987 .peer
988 .send(
989 connection_id,
990 proto::CallCanceled {
991 room_id: room_id.to_proto(),
992 },
993 )
994 .trace_err();
995 }
996 response.send(proto::Ack {})?;
997
998 update_user_contacts(called_user_id, &session).await?;
999 Ok(())
1000}
1001
1002async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1003 let room_id = RoomId::from_proto(message.room_id);
1004 {
1005 let room = session
1006 .db()
1007 .await
1008 .decline_call(Some(room_id), session.user_id)
1009 .await?;
1010 room_updated(&room, &session.peer);
1011 }
1012
1013 for connection_id in session
1014 .connection_pool()
1015 .await
1016 .user_connection_ids(session.user_id)
1017 {
1018 session
1019 .peer
1020 .send(
1021 connection_id,
1022 proto::CallCanceled {
1023 room_id: room_id.to_proto(),
1024 },
1025 )
1026 .trace_err();
1027 }
1028 update_user_contacts(session.user_id, &session).await?;
1029 Ok(())
1030}
1031
1032async fn update_participant_location(
1033 request: proto::UpdateParticipantLocation,
1034 response: Response<proto::UpdateParticipantLocation>,
1035 session: Session,
1036) -> Result<()> {
1037 let room_id = RoomId::from_proto(request.room_id);
1038 let location = request
1039 .location
1040 .ok_or_else(|| anyhow!("invalid location"))?;
1041 let room = session
1042 .db()
1043 .await
1044 .update_room_participant_location(room_id, session.connection_id, location)
1045 .await?;
1046 room_updated(&room, &session.peer);
1047 response.send(proto::Ack {})?;
1048 Ok(())
1049}
1050
1051async fn share_project(
1052 request: proto::ShareProject,
1053 response: Response<proto::ShareProject>,
1054 session: Session,
1055) -> Result<()> {
1056 let (project_id, room) = &*session
1057 .db()
1058 .await
1059 .share_project(
1060 RoomId::from_proto(request.room_id),
1061 session.connection_id,
1062 &request.worktrees,
1063 )
1064 .await?;
1065 response.send(proto::ShareProjectResponse {
1066 project_id: project_id.to_proto(),
1067 })?;
1068 room_updated(&room, &session.peer);
1069
1070 Ok(())
1071}
1072
1073async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1074 let project_id = ProjectId::from_proto(message.project_id);
1075
1076 let (room, guest_connection_ids) = &*session
1077 .db()
1078 .await
1079 .unshare_project(project_id, session.connection_id)
1080 .await?;
1081
1082 broadcast(
1083 session.connection_id,
1084 guest_connection_ids.iter().copied(),
1085 |conn_id| session.peer.send(conn_id, message.clone()),
1086 );
1087 room_updated(&room, &session.peer);
1088
1089 Ok(())
1090}
1091
1092async fn join_project(
1093 request: proto::JoinProject,
1094 response: Response<proto::JoinProject>,
1095 session: Session,
1096) -> Result<()> {
1097 let project_id = ProjectId::from_proto(request.project_id);
1098 let guest_user_id = session.user_id;
1099
1100 tracing::info!(%project_id, "join project");
1101
1102 let (project, replica_id) = &mut *session
1103 .db()
1104 .await
1105 .join_project(project_id, session.connection_id)
1106 .await?;
1107
1108 let collaborators = project
1109 .collaborators
1110 .iter()
1111 .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32)
1112 .map(|collaborator| proto::Collaborator {
1113 peer_id: collaborator.connection_id as u32,
1114 replica_id: collaborator.replica_id.0 as u32,
1115 user_id: collaborator.user_id.to_proto(),
1116 })
1117 .collect::<Vec<_>>();
1118 let worktrees = project
1119 .worktrees
1120 .iter()
1121 .map(|(id, worktree)| proto::WorktreeMetadata {
1122 id: *id,
1123 root_name: worktree.root_name.clone(),
1124 visible: worktree.visible,
1125 abs_path: worktree.abs_path.clone(),
1126 })
1127 .collect::<Vec<_>>();
1128
1129 for collaborator in &collaborators {
1130 session
1131 .peer
1132 .send(
1133 ConnectionId(collaborator.peer_id),
1134 proto::AddProjectCollaborator {
1135 project_id: project_id.to_proto(),
1136 collaborator: Some(proto::Collaborator {
1137 peer_id: session.connection_id.0,
1138 replica_id: replica_id.0 as u32,
1139 user_id: guest_user_id.to_proto(),
1140 }),
1141 },
1142 )
1143 .trace_err();
1144 }
1145
1146 // First, we send the metadata associated with each worktree.
1147 response.send(proto::JoinProjectResponse {
1148 worktrees: worktrees.clone(),
1149 replica_id: replica_id.0 as u32,
1150 collaborators: collaborators.clone(),
1151 language_servers: project.language_servers.clone(),
1152 })?;
1153
1154 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1155 #[cfg(any(test, feature = "test-support"))]
1156 const MAX_CHUNK_SIZE: usize = 2;
1157 #[cfg(not(any(test, feature = "test-support")))]
1158 const MAX_CHUNK_SIZE: usize = 256;
1159
1160 // Stream this worktree's entries.
1161 let message = proto::UpdateWorktree {
1162 project_id: project_id.to_proto(),
1163 worktree_id,
1164 abs_path: worktree.abs_path.clone(),
1165 root_name: worktree.root_name,
1166 updated_entries: worktree.entries,
1167 removed_entries: Default::default(),
1168 scan_id: worktree.scan_id,
1169 is_last_update: worktree.is_complete,
1170 };
1171 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1172 session.peer.send(session.connection_id, update.clone())?;
1173 }
1174
1175 // Stream this worktree's diagnostics.
1176 for summary in worktree.diagnostic_summaries {
1177 session.peer.send(
1178 session.connection_id,
1179 proto::UpdateDiagnosticSummary {
1180 project_id: project_id.to_proto(),
1181 worktree_id: worktree.id,
1182 summary: Some(summary),
1183 },
1184 )?;
1185 }
1186 }
1187
1188 for language_server in &project.language_servers {
1189 session.peer.send(
1190 session.connection_id,
1191 proto::UpdateLanguageServer {
1192 project_id: project_id.to_proto(),
1193 language_server_id: language_server.id,
1194 variant: Some(
1195 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1196 proto::LspDiskBasedDiagnosticsUpdated {},
1197 ),
1198 ),
1199 },
1200 )?;
1201 }
1202
1203 Ok(())
1204}
1205
1206async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1207 let sender_id = session.connection_id;
1208 let project_id = ProjectId::from_proto(request.project_id);
1209
1210 let project = session
1211 .db()
1212 .await
1213 .leave_project(project_id, sender_id)
1214 .await?;
1215 tracing::info!(
1216 %project_id,
1217 host_user_id = %project.host_user_id,
1218 host_connection_id = %project.host_connection_id,
1219 "leave project"
1220 );
1221 project_left(&project, &session);
1222
1223 Ok(())
1224}
1225
1226async fn update_project(
1227 request: proto::UpdateProject,
1228 response: Response<proto::UpdateProject>,
1229 session: Session,
1230) -> Result<()> {
1231 let project_id = ProjectId::from_proto(request.project_id);
1232 let (room, guest_connection_ids) = &*session
1233 .db()
1234 .await
1235 .update_project(project_id, session.connection_id, &request.worktrees)
1236 .await?;
1237 broadcast(
1238 session.connection_id,
1239 guest_connection_ids.iter().copied(),
1240 |connection_id| {
1241 session
1242 .peer
1243 .forward_send(session.connection_id, connection_id, request.clone())
1244 },
1245 );
1246 room_updated(&room, &session.peer);
1247 response.send(proto::Ack {})?;
1248
1249 Ok(())
1250}
1251
1252async fn update_worktree(
1253 request: proto::UpdateWorktree,
1254 response: Response<proto::UpdateWorktree>,
1255 session: Session,
1256) -> Result<()> {
1257 let guest_connection_ids = session
1258 .db()
1259 .await
1260 .update_worktree(&request, session.connection_id)
1261 .await?;
1262
1263 broadcast(
1264 session.connection_id,
1265 guest_connection_ids.iter().copied(),
1266 |connection_id| {
1267 session
1268 .peer
1269 .forward_send(session.connection_id, connection_id, request.clone())
1270 },
1271 );
1272 response.send(proto::Ack {})?;
1273 Ok(())
1274}
1275
1276async fn update_diagnostic_summary(
1277 message: proto::UpdateDiagnosticSummary,
1278 session: Session,
1279) -> Result<()> {
1280 let guest_connection_ids = session
1281 .db()
1282 .await
1283 .update_diagnostic_summary(&message, session.connection_id)
1284 .await?;
1285
1286 broadcast(
1287 session.connection_id,
1288 guest_connection_ids.iter().copied(),
1289 |connection_id| {
1290 session
1291 .peer
1292 .forward_send(session.connection_id, connection_id, message.clone())
1293 },
1294 );
1295
1296 Ok(())
1297}
1298
1299async fn start_language_server(
1300 request: proto::StartLanguageServer,
1301 session: Session,
1302) -> Result<()> {
1303 let guest_connection_ids = session
1304 .db()
1305 .await
1306 .start_language_server(&request, session.connection_id)
1307 .await?;
1308
1309 broadcast(
1310 session.connection_id,
1311 guest_connection_ids.iter().copied(),
1312 |connection_id| {
1313 session
1314 .peer
1315 .forward_send(session.connection_id, connection_id, request.clone())
1316 },
1317 );
1318 Ok(())
1319}
1320
1321async fn update_language_server(
1322 request: proto::UpdateLanguageServer,
1323 session: Session,
1324) -> Result<()> {
1325 let project_id = ProjectId::from_proto(request.project_id);
1326 let project_connection_ids = session
1327 .db()
1328 .await
1329 .project_connection_ids(project_id, session.connection_id)
1330 .await?;
1331 broadcast(
1332 session.connection_id,
1333 project_connection_ids.iter().copied(),
1334 |connection_id| {
1335 session
1336 .peer
1337 .forward_send(session.connection_id, connection_id, request.clone())
1338 },
1339 );
1340 Ok(())
1341}
1342
1343async fn forward_project_request<T>(
1344 request: T,
1345 response: Response<T>,
1346 session: Session,
1347) -> Result<()>
1348where
1349 T: EntityMessage + RequestMessage,
1350{
1351 let project_id = ProjectId::from_proto(request.remote_entity_id());
1352 let host_connection_id = {
1353 let collaborators = session
1354 .db()
1355 .await
1356 .project_collaborators(project_id, session.connection_id)
1357 .await?;
1358 ConnectionId(
1359 collaborators
1360 .iter()
1361 .find(|collaborator| collaborator.is_host)
1362 .ok_or_else(|| anyhow!("host not found"))?
1363 .connection_id as u32,
1364 )
1365 };
1366
1367 let payload = session
1368 .peer
1369 .forward_request(session.connection_id, host_connection_id, request)
1370 .await?;
1371
1372 response.send(payload)?;
1373 Ok(())
1374}
1375
1376async fn save_buffer(
1377 request: proto::SaveBuffer,
1378 response: Response<proto::SaveBuffer>,
1379 session: Session,
1380) -> Result<()> {
1381 let project_id = ProjectId::from_proto(request.project_id);
1382 let host_connection_id = {
1383 let collaborators = session
1384 .db()
1385 .await
1386 .project_collaborators(project_id, session.connection_id)
1387 .await?;
1388 let host = collaborators
1389 .iter()
1390 .find(|collaborator| collaborator.is_host)
1391 .ok_or_else(|| anyhow!("host not found"))?;
1392 ConnectionId(host.connection_id as u32)
1393 };
1394 let response_payload = session
1395 .peer
1396 .forward_request(session.connection_id, host_connection_id, request.clone())
1397 .await?;
1398
1399 let mut collaborators = session
1400 .db()
1401 .await
1402 .project_collaborators(project_id, session.connection_id)
1403 .await?;
1404 collaborators
1405 .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32);
1406 let project_connection_ids = collaborators
1407 .iter()
1408 .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
1409 broadcast(host_connection_id, project_connection_ids, |conn_id| {
1410 session
1411 .peer
1412 .forward_send(host_connection_id, conn_id, response_payload.clone())
1413 });
1414 response.send(response_payload)?;
1415 Ok(())
1416}
1417
1418async fn create_buffer_for_peer(
1419 request: proto::CreateBufferForPeer,
1420 session: Session,
1421) -> Result<()> {
1422 session.peer.forward_send(
1423 session.connection_id,
1424 ConnectionId(request.peer_id),
1425 request,
1426 )?;
1427 Ok(())
1428}
1429
1430async fn update_buffer(
1431 request: proto::UpdateBuffer,
1432 response: Response<proto::UpdateBuffer>,
1433 session: Session,
1434) -> Result<()> {
1435 let project_id = ProjectId::from_proto(request.project_id);
1436 let project_connection_ids = session
1437 .db()
1438 .await
1439 .project_connection_ids(project_id, session.connection_id)
1440 .await?;
1441
1442 broadcast(
1443 session.connection_id,
1444 project_connection_ids.iter().copied(),
1445 |connection_id| {
1446 session
1447 .peer
1448 .forward_send(session.connection_id, connection_id, request.clone())
1449 },
1450 );
1451 response.send(proto::Ack {})?;
1452 Ok(())
1453}
1454
1455async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1456 let project_id = ProjectId::from_proto(request.project_id);
1457 let project_connection_ids = session
1458 .db()
1459 .await
1460 .project_connection_ids(project_id, session.connection_id)
1461 .await?;
1462
1463 broadcast(
1464 session.connection_id,
1465 project_connection_ids.iter().copied(),
1466 |connection_id| {
1467 session
1468 .peer
1469 .forward_send(session.connection_id, connection_id, request.clone())
1470 },
1471 );
1472 Ok(())
1473}
1474
1475async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1476 let project_id = ProjectId::from_proto(request.project_id);
1477 let project_connection_ids = session
1478 .db()
1479 .await
1480 .project_connection_ids(project_id, session.connection_id)
1481 .await?;
1482 broadcast(
1483 session.connection_id,
1484 project_connection_ids.iter().copied(),
1485 |connection_id| {
1486 session
1487 .peer
1488 .forward_send(session.connection_id, connection_id, request.clone())
1489 },
1490 );
1491 Ok(())
1492}
1493
1494async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1495 let project_id = ProjectId::from_proto(request.project_id);
1496 let project_connection_ids = session
1497 .db()
1498 .await
1499 .project_connection_ids(project_id, session.connection_id)
1500 .await?;
1501 broadcast(
1502 session.connection_id,
1503 project_connection_ids.iter().copied(),
1504 |connection_id| {
1505 session
1506 .peer
1507 .forward_send(session.connection_id, connection_id, request.clone())
1508 },
1509 );
1510 Ok(())
1511}
1512
1513async fn follow(
1514 request: proto::Follow,
1515 response: Response<proto::Follow>,
1516 session: Session,
1517) -> Result<()> {
1518 let project_id = ProjectId::from_proto(request.project_id);
1519 let leader_id = ConnectionId(request.leader_id);
1520 let follower_id = session.connection_id;
1521 {
1522 let project_connection_ids = session
1523 .db()
1524 .await
1525 .project_connection_ids(project_id, session.connection_id)
1526 .await?;
1527
1528 if !project_connection_ids.contains(&leader_id) {
1529 Err(anyhow!("no such peer"))?;
1530 }
1531 }
1532
1533 let mut response_payload = session
1534 .peer
1535 .forward_request(session.connection_id, leader_id, request)
1536 .await?;
1537 response_payload
1538 .views
1539 .retain(|view| view.leader_id != Some(follower_id.0));
1540 response.send(response_payload)?;
1541 Ok(())
1542}
1543
1544async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1545 let project_id = ProjectId::from_proto(request.project_id);
1546 let leader_id = ConnectionId(request.leader_id);
1547 let project_connection_ids = session
1548 .db()
1549 .await
1550 .project_connection_ids(project_id, session.connection_id)
1551 .await?;
1552 if !project_connection_ids.contains(&leader_id) {
1553 Err(anyhow!("no such peer"))?;
1554 }
1555 session
1556 .peer
1557 .forward_send(session.connection_id, leader_id, request)?;
1558 Ok(())
1559}
1560
1561async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1562 let project_id = ProjectId::from_proto(request.project_id);
1563 let project_connection_ids = session
1564 .db
1565 .lock()
1566 .await
1567 .project_connection_ids(project_id, session.connection_id)
1568 .await?;
1569
1570 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1571 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1572 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1573 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1574 });
1575 for follower_id in &request.follower_ids {
1576 let follower_id = ConnectionId(*follower_id);
1577 if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1578 session
1579 .peer
1580 .forward_send(session.connection_id, follower_id, request.clone())?;
1581 }
1582 }
1583 Ok(())
1584}
1585
1586async fn get_users(
1587 request: proto::GetUsers,
1588 response: Response<proto::GetUsers>,
1589 session: Session,
1590) -> Result<()> {
1591 let user_ids = request
1592 .user_ids
1593 .into_iter()
1594 .map(UserId::from_proto)
1595 .collect();
1596 let users = session
1597 .db()
1598 .await
1599 .get_users_by_ids(user_ids)
1600 .await?
1601 .into_iter()
1602 .map(|user| proto::User {
1603 id: user.id.to_proto(),
1604 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1605 github_login: user.github_login,
1606 })
1607 .collect();
1608 response.send(proto::UsersResponse { users })?;
1609 Ok(())
1610}
1611
1612async fn fuzzy_search_users(
1613 request: proto::FuzzySearchUsers,
1614 response: Response<proto::FuzzySearchUsers>,
1615 session: Session,
1616) -> Result<()> {
1617 let query = request.query;
1618 let users = match query.len() {
1619 0 => vec![],
1620 1 | 2 => session
1621 .db()
1622 .await
1623 .get_user_by_github_account(&query, None)
1624 .await?
1625 .into_iter()
1626 .collect(),
1627 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1628 };
1629 let users = users
1630 .into_iter()
1631 .filter(|user| user.id != session.user_id)
1632 .map(|user| proto::User {
1633 id: user.id.to_proto(),
1634 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1635 github_login: user.github_login,
1636 })
1637 .collect();
1638 response.send(proto::UsersResponse { users })?;
1639 Ok(())
1640}
1641
1642async fn request_contact(
1643 request: proto::RequestContact,
1644 response: Response<proto::RequestContact>,
1645 session: Session,
1646) -> Result<()> {
1647 let requester_id = session.user_id;
1648 let responder_id = UserId::from_proto(request.responder_id);
1649 if requester_id == responder_id {
1650 return Err(anyhow!("cannot add yourself as a contact"))?;
1651 }
1652
1653 session
1654 .db()
1655 .await
1656 .send_contact_request(requester_id, responder_id)
1657 .await?;
1658
1659 // Update outgoing contact requests of requester
1660 let mut update = proto::UpdateContacts::default();
1661 update.outgoing_requests.push(responder_id.to_proto());
1662 for connection_id in session
1663 .connection_pool()
1664 .await
1665 .user_connection_ids(requester_id)
1666 {
1667 session.peer.send(connection_id, update.clone())?;
1668 }
1669
1670 // Update incoming contact requests of responder
1671 let mut update = proto::UpdateContacts::default();
1672 update
1673 .incoming_requests
1674 .push(proto::IncomingContactRequest {
1675 requester_id: requester_id.to_proto(),
1676 should_notify: true,
1677 });
1678 for connection_id in session
1679 .connection_pool()
1680 .await
1681 .user_connection_ids(responder_id)
1682 {
1683 session.peer.send(connection_id, update.clone())?;
1684 }
1685
1686 response.send(proto::Ack {})?;
1687 Ok(())
1688}
1689
1690async fn respond_to_contact_request(
1691 request: proto::RespondToContactRequest,
1692 response: Response<proto::RespondToContactRequest>,
1693 session: Session,
1694) -> Result<()> {
1695 let responder_id = session.user_id;
1696 let requester_id = UserId::from_proto(request.requester_id);
1697 let db = session.db().await;
1698 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1699 db.dismiss_contact_notification(responder_id, requester_id)
1700 .await?;
1701 } else {
1702 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1703
1704 db.respond_to_contact_request(responder_id, requester_id, accept)
1705 .await?;
1706 let requester_busy = db.is_user_busy(requester_id).await?;
1707 let responder_busy = db.is_user_busy(responder_id).await?;
1708
1709 let pool = session.connection_pool().await;
1710 // Update responder with new contact
1711 let mut update = proto::UpdateContacts::default();
1712 if accept {
1713 update
1714 .contacts
1715 .push(contact_for_user(requester_id, false, requester_busy, &pool));
1716 }
1717 update
1718 .remove_incoming_requests
1719 .push(requester_id.to_proto());
1720 for connection_id in pool.user_connection_ids(responder_id) {
1721 session.peer.send(connection_id, update.clone())?;
1722 }
1723
1724 // Update requester with new contact
1725 let mut update = proto::UpdateContacts::default();
1726 if accept {
1727 update
1728 .contacts
1729 .push(contact_for_user(responder_id, true, responder_busy, &pool));
1730 }
1731 update
1732 .remove_outgoing_requests
1733 .push(responder_id.to_proto());
1734 for connection_id in pool.user_connection_ids(requester_id) {
1735 session.peer.send(connection_id, update.clone())?;
1736 }
1737 }
1738
1739 response.send(proto::Ack {})?;
1740 Ok(())
1741}
1742
1743async fn remove_contact(
1744 request: proto::RemoveContact,
1745 response: Response<proto::RemoveContact>,
1746 session: Session,
1747) -> Result<()> {
1748 let requester_id = session.user_id;
1749 let responder_id = UserId::from_proto(request.user_id);
1750 let db = session.db().await;
1751 db.remove_contact(requester_id, responder_id).await?;
1752
1753 let pool = session.connection_pool().await;
1754 // Update outgoing contact requests of requester
1755 let mut update = proto::UpdateContacts::default();
1756 update
1757 .remove_outgoing_requests
1758 .push(responder_id.to_proto());
1759 for connection_id in pool.user_connection_ids(requester_id) {
1760 session.peer.send(connection_id, update.clone())?;
1761 }
1762
1763 // Update incoming contact requests of responder
1764 let mut update = proto::UpdateContacts::default();
1765 update
1766 .remove_incoming_requests
1767 .push(requester_id.to_proto());
1768 for connection_id in pool.user_connection_ids(responder_id) {
1769 session.peer.send(connection_id, update.clone())?;
1770 }
1771
1772 response.send(proto::Ack {})?;
1773 Ok(())
1774}
1775
1776async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1777 let project_id = ProjectId::from_proto(request.project_id);
1778 let project_connection_ids = session
1779 .db()
1780 .await
1781 .project_connection_ids(project_id, session.connection_id)
1782 .await?;
1783 broadcast(
1784 session.connection_id,
1785 project_connection_ids.iter().copied(),
1786 |connection_id| {
1787 session
1788 .peer
1789 .forward_send(session.connection_id, connection_id, request.clone())
1790 },
1791 );
1792 Ok(())
1793}
1794
1795async fn get_private_user_info(
1796 _request: proto::GetPrivateUserInfo,
1797 response: Response<proto::GetPrivateUserInfo>,
1798 session: Session,
1799) -> Result<()> {
1800 let metrics_id = session
1801 .db()
1802 .await
1803 .get_user_metrics_id(session.user_id)
1804 .await?;
1805 let user = session
1806 .db()
1807 .await
1808 .get_user_by_id(session.user_id)
1809 .await?
1810 .ok_or_else(|| anyhow!("user not found"))?;
1811 response.send(proto::GetPrivateUserInfoResponse {
1812 metrics_id,
1813 staff: user.admin,
1814 })?;
1815 Ok(())
1816}
1817
1818fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1819 match message {
1820 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1821 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1822 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1823 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1824 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1825 code: frame.code.into(),
1826 reason: frame.reason,
1827 })),
1828 }
1829}
1830
1831fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1832 match message {
1833 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1834 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1835 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1836 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1837 AxumMessage::Close(frame) => {
1838 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1839 code: frame.code.into(),
1840 reason: frame.reason,
1841 }))
1842 }
1843 }
1844}
1845
1846fn build_initial_contacts_update(
1847 contacts: Vec<db::Contact>,
1848 pool: &ConnectionPool,
1849) -> proto::UpdateContacts {
1850 let mut update = proto::UpdateContacts::default();
1851
1852 for contact in contacts {
1853 match contact {
1854 db::Contact::Accepted {
1855 user_id,
1856 should_notify,
1857 busy,
1858 } => {
1859 update
1860 .contacts
1861 .push(contact_for_user(user_id, should_notify, busy, &pool));
1862 }
1863 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
1864 db::Contact::Incoming {
1865 user_id,
1866 should_notify,
1867 } => update
1868 .incoming_requests
1869 .push(proto::IncomingContactRequest {
1870 requester_id: user_id.to_proto(),
1871 should_notify,
1872 }),
1873 }
1874 }
1875
1876 update
1877}
1878
1879fn contact_for_user(
1880 user_id: UserId,
1881 should_notify: bool,
1882 busy: bool,
1883 pool: &ConnectionPool,
1884) -> proto::Contact {
1885 proto::Contact {
1886 user_id: user_id.to_proto(),
1887 online: pool.is_user_online(user_id),
1888 busy,
1889 should_notify,
1890 }
1891}
1892
1893fn room_updated(room: &proto::Room, peer: &Peer) {
1894 for participant in &room.participants {
1895 peer.send(
1896 ConnectionId(participant.peer_id),
1897 proto::RoomUpdated {
1898 room: Some(room.clone()),
1899 },
1900 )
1901 .trace_err();
1902 }
1903}
1904
1905async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
1906 let db = session.db().await;
1907 let contacts = db.get_contacts(user_id).await?;
1908 let busy = db.is_user_busy(user_id).await?;
1909
1910 let pool = session.connection_pool().await;
1911 let updated_contact = contact_for_user(user_id, false, busy, &pool);
1912 for contact in contacts {
1913 if let db::Contact::Accepted {
1914 user_id: contact_user_id,
1915 ..
1916 } = contact
1917 {
1918 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
1919 session
1920 .peer
1921 .send(
1922 contact_conn_id,
1923 proto::UpdateContacts {
1924 contacts: vec![updated_contact.clone()],
1925 remove_contacts: Default::default(),
1926 incoming_requests: Default::default(),
1927 remove_incoming_requests: Default::default(),
1928 outgoing_requests: Default::default(),
1929 remove_outgoing_requests: Default::default(),
1930 },
1931 )
1932 .trace_err();
1933 }
1934 }
1935 }
1936 Ok(())
1937}
1938
1939async fn leave_room_for_session(session: &Session) -> Result<()> {
1940 let mut contacts_to_update = HashSet::default();
1941
1942 let room_id;
1943 let canceled_calls_to_user_ids;
1944 let live_kit_room;
1945 let delete_live_kit_room;
1946 {
1947 let mut left_room = session.db().await.leave_room(session.connection_id).await?;
1948 contacts_to_update.insert(session.user_id);
1949
1950 for project in left_room.left_projects.values() {
1951 project_left(project, session);
1952 }
1953
1954 room_updated(&left_room.room, &session.peer);
1955 room_id = RoomId::from_proto(left_room.room.id);
1956 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
1957 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
1958 delete_live_kit_room = left_room.room.participants.is_empty();
1959 }
1960
1961 {
1962 let pool = session.connection_pool().await;
1963 for canceled_user_id in canceled_calls_to_user_ids {
1964 for connection_id in pool.user_connection_ids(canceled_user_id) {
1965 session
1966 .peer
1967 .send(
1968 connection_id,
1969 proto::CallCanceled {
1970 room_id: room_id.to_proto(),
1971 },
1972 )
1973 .trace_err();
1974 }
1975 contacts_to_update.insert(canceled_user_id);
1976 }
1977 }
1978
1979 for contact_user_id in contacts_to_update {
1980 update_user_contacts(contact_user_id, &session).await?;
1981 }
1982
1983 if let Some(live_kit) = session.live_kit_client.as_ref() {
1984 live_kit
1985 .remove_participant(live_kit_room.clone(), session.connection_id.to_string())
1986 .await
1987 .trace_err();
1988
1989 if delete_live_kit_room {
1990 live_kit.delete_room(live_kit_room).await.trace_err();
1991 }
1992 }
1993
1994 Ok(())
1995}
1996
1997fn project_left(project: &db::LeftProject, session: &Session) {
1998 for connection_id in &project.connection_ids {
1999 if project.host_user_id == session.user_id {
2000 session
2001 .peer
2002 .send(
2003 *connection_id,
2004 proto::UnshareProject {
2005 project_id: project.id.to_proto(),
2006 },
2007 )
2008 .trace_err();
2009 } else {
2010 session
2011 .peer
2012 .send(
2013 *connection_id,
2014 proto::RemoveProjectCollaborator {
2015 project_id: project.id.to_proto(),
2016 peer_id: session.connection_id.0,
2017 },
2018 )
2019 .trace_err();
2020 }
2021 }
2022
2023 session
2024 .peer
2025 .send(
2026 session.connection_id,
2027 proto::UnshareProject {
2028 project_id: project.id.to_proto(),
2029 },
2030 )
2031 .trace_err();
2032}
2033
2034pub trait ResultExt {
2035 type Ok;
2036
2037 fn trace_err(self) -> Option<Self::Ok>;
2038}
2039
2040impl<T, E> ResultExt for Result<T, E>
2041where
2042 E: std::fmt::Debug,
2043{
2044 type Ok = T;
2045
2046 fn trace_err(self) -> Option<T> {
2047 match self {
2048 Ok(value) => Some(value),
2049 Err(error) => {
2050 tracing::error!("{:?}", error);
2051 None
2052 }
2053 }
2054 }
2055}