1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, DefaultDb, ProjectId, RoomId, User, UserId},
6 AppState, Result,
7};
8use anyhow::anyhow;
9use async_tungstenite::tungstenite::{
10 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
11};
12use axum::{
13 body::Body,
14 extract::{
15 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
16 ConnectInfo, WebSocketUpgrade,
17 },
18 headers::{Header, HeaderName},
19 http::StatusCode,
20 middleware,
21 response::IntoResponse,
22 routing::get,
23 Extension, Router, TypedHeader,
24};
25use collections::{HashMap, HashSet};
26pub use connection_pool::ConnectionPool;
27use futures::{
28 channel::oneshot,
29 future::{self, BoxFuture},
30 stream::FuturesUnordered,
31 FutureExt, SinkExt, StreamExt, TryStreamExt,
32};
33use lazy_static::lazy_static;
34use prometheus::{register_int_gauge, IntGauge};
35use rpc::{
36 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
37 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
38};
39use serde::{Serialize, Serializer};
40use std::{
41 any::TypeId,
42 fmt,
43 future::Future,
44 marker::PhantomData,
45 net::SocketAddr,
46 ops::{Deref, DerefMut},
47 rc::Rc,
48 sync::{
49 atomic::{AtomicBool, Ordering::SeqCst},
50 Arc,
51 },
52 time::Duration,
53};
54use tokio::{
55 sync::{Mutex, MutexGuard},
56 time::Sleep,
57};
58use tower::ServiceBuilder;
59use tracing::{info_span, instrument, Instrument};
60
61lazy_static! {
62 static ref METRIC_CONNECTIONS: IntGauge =
63 register_int_gauge!("connections", "number of connections").unwrap();
64 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
65 "shared_projects",
66 "number of open projects with one or more guests"
67 )
68 .unwrap();
69}
70
71type MessageHandler =
72 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
73
74struct Response<R> {
75 peer: Arc<Peer>,
76 receipt: Receipt<R>,
77 responded: Arc<AtomicBool>,
78}
79
80impl<R: RequestMessage> Response<R> {
81 fn send(self, payload: R::Response) -> Result<()> {
82 self.responded.store(true, SeqCst);
83 self.peer.respond(self.receipt, payload)?;
84 Ok(())
85 }
86}
87
88#[derive(Clone)]
89struct Session {
90 user_id: UserId,
91 connection_id: ConnectionId,
92 db: Arc<Mutex<DbHandle>>,
93 peer: Arc<Peer>,
94 connection_pool: Arc<Mutex<ConnectionPool>>,
95 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
96}
97
98impl Session {
99 async fn db(&self) -> MutexGuard<DbHandle> {
100 #[cfg(test)]
101 tokio::task::yield_now().await;
102 let guard = self.db.lock().await;
103 #[cfg(test)]
104 tokio::task::yield_now().await;
105 guard
106 }
107
108 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
109 #[cfg(test)]
110 tokio::task::yield_now().await;
111 let guard = self.connection_pool.lock().await;
112 #[cfg(test)]
113 tokio::task::yield_now().await;
114 ConnectionPoolGuard {
115 guard,
116 _not_send: PhantomData,
117 }
118 }
119}
120
121impl fmt::Debug for Session {
122 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123 f.debug_struct("Session")
124 .field("user_id", &self.user_id)
125 .field("connection_id", &self.connection_id)
126 .finish()
127 }
128}
129
130struct DbHandle(Arc<DefaultDb>);
131
132impl Deref for DbHandle {
133 type Target = DefaultDb;
134
135 fn deref(&self) -> &Self::Target {
136 self.0.as_ref()
137 }
138}
139
140pub struct Server {
141 peer: Arc<Peer>,
142 pub(crate) connection_pool: Arc<Mutex<ConnectionPool>>,
143 app_state: Arc<AppState>,
144 handlers: HashMap<TypeId, MessageHandler>,
145}
146
147pub trait Executor: Send + Clone {
148 type Sleep: Send + Future;
149 fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
150 fn sleep(&self, duration: Duration) -> Self::Sleep;
151}
152
153#[derive(Clone)]
154pub struct RealExecutor;
155
156pub(crate) struct ConnectionPoolGuard<'a> {
157 guard: MutexGuard<'a, ConnectionPool>,
158 _not_send: PhantomData<Rc<()>>,
159}
160
161#[derive(Serialize)]
162pub struct ServerSnapshot<'a> {
163 peer: &'a Peer,
164 #[serde(serialize_with = "serialize_deref")]
165 connection_pool: ConnectionPoolGuard<'a>,
166}
167
168pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
169where
170 S: Serializer,
171 T: Deref<Target = U>,
172 U: Serialize,
173{
174 Serialize::serialize(value.deref(), serializer)
175}
176
177impl Server {
178 pub fn new(app_state: Arc<AppState>) -> Arc<Self> {
179 let mut server = Self {
180 peer: Peer::new(),
181 app_state,
182 connection_pool: Default::default(),
183 handlers: Default::default(),
184 };
185
186 server
187 .add_request_handler(ping)
188 .add_request_handler(create_room)
189 .add_request_handler(join_room)
190 .add_message_handler(leave_room)
191 .add_request_handler(call)
192 .add_request_handler(cancel_call)
193 .add_message_handler(decline_call)
194 .add_request_handler(update_participant_location)
195 .add_request_handler(share_project)
196 .add_message_handler(unshare_project)
197 .add_request_handler(join_project)
198 .add_message_handler(leave_project)
199 .add_request_handler(update_project)
200 .add_request_handler(update_worktree)
201 .add_message_handler(start_language_server)
202 .add_message_handler(update_language_server)
203 .add_request_handler(update_diagnostic_summary)
204 .add_request_handler(forward_project_request::<proto::GetHover>)
205 .add_request_handler(forward_project_request::<proto::GetDefinition>)
206 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
207 .add_request_handler(forward_project_request::<proto::GetReferences>)
208 .add_request_handler(forward_project_request::<proto::SearchProject>)
209 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
210 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
211 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
212 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
213 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
214 .add_request_handler(forward_project_request::<proto::GetCompletions>)
215 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
216 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
217 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
218 .add_request_handler(forward_project_request::<proto::PrepareRename>)
219 .add_request_handler(forward_project_request::<proto::PerformRename>)
220 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
221 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
222 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
223 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
224 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
225 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
226 .add_message_handler(create_buffer_for_peer)
227 .add_request_handler(update_buffer)
228 .add_message_handler(update_buffer_file)
229 .add_message_handler(buffer_reloaded)
230 .add_message_handler(buffer_saved)
231 .add_request_handler(save_buffer)
232 .add_request_handler(get_users)
233 .add_request_handler(fuzzy_search_users)
234 .add_request_handler(request_contact)
235 .add_request_handler(remove_contact)
236 .add_request_handler(respond_to_contact_request)
237 .add_request_handler(follow)
238 .add_message_handler(unfollow)
239 .add_message_handler(update_followers)
240 .add_message_handler(update_diff_base)
241 .add_request_handler(get_private_user_info);
242
243 Arc::new(server)
244 }
245
246 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
247 where
248 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
249 Fut: 'static + Send + Future<Output = Result<()>>,
250 M: EnvelopedMessage,
251 {
252 let prev_handler = self.handlers.insert(
253 TypeId::of::<M>(),
254 Box::new(move |envelope, session| {
255 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
256 let span = info_span!(
257 "handle message",
258 payload_type = envelope.payload_type_name()
259 );
260 span.in_scope(|| {
261 tracing::info!(
262 payload_type = envelope.payload_type_name(),
263 "message received"
264 );
265 });
266 let future = (handler)(*envelope, session);
267 async move {
268 if let Err(error) = future.await {
269 tracing::error!(%error, "error handling message");
270 }
271 }
272 .instrument(span)
273 .boxed()
274 }),
275 );
276 if prev_handler.is_some() {
277 panic!("registered a handler for the same message twice");
278 }
279 self
280 }
281
282 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
283 where
284 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
285 Fut: 'static + Send + Future<Output = Result<()>>,
286 M: EnvelopedMessage,
287 {
288 self.add_handler(move |envelope, session| handler(envelope.payload, session));
289 self
290 }
291
292 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
293 where
294 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
295 Fut: Send + Future<Output = Result<()>>,
296 M: RequestMessage,
297 {
298 let handler = Arc::new(handler);
299 self.add_handler(move |envelope, session| {
300 let receipt = envelope.receipt();
301 let handler = handler.clone();
302 async move {
303 let peer = session.peer.clone();
304 let responded = Arc::new(AtomicBool::default());
305 let response = Response {
306 peer: peer.clone(),
307 responded: responded.clone(),
308 receipt,
309 };
310 match (handler)(envelope.payload, response, session).await {
311 Ok(()) => {
312 if responded.load(std::sync::atomic::Ordering::SeqCst) {
313 Ok(())
314 } else {
315 Err(anyhow!("handler did not send a response"))?
316 }
317 }
318 Err(error) => {
319 peer.respond_with_error(
320 receipt,
321 proto::Error {
322 message: error.to_string(),
323 },
324 )?;
325 Err(error)
326 }
327 }
328 }
329 })
330 }
331
332 pub fn handle_connection<E: Executor>(
333 self: &Arc<Self>,
334 connection: Connection,
335 address: String,
336 user: User,
337 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
338 executor: E,
339 ) -> impl Future<Output = Result<()>> {
340 let this = self.clone();
341 let user_id = user.id;
342 let login = user.github_login;
343 let span = info_span!("handle connection", %user_id, %login, %address);
344 async move {
345 let (connection_id, handle_io, mut incoming_rx) = this
346 .peer
347 .add_connection(connection, {
348 let executor = executor.clone();
349 move |duration| {
350 let timer = executor.sleep(duration);
351 async move {
352 timer.await;
353 }
354 }
355 });
356
357 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
358 this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
359 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
360
361 if let Some(send_connection_id) = send_connection_id.take() {
362 let _ = send_connection_id.send(connection_id);
363 }
364
365 if !user.connected_once {
366 this.peer.send(connection_id, proto::ShowContacts {})?;
367 this.app_state.db.set_user_connected_once(user_id, true).await?;
368 }
369
370 let (contacts, invite_code) = future::try_join(
371 this.app_state.db.get_contacts(user_id),
372 this.app_state.db.get_invite_code_for_user(user_id)
373 ).await?;
374
375 {
376 let mut pool = this.connection_pool.lock().await;
377 pool.add_connection(connection_id, user_id, user.admin);
378 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
379
380 if let Some((code, count)) = invite_code {
381 this.peer.send(connection_id, proto::UpdateInviteInfo {
382 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
383 count,
384 })?;
385 }
386 }
387
388 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
389 this.peer.send(connection_id, incoming_call)?;
390 }
391
392 let session = Session {
393 user_id,
394 connection_id,
395 db: Arc::new(Mutex::new(DbHandle(this.app_state.db.clone()))),
396 peer: this.peer.clone(),
397 connection_pool: this.connection_pool.clone(),
398 live_kit_client: this.app_state.live_kit_client.clone()
399 };
400 update_user_contacts(user_id, &session).await?;
401
402 let handle_io = handle_io.fuse();
403 futures::pin_mut!(handle_io);
404
405 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
406 // This prevents deadlocks when e.g., client A performs a request to client B and
407 // client B performs a request to client A. If both clients stop processing further
408 // messages until their respective request completes, they won't have a chance to
409 // respond to the other client's request and cause a deadlock.
410 //
411 // This arrangement ensures we will attempt to process earlier messages first, but fall
412 // back to processing messages arrived later in the spirit of making progress.
413 let mut foreground_message_handlers = FuturesUnordered::new();
414 loop {
415 let next_message = incoming_rx.next().fuse();
416 futures::pin_mut!(next_message);
417 futures::select_biased! {
418 result = handle_io => {
419 if let Err(error) = result {
420 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
421 }
422 break;
423 }
424 _ = foreground_message_handlers.next() => {}
425 message = next_message => {
426 if let Some(message) = message {
427 let type_name = message.payload_type_name();
428 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
429 let span_enter = span.enter();
430 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
431 let is_background = message.is_background();
432 let handle_message = (handler)(message, session.clone());
433 drop(span_enter);
434
435 let handle_message = handle_message.instrument(span);
436 if is_background {
437 executor.spawn_detached(handle_message);
438 } else {
439 foreground_message_handlers.push(handle_message);
440 }
441 } else {
442 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
443 }
444 } else {
445 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
446 break;
447 }
448 }
449 }
450 }
451
452 drop(foreground_message_handlers);
453 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
454 if let Err(error) = sign_out(session).await {
455 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
456 }
457
458 Ok(())
459 }.instrument(span)
460 }
461
462 pub async fn invite_code_redeemed(
463 self: &Arc<Self>,
464 inviter_id: UserId,
465 invitee_id: UserId,
466 ) -> Result<()> {
467 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
468 if let Some(code) = &user.invite_code {
469 let pool = self.connection_pool.lock().await;
470 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
471 for connection_id in pool.user_connection_ids(inviter_id) {
472 self.peer.send(
473 connection_id,
474 proto::UpdateContacts {
475 contacts: vec![invitee_contact.clone()],
476 ..Default::default()
477 },
478 )?;
479 self.peer.send(
480 connection_id,
481 proto::UpdateInviteInfo {
482 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
483 count: user.invite_count as u32,
484 },
485 )?;
486 }
487 }
488 }
489 Ok(())
490 }
491
492 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
493 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
494 if let Some(invite_code) = &user.invite_code {
495 let pool = self.connection_pool.lock().await;
496 for connection_id in pool.user_connection_ids(user_id) {
497 self.peer.send(
498 connection_id,
499 proto::UpdateInviteInfo {
500 url: format!(
501 "{}{}",
502 self.app_state.config.invite_link_prefix, invite_code
503 ),
504 count: user.invite_count as u32,
505 },
506 )?;
507 }
508 }
509 }
510 Ok(())
511 }
512
513 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
514 ServerSnapshot {
515 connection_pool: ConnectionPoolGuard {
516 guard: self.connection_pool.lock().await,
517 _not_send: PhantomData,
518 },
519 peer: &self.peer,
520 }
521 }
522}
523
524impl<'a> Deref for ConnectionPoolGuard<'a> {
525 type Target = ConnectionPool;
526
527 fn deref(&self) -> &Self::Target {
528 &*self.guard
529 }
530}
531
532impl<'a> DerefMut for ConnectionPoolGuard<'a> {
533 fn deref_mut(&mut self) -> &mut Self::Target {
534 &mut *self.guard
535 }
536}
537
538impl<'a> Drop for ConnectionPoolGuard<'a> {
539 fn drop(&mut self) {
540 #[cfg(test)]
541 self.check_invariants();
542 }
543}
544
545impl Executor for RealExecutor {
546 type Sleep = Sleep;
547
548 fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
549 tokio::task::spawn(future);
550 }
551
552 fn sleep(&self, duration: Duration) -> Self::Sleep {
553 tokio::time::sleep(duration)
554 }
555}
556
557fn broadcast<F>(
558 sender_id: ConnectionId,
559 receiver_ids: impl IntoIterator<Item = ConnectionId>,
560 mut f: F,
561) where
562 F: FnMut(ConnectionId) -> anyhow::Result<()>,
563{
564 for receiver_id in receiver_ids {
565 if receiver_id != sender_id {
566 f(receiver_id).trace_err();
567 }
568 }
569}
570
571lazy_static! {
572 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
573}
574
575pub struct ProtocolVersion(u32);
576
577impl Header for ProtocolVersion {
578 fn name() -> &'static HeaderName {
579 &ZED_PROTOCOL_VERSION
580 }
581
582 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
583 where
584 Self: Sized,
585 I: Iterator<Item = &'i axum::http::HeaderValue>,
586 {
587 let version = values
588 .next()
589 .ok_or_else(axum::headers::Error::invalid)?
590 .to_str()
591 .map_err(|_| axum::headers::Error::invalid())?
592 .parse()
593 .map_err(|_| axum::headers::Error::invalid())?;
594 Ok(Self(version))
595 }
596
597 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
598 values.extend([self.0.to_string().parse().unwrap()]);
599 }
600}
601
602pub fn routes(server: Arc<Server>) -> Router<Body> {
603 Router::new()
604 .route("/rpc", get(handle_websocket_request))
605 .layer(
606 ServiceBuilder::new()
607 .layer(Extension(server.app_state.clone()))
608 .layer(middleware::from_fn(auth::validate_header)),
609 )
610 .route("/metrics", get(handle_metrics))
611 .layer(Extension(server))
612}
613
614pub async fn handle_websocket_request(
615 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
616 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
617 Extension(server): Extension<Arc<Server>>,
618 Extension(user): Extension<User>,
619 ws: WebSocketUpgrade,
620) -> axum::response::Response {
621 if protocol_version != rpc::PROTOCOL_VERSION {
622 return (
623 StatusCode::UPGRADE_REQUIRED,
624 "client must be upgraded".to_string(),
625 )
626 .into_response();
627 }
628 let socket_address = socket_address.to_string();
629 ws.on_upgrade(move |socket| {
630 use util::ResultExt;
631 let socket = socket
632 .map_ok(to_tungstenite_message)
633 .err_into()
634 .with(|message| async move { Ok(to_axum_message(message)) });
635 let connection = Connection::new(Box::pin(socket));
636 async move {
637 server
638 .handle_connection(connection, socket_address, user, None, RealExecutor)
639 .await
640 .log_err();
641 }
642 })
643}
644
645pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
646 let connections = server
647 .connection_pool
648 .lock()
649 .await
650 .connections()
651 .filter(|connection| !connection.admin)
652 .count();
653
654 METRIC_CONNECTIONS.set(connections as _);
655
656 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
657 METRIC_SHARED_PROJECTS.set(shared_projects as _);
658
659 let encoder = prometheus::TextEncoder::new();
660 let metric_families = prometheus::gather();
661 let encoded_metrics = encoder
662 .encode_to_string(&metric_families)
663 .map_err(|err| anyhow!("{}", err))?;
664 Ok(encoded_metrics)
665}
666
667#[instrument(err)]
668async fn sign_out(session: Session) -> Result<()> {
669 session.peer.disconnect(session.connection_id);
670 let decline_calls = {
671 let mut pool = session.connection_pool().await;
672 pool.remove_connection(session.connection_id)?;
673 let mut connections = pool.user_connection_ids(session.user_id);
674 connections.next().is_none()
675 };
676
677 leave_room_for_session(&session).await.trace_err();
678 if decline_calls {
679 if let Some(room) = session
680 .db()
681 .await
682 .decline_call(None, session.user_id)
683 .await
684 .trace_err()
685 {
686 room_updated(&room, &session);
687 }
688 }
689
690 update_user_contacts(session.user_id, &session).await?;
691
692 Ok(())
693}
694
695async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
696 response.send(proto::Ack {})?;
697 Ok(())
698}
699
700async fn create_room(
701 _request: proto::CreateRoom,
702 response: Response<proto::CreateRoom>,
703 session: Session,
704) -> Result<()> {
705 let room = session
706 .db()
707 .await
708 .create_room(session.user_id, session.connection_id)
709 .await?;
710
711 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
712 if let Some(_) = live_kit
713 .create_room(room.live_kit_room.clone())
714 .await
715 .trace_err()
716 {
717 if let Some(token) = live_kit
718 .room_token(&room.live_kit_room, &session.connection_id.to_string())
719 .trace_err()
720 {
721 Some(proto::LiveKitConnectionInfo {
722 server_url: live_kit.url().into(),
723 token,
724 })
725 } else {
726 None
727 }
728 } else {
729 None
730 }
731 } else {
732 None
733 };
734
735 response.send(proto::CreateRoomResponse {
736 room: Some(room),
737 live_kit_connection_info,
738 })?;
739 update_user_contacts(session.user_id, &session).await?;
740 Ok(())
741}
742
743async fn join_room(
744 request: proto::JoinRoom,
745 response: Response<proto::JoinRoom>,
746 session: Session,
747) -> Result<()> {
748 let room = session
749 .db()
750 .await
751 .join_room(
752 RoomId::from_proto(request.id),
753 session.user_id,
754 session.connection_id,
755 )
756 .await?;
757 for connection_id in session
758 .connection_pool()
759 .await
760 .user_connection_ids(session.user_id)
761 {
762 session
763 .peer
764 .send(connection_id, proto::CallCanceled {})
765 .trace_err();
766 }
767
768 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
769 if let Some(token) = live_kit
770 .room_token(&room.live_kit_room, &session.connection_id.to_string())
771 .trace_err()
772 {
773 Some(proto::LiveKitConnectionInfo {
774 server_url: live_kit.url().into(),
775 token,
776 })
777 } else {
778 None
779 }
780 } else {
781 None
782 };
783
784 room_updated(&room, &session);
785 response.send(proto::JoinRoomResponse {
786 room: Some(room),
787 live_kit_connection_info,
788 })?;
789
790 update_user_contacts(session.user_id, &session).await?;
791 Ok(())
792}
793
794async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
795 leave_room_for_session(&session).await
796}
797
798async fn call(
799 request: proto::Call,
800 response: Response<proto::Call>,
801 session: Session,
802) -> Result<()> {
803 let room_id = RoomId::from_proto(request.room_id);
804 let calling_user_id = session.user_id;
805 let calling_connection_id = session.connection_id;
806 let called_user_id = UserId::from_proto(request.called_user_id);
807 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
808 if !session
809 .db()
810 .await
811 .has_contact(calling_user_id, called_user_id)
812 .await?
813 {
814 return Err(anyhow!("cannot call a user who isn't a contact"))?;
815 }
816
817 let (room, incoming_call) = session
818 .db()
819 .await
820 .call(
821 room_id,
822 calling_user_id,
823 calling_connection_id,
824 called_user_id,
825 initial_project_id,
826 )
827 .await?;
828 room_updated(&room, &session);
829 update_user_contacts(called_user_id, &session).await?;
830
831 let mut calls = session
832 .connection_pool()
833 .await
834 .user_connection_ids(called_user_id)
835 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
836 .collect::<FuturesUnordered<_>>();
837
838 while let Some(call_response) = calls.next().await {
839 match call_response.as_ref() {
840 Ok(_) => {
841 response.send(proto::Ack {})?;
842 return Ok(());
843 }
844 Err(_) => {
845 call_response.trace_err();
846 }
847 }
848 }
849
850 let room = session
851 .db()
852 .await
853 .call_failed(room_id, called_user_id)
854 .await?;
855 room_updated(&room, &session);
856 update_user_contacts(called_user_id, &session).await?;
857
858 Err(anyhow!("failed to ring user"))?
859}
860
861async fn cancel_call(
862 request: proto::CancelCall,
863 response: Response<proto::CancelCall>,
864 session: Session,
865) -> Result<()> {
866 let called_user_id = UserId::from_proto(request.called_user_id);
867 let room_id = RoomId::from_proto(request.room_id);
868 let room = session
869 .db()
870 .await
871 .cancel_call(Some(room_id), session.connection_id, called_user_id)
872 .await?;
873 for connection_id in session
874 .connection_pool()
875 .await
876 .user_connection_ids(called_user_id)
877 {
878 session
879 .peer
880 .send(connection_id, proto::CallCanceled {})
881 .trace_err();
882 }
883 room_updated(&room, &session);
884 response.send(proto::Ack {})?;
885
886 update_user_contacts(called_user_id, &session).await?;
887 Ok(())
888}
889
890async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
891 let room_id = RoomId::from_proto(message.room_id);
892 let room = session
893 .db()
894 .await
895 .decline_call(Some(room_id), session.user_id)
896 .await?;
897 for connection_id in session
898 .connection_pool()
899 .await
900 .user_connection_ids(session.user_id)
901 {
902 session
903 .peer
904 .send(connection_id, proto::CallCanceled {})
905 .trace_err();
906 }
907 room_updated(&room, &session);
908 update_user_contacts(session.user_id, &session).await?;
909 Ok(())
910}
911
912async fn update_participant_location(
913 request: proto::UpdateParticipantLocation,
914 response: Response<proto::UpdateParticipantLocation>,
915 session: Session,
916) -> Result<()> {
917 let room_id = RoomId::from_proto(request.room_id);
918 let location = request
919 .location
920 .ok_or_else(|| anyhow!("invalid location"))?;
921 let room = session
922 .db()
923 .await
924 .update_room_participant_location(room_id, session.connection_id, location)
925 .await?;
926 room_updated(&room, &session);
927 response.send(proto::Ack {})?;
928 Ok(())
929}
930
931async fn share_project(
932 request: proto::ShareProject,
933 response: Response<proto::ShareProject>,
934 session: Session,
935) -> Result<()> {
936 let (project_id, room) = session
937 .db()
938 .await
939 .share_project(
940 RoomId::from_proto(request.room_id),
941 session.connection_id,
942 &request.worktrees,
943 )
944 .await?;
945 response.send(proto::ShareProjectResponse {
946 project_id: project_id.to_proto(),
947 })?;
948 room_updated(&room, &session);
949
950 Ok(())
951}
952
953async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
954 let project_id = ProjectId::from_proto(message.project_id);
955
956 let (room, guest_connection_ids) = session
957 .db()
958 .await
959 .unshare_project(project_id, session.connection_id)
960 .await?;
961
962 broadcast(session.connection_id, guest_connection_ids, |conn_id| {
963 session.peer.send(conn_id, message.clone())
964 });
965 room_updated(&room, &session);
966
967 Ok(())
968}
969
970async fn join_project(
971 request: proto::JoinProject,
972 response: Response<proto::JoinProject>,
973 session: Session,
974) -> Result<()> {
975 let project_id = ProjectId::from_proto(request.project_id);
976 let guest_user_id = session.user_id;
977
978 tracing::info!(%project_id, "join project");
979
980 let (project, replica_id) = session
981 .db()
982 .await
983 .join_project(project_id, session.connection_id)
984 .await?;
985
986 let collaborators = project
987 .collaborators
988 .iter()
989 .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32)
990 .map(|collaborator| proto::Collaborator {
991 peer_id: collaborator.connection_id as u32,
992 replica_id: collaborator.replica_id.0 as u32,
993 user_id: collaborator.user_id.to_proto(),
994 })
995 .collect::<Vec<_>>();
996 let worktrees = project
997 .worktrees
998 .iter()
999 .map(|(id, worktree)| proto::WorktreeMetadata {
1000 id: id.to_proto(),
1001 root_name: worktree.root_name.clone(),
1002 visible: worktree.visible,
1003 abs_path: worktree.abs_path.clone(),
1004 })
1005 .collect::<Vec<_>>();
1006
1007 for collaborator in &collaborators {
1008 session
1009 .peer
1010 .send(
1011 ConnectionId(collaborator.peer_id),
1012 proto::AddProjectCollaborator {
1013 project_id: project_id.to_proto(),
1014 collaborator: Some(proto::Collaborator {
1015 peer_id: session.connection_id.0,
1016 replica_id: replica_id.0 as u32,
1017 user_id: guest_user_id.to_proto(),
1018 }),
1019 },
1020 )
1021 .trace_err();
1022 }
1023
1024 // First, we send the metadata associated with each worktree.
1025 response.send(proto::JoinProjectResponse {
1026 worktrees: worktrees.clone(),
1027 replica_id: replica_id.0 as u32,
1028 collaborators: collaborators.clone(),
1029 language_servers: project.language_servers.clone(),
1030 })?;
1031
1032 for (worktree_id, worktree) in project.worktrees {
1033 #[cfg(any(test, feature = "test-support"))]
1034 const MAX_CHUNK_SIZE: usize = 2;
1035 #[cfg(not(any(test, feature = "test-support")))]
1036 const MAX_CHUNK_SIZE: usize = 256;
1037
1038 // Stream this worktree's entries.
1039 let message = proto::UpdateWorktree {
1040 project_id: project_id.to_proto(),
1041 worktree_id: worktree_id.to_proto(),
1042 abs_path: worktree.abs_path.clone(),
1043 root_name: worktree.root_name,
1044 updated_entries: worktree.entries,
1045 removed_entries: Default::default(),
1046 scan_id: worktree.scan_id,
1047 is_last_update: worktree.is_complete,
1048 };
1049 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1050 session.peer.send(session.connection_id, update.clone())?;
1051 }
1052
1053 // Stream this worktree's diagnostics.
1054 for summary in worktree.diagnostic_summaries {
1055 session.peer.send(
1056 session.connection_id,
1057 proto::UpdateDiagnosticSummary {
1058 project_id: project_id.to_proto(),
1059 worktree_id: worktree.id.to_proto(),
1060 summary: Some(summary),
1061 },
1062 )?;
1063 }
1064 }
1065
1066 for language_server in &project.language_servers {
1067 session.peer.send(
1068 session.connection_id,
1069 proto::UpdateLanguageServer {
1070 project_id: project_id.to_proto(),
1071 language_server_id: language_server.id,
1072 variant: Some(
1073 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1074 proto::LspDiskBasedDiagnosticsUpdated {},
1075 ),
1076 ),
1077 },
1078 )?;
1079 }
1080
1081 Ok(())
1082}
1083
1084async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1085 let sender_id = session.connection_id;
1086 let project_id = ProjectId::from_proto(request.project_id);
1087 let project;
1088 {
1089 project = session
1090 .db()
1091 .await
1092 .leave_project(project_id, sender_id)
1093 .await?;
1094 tracing::info!(
1095 %project_id,
1096 host_user_id = %project.host_user_id,
1097 host_connection_id = %project.host_connection_id,
1098 "leave project"
1099 );
1100
1101 broadcast(sender_id, project.connection_ids, |conn_id| {
1102 session.peer.send(
1103 conn_id,
1104 proto::RemoveProjectCollaborator {
1105 project_id: project_id.to_proto(),
1106 peer_id: sender_id.0,
1107 },
1108 )
1109 });
1110 }
1111
1112 Ok(())
1113}
1114
1115async fn update_project(
1116 request: proto::UpdateProject,
1117 response: Response<proto::UpdateProject>,
1118 session: Session,
1119) -> Result<()> {
1120 let project_id = ProjectId::from_proto(request.project_id);
1121 let (room, guest_connection_ids) = session
1122 .db()
1123 .await
1124 .update_project(project_id, session.connection_id, &request.worktrees)
1125 .await?;
1126 broadcast(
1127 session.connection_id,
1128 guest_connection_ids,
1129 |connection_id| {
1130 session
1131 .peer
1132 .forward_send(session.connection_id, connection_id, request.clone())
1133 },
1134 );
1135 room_updated(&room, &session);
1136 response.send(proto::Ack {})?;
1137
1138 Ok(())
1139}
1140
1141async fn update_worktree(
1142 request: proto::UpdateWorktree,
1143 response: Response<proto::UpdateWorktree>,
1144 session: Session,
1145) -> Result<()> {
1146 let guest_connection_ids = session
1147 .db()
1148 .await
1149 .update_worktree(&request, session.connection_id)
1150 .await?;
1151
1152 broadcast(
1153 session.connection_id,
1154 guest_connection_ids,
1155 |connection_id| {
1156 session
1157 .peer
1158 .forward_send(session.connection_id, connection_id, request.clone())
1159 },
1160 );
1161 response.send(proto::Ack {})?;
1162 Ok(())
1163}
1164
1165async fn update_diagnostic_summary(
1166 request: proto::UpdateDiagnosticSummary,
1167 response: Response<proto::UpdateDiagnosticSummary>,
1168 session: Session,
1169) -> Result<()> {
1170 let guest_connection_ids = session
1171 .db()
1172 .await
1173 .update_diagnostic_summary(&request, session.connection_id)
1174 .await?;
1175
1176 broadcast(
1177 session.connection_id,
1178 guest_connection_ids,
1179 |connection_id| {
1180 session
1181 .peer
1182 .forward_send(session.connection_id, connection_id, request.clone())
1183 },
1184 );
1185
1186 response.send(proto::Ack {})?;
1187 Ok(())
1188}
1189
1190async fn start_language_server(
1191 request: proto::StartLanguageServer,
1192 session: Session,
1193) -> Result<()> {
1194 let guest_connection_ids = session
1195 .db()
1196 .await
1197 .start_language_server(&request, session.connection_id)
1198 .await?;
1199
1200 broadcast(
1201 session.connection_id,
1202 guest_connection_ids,
1203 |connection_id| {
1204 session
1205 .peer
1206 .forward_send(session.connection_id, connection_id, request.clone())
1207 },
1208 );
1209 Ok(())
1210}
1211
1212async fn update_language_server(
1213 request: proto::UpdateLanguageServer,
1214 session: Session,
1215) -> Result<()> {
1216 let project_id = ProjectId::from_proto(request.project_id);
1217 let project_connection_ids = session
1218 .db()
1219 .await
1220 .project_connection_ids(project_id, session.connection_id)
1221 .await?;
1222 broadcast(
1223 session.connection_id,
1224 project_connection_ids,
1225 |connection_id| {
1226 session
1227 .peer
1228 .forward_send(session.connection_id, connection_id, request.clone())
1229 },
1230 );
1231 Ok(())
1232}
1233
1234async fn forward_project_request<T>(
1235 request: T,
1236 response: Response<T>,
1237 session: Session,
1238) -> Result<()>
1239where
1240 T: EntityMessage + RequestMessage,
1241{
1242 let project_id = ProjectId::from_proto(request.remote_entity_id());
1243 let collaborators = session
1244 .db()
1245 .await
1246 .project_collaborators(project_id, session.connection_id)
1247 .await?;
1248 let host = collaborators
1249 .iter()
1250 .find(|collaborator| collaborator.is_host)
1251 .ok_or_else(|| anyhow!("host not found"))?;
1252
1253 let payload = session
1254 .peer
1255 .forward_request(
1256 session.connection_id,
1257 ConnectionId(host.connection_id as u32),
1258 request,
1259 )
1260 .await?;
1261
1262 response.send(payload)?;
1263 Ok(())
1264}
1265
1266async fn save_buffer(
1267 request: proto::SaveBuffer,
1268 response: Response<proto::SaveBuffer>,
1269 session: Session,
1270) -> Result<()> {
1271 let project_id = ProjectId::from_proto(request.project_id);
1272 let collaborators = session
1273 .db()
1274 .await
1275 .project_collaborators(project_id, session.connection_id)
1276 .await?;
1277 let host = collaborators
1278 .into_iter()
1279 .find(|collaborator| collaborator.is_host)
1280 .ok_or_else(|| anyhow!("host not found"))?;
1281 let host_connection_id = ConnectionId(host.connection_id as u32);
1282 let response_payload = session
1283 .peer
1284 .forward_request(session.connection_id, host_connection_id, request.clone())
1285 .await?;
1286
1287 let mut collaborators = session
1288 .db()
1289 .await
1290 .project_collaborators(project_id, session.connection_id)
1291 .await?;
1292 collaborators
1293 .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32);
1294 let project_connection_ids = collaborators
1295 .into_iter()
1296 .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
1297 broadcast(host_connection_id, project_connection_ids, |conn_id| {
1298 session
1299 .peer
1300 .forward_send(host_connection_id, conn_id, response_payload.clone())
1301 });
1302 response.send(response_payload)?;
1303 Ok(())
1304}
1305
1306async fn create_buffer_for_peer(
1307 request: proto::CreateBufferForPeer,
1308 session: Session,
1309) -> Result<()> {
1310 session.peer.forward_send(
1311 session.connection_id,
1312 ConnectionId(request.peer_id),
1313 request,
1314 )?;
1315 Ok(())
1316}
1317
1318async fn update_buffer(
1319 request: proto::UpdateBuffer,
1320 response: Response<proto::UpdateBuffer>,
1321 session: Session,
1322) -> Result<()> {
1323 let project_id = ProjectId::from_proto(request.project_id);
1324 let project_connection_ids = session
1325 .db()
1326 .await
1327 .project_connection_ids(project_id, session.connection_id)
1328 .await?;
1329
1330 broadcast(
1331 session.connection_id,
1332 project_connection_ids,
1333 |connection_id| {
1334 session
1335 .peer
1336 .forward_send(session.connection_id, connection_id, request.clone())
1337 },
1338 );
1339 response.send(proto::Ack {})?;
1340 Ok(())
1341}
1342
1343async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1344 let project_id = ProjectId::from_proto(request.project_id);
1345 let project_connection_ids = session
1346 .db()
1347 .await
1348 .project_connection_ids(project_id, session.connection_id)
1349 .await?;
1350
1351 broadcast(
1352 session.connection_id,
1353 project_connection_ids,
1354 |connection_id| {
1355 session
1356 .peer
1357 .forward_send(session.connection_id, connection_id, request.clone())
1358 },
1359 );
1360 Ok(())
1361}
1362
1363async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1364 let project_id = ProjectId::from_proto(request.project_id);
1365 let project_connection_ids = session
1366 .db()
1367 .await
1368 .project_connection_ids(project_id, session.connection_id)
1369 .await?;
1370 broadcast(
1371 session.connection_id,
1372 project_connection_ids,
1373 |connection_id| {
1374 session
1375 .peer
1376 .forward_send(session.connection_id, connection_id, request.clone())
1377 },
1378 );
1379 Ok(())
1380}
1381
1382async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1383 let project_id = ProjectId::from_proto(request.project_id);
1384 let project_connection_ids = session
1385 .db()
1386 .await
1387 .project_connection_ids(project_id, session.connection_id)
1388 .await?;
1389 broadcast(
1390 session.connection_id,
1391 project_connection_ids,
1392 |connection_id| {
1393 session
1394 .peer
1395 .forward_send(session.connection_id, connection_id, request.clone())
1396 },
1397 );
1398 Ok(())
1399}
1400
1401async fn follow(
1402 request: proto::Follow,
1403 response: Response<proto::Follow>,
1404 session: Session,
1405) -> Result<()> {
1406 let project_id = ProjectId::from_proto(request.project_id);
1407 let leader_id = ConnectionId(request.leader_id);
1408 let follower_id = session.connection_id;
1409 let project_connection_ids = session
1410 .db()
1411 .await
1412 .project_connection_ids(project_id, session.connection_id)
1413 .await?;
1414
1415 if !project_connection_ids.contains(&leader_id) {
1416 Err(anyhow!("no such peer"))?;
1417 }
1418
1419 let mut response_payload = session
1420 .peer
1421 .forward_request(session.connection_id, leader_id, request)
1422 .await?;
1423 response_payload
1424 .views
1425 .retain(|view| view.leader_id != Some(follower_id.0));
1426 response.send(response_payload)?;
1427 Ok(())
1428}
1429
1430async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1431 let project_id = ProjectId::from_proto(request.project_id);
1432 let leader_id = ConnectionId(request.leader_id);
1433 let project_connection_ids = session
1434 .db()
1435 .await
1436 .project_connection_ids(project_id, session.connection_id)
1437 .await?;
1438 if !project_connection_ids.contains(&leader_id) {
1439 Err(anyhow!("no such peer"))?;
1440 }
1441 session
1442 .peer
1443 .forward_send(session.connection_id, leader_id, request)?;
1444 Ok(())
1445}
1446
1447async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1448 let project_id = ProjectId::from_proto(request.project_id);
1449 let project_connection_ids = session
1450 .db
1451 .lock()
1452 .await
1453 .project_connection_ids(project_id, session.connection_id)
1454 .await?;
1455
1456 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1457 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1458 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1459 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1460 });
1461 for follower_id in &request.follower_ids {
1462 let follower_id = ConnectionId(*follower_id);
1463 if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1464 session
1465 .peer
1466 .forward_send(session.connection_id, follower_id, request.clone())?;
1467 }
1468 }
1469 Ok(())
1470}
1471
1472async fn get_users(
1473 request: proto::GetUsers,
1474 response: Response<proto::GetUsers>,
1475 session: Session,
1476) -> Result<()> {
1477 let user_ids = request
1478 .user_ids
1479 .into_iter()
1480 .map(UserId::from_proto)
1481 .collect();
1482 let users = session
1483 .db()
1484 .await
1485 .get_users_by_ids(user_ids)
1486 .await?
1487 .into_iter()
1488 .map(|user| proto::User {
1489 id: user.id.to_proto(),
1490 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1491 github_login: user.github_login,
1492 })
1493 .collect();
1494 response.send(proto::UsersResponse { users })?;
1495 Ok(())
1496}
1497
1498async fn fuzzy_search_users(
1499 request: proto::FuzzySearchUsers,
1500 response: Response<proto::FuzzySearchUsers>,
1501 session: Session,
1502) -> Result<()> {
1503 let query = request.query;
1504 let users = match query.len() {
1505 0 => vec![],
1506 1 | 2 => session
1507 .db()
1508 .await
1509 .get_user_by_github_account(&query, None)
1510 .await?
1511 .into_iter()
1512 .collect(),
1513 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1514 };
1515 let users = users
1516 .into_iter()
1517 .filter(|user| user.id != session.user_id)
1518 .map(|user| proto::User {
1519 id: user.id.to_proto(),
1520 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1521 github_login: user.github_login,
1522 })
1523 .collect();
1524 response.send(proto::UsersResponse { users })?;
1525 Ok(())
1526}
1527
1528async fn request_contact(
1529 request: proto::RequestContact,
1530 response: Response<proto::RequestContact>,
1531 session: Session,
1532) -> Result<()> {
1533 let requester_id = session.user_id;
1534 let responder_id = UserId::from_proto(request.responder_id);
1535 if requester_id == responder_id {
1536 return Err(anyhow!("cannot add yourself as a contact"))?;
1537 }
1538
1539 session
1540 .db()
1541 .await
1542 .send_contact_request(requester_id, responder_id)
1543 .await?;
1544
1545 // Update outgoing contact requests of requester
1546 let mut update = proto::UpdateContacts::default();
1547 update.outgoing_requests.push(responder_id.to_proto());
1548 for connection_id in session
1549 .connection_pool()
1550 .await
1551 .user_connection_ids(requester_id)
1552 {
1553 session.peer.send(connection_id, update.clone())?;
1554 }
1555
1556 // Update incoming contact requests of responder
1557 let mut update = proto::UpdateContacts::default();
1558 update
1559 .incoming_requests
1560 .push(proto::IncomingContactRequest {
1561 requester_id: requester_id.to_proto(),
1562 should_notify: true,
1563 });
1564 for connection_id in session
1565 .connection_pool()
1566 .await
1567 .user_connection_ids(responder_id)
1568 {
1569 session.peer.send(connection_id, update.clone())?;
1570 }
1571
1572 response.send(proto::Ack {})?;
1573 Ok(())
1574}
1575
1576async fn respond_to_contact_request(
1577 request: proto::RespondToContactRequest,
1578 response: Response<proto::RespondToContactRequest>,
1579 session: Session,
1580) -> Result<()> {
1581 let responder_id = session.user_id;
1582 let requester_id = UserId::from_proto(request.requester_id);
1583 let db = session.db().await;
1584 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1585 db.dismiss_contact_notification(responder_id, requester_id)
1586 .await?;
1587 } else {
1588 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1589
1590 db.respond_to_contact_request(responder_id, requester_id, accept)
1591 .await?;
1592 let busy = db.is_user_busy(requester_id).await?;
1593
1594 let pool = session.connection_pool().await;
1595 // Update responder with new contact
1596 let mut update = proto::UpdateContacts::default();
1597 if accept {
1598 update
1599 .contacts
1600 .push(contact_for_user(requester_id, false, busy, &pool));
1601 }
1602 update
1603 .remove_incoming_requests
1604 .push(requester_id.to_proto());
1605 for connection_id in pool.user_connection_ids(responder_id) {
1606 session.peer.send(connection_id, update.clone())?;
1607 }
1608
1609 // Update requester with new contact
1610 let mut update = proto::UpdateContacts::default();
1611 if accept {
1612 update
1613 .contacts
1614 .push(contact_for_user(responder_id, true, busy, &pool));
1615 }
1616 update
1617 .remove_outgoing_requests
1618 .push(responder_id.to_proto());
1619 for connection_id in pool.user_connection_ids(requester_id) {
1620 session.peer.send(connection_id, update.clone())?;
1621 }
1622 }
1623
1624 response.send(proto::Ack {})?;
1625 Ok(())
1626}
1627
1628async fn remove_contact(
1629 request: proto::RemoveContact,
1630 response: Response<proto::RemoveContact>,
1631 session: Session,
1632) -> Result<()> {
1633 let requester_id = session.user_id;
1634 let responder_id = UserId::from_proto(request.user_id);
1635 let db = session.db().await;
1636 db.remove_contact(requester_id, responder_id).await?;
1637
1638 let pool = session.connection_pool().await;
1639 // Update outgoing contact requests of requester
1640 let mut update = proto::UpdateContacts::default();
1641 update
1642 .remove_outgoing_requests
1643 .push(responder_id.to_proto());
1644 for connection_id in pool.user_connection_ids(requester_id) {
1645 session.peer.send(connection_id, update.clone())?;
1646 }
1647
1648 // Update incoming contact requests of responder
1649 let mut update = proto::UpdateContacts::default();
1650 update
1651 .remove_incoming_requests
1652 .push(requester_id.to_proto());
1653 for connection_id in pool.user_connection_ids(responder_id) {
1654 session.peer.send(connection_id, update.clone())?;
1655 }
1656
1657 response.send(proto::Ack {})?;
1658 Ok(())
1659}
1660
1661async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1662 let project_id = ProjectId::from_proto(request.project_id);
1663 let project_connection_ids = session
1664 .db()
1665 .await
1666 .project_connection_ids(project_id, session.connection_id)
1667 .await?;
1668 broadcast(
1669 session.connection_id,
1670 project_connection_ids,
1671 |connection_id| {
1672 session
1673 .peer
1674 .forward_send(session.connection_id, connection_id, request.clone())
1675 },
1676 );
1677 Ok(())
1678}
1679
1680async fn get_private_user_info(
1681 _request: proto::GetPrivateUserInfo,
1682 response: Response<proto::GetPrivateUserInfo>,
1683 session: Session,
1684) -> Result<()> {
1685 let metrics_id = session
1686 .db()
1687 .await
1688 .get_user_metrics_id(session.user_id)
1689 .await?;
1690 let user = session
1691 .db()
1692 .await
1693 .get_user_by_id(session.user_id)
1694 .await?
1695 .ok_or_else(|| anyhow!("user not found"))?;
1696 response.send(proto::GetPrivateUserInfoResponse {
1697 metrics_id,
1698 staff: user.admin,
1699 })?;
1700 Ok(())
1701}
1702
1703fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1704 match message {
1705 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1706 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1707 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1708 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1709 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1710 code: frame.code.into(),
1711 reason: frame.reason,
1712 })),
1713 }
1714}
1715
1716fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1717 match message {
1718 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1719 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1720 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1721 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1722 AxumMessage::Close(frame) => {
1723 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1724 code: frame.code.into(),
1725 reason: frame.reason,
1726 }))
1727 }
1728 }
1729}
1730
1731fn build_initial_contacts_update(
1732 contacts: Vec<db::Contact>,
1733 pool: &ConnectionPool,
1734) -> proto::UpdateContacts {
1735 let mut update = proto::UpdateContacts::default();
1736
1737 for contact in contacts {
1738 match contact {
1739 db::Contact::Accepted {
1740 user_id,
1741 should_notify,
1742 busy,
1743 } => {
1744 update
1745 .contacts
1746 .push(contact_for_user(user_id, should_notify, busy, &pool));
1747 }
1748 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
1749 db::Contact::Incoming {
1750 user_id,
1751 should_notify,
1752 } => update
1753 .incoming_requests
1754 .push(proto::IncomingContactRequest {
1755 requester_id: user_id.to_proto(),
1756 should_notify,
1757 }),
1758 }
1759 }
1760
1761 update
1762}
1763
1764fn contact_for_user(
1765 user_id: UserId,
1766 should_notify: bool,
1767 busy: bool,
1768 pool: &ConnectionPool,
1769) -> proto::Contact {
1770 proto::Contact {
1771 user_id: user_id.to_proto(),
1772 online: pool.is_user_online(user_id),
1773 busy,
1774 should_notify,
1775 }
1776}
1777
1778fn room_updated(room: &proto::Room, session: &Session) {
1779 for participant in &room.participants {
1780 session
1781 .peer
1782 .send(
1783 ConnectionId(participant.peer_id),
1784 proto::RoomUpdated {
1785 room: Some(room.clone()),
1786 },
1787 )
1788 .trace_err();
1789 }
1790}
1791
1792async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
1793 let db = session.db().await;
1794 let contacts = db.get_contacts(user_id).await?;
1795 let busy = db.is_user_busy(user_id).await?;
1796
1797 let pool = session.connection_pool().await;
1798 let updated_contact = contact_for_user(user_id, false, busy, &pool);
1799 for contact in contacts {
1800 if let db::Contact::Accepted {
1801 user_id: contact_user_id,
1802 ..
1803 } = contact
1804 {
1805 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
1806 session
1807 .peer
1808 .send(
1809 contact_conn_id,
1810 proto::UpdateContacts {
1811 contacts: vec![updated_contact.clone()],
1812 remove_contacts: Default::default(),
1813 incoming_requests: Default::default(),
1814 remove_incoming_requests: Default::default(),
1815 outgoing_requests: Default::default(),
1816 remove_outgoing_requests: Default::default(),
1817 },
1818 )
1819 .trace_err();
1820 }
1821 }
1822 }
1823 Ok(())
1824}
1825
1826async fn leave_room_for_session(session: &Session) -> Result<()> {
1827 let mut contacts_to_update = HashSet::default();
1828
1829 let Some(left_room) = session.db().await.leave_room(session.connection_id).await? else {
1830 return Err(anyhow!("no room to leave"))?;
1831 };
1832 contacts_to_update.insert(session.user_id);
1833
1834 for project in left_room.left_projects.into_values() {
1835 for connection_id in project.connection_ids {
1836 if project.host_user_id == session.user_id {
1837 session
1838 .peer
1839 .send(
1840 connection_id,
1841 proto::UnshareProject {
1842 project_id: project.id.to_proto(),
1843 },
1844 )
1845 .trace_err();
1846 } else {
1847 session
1848 .peer
1849 .send(
1850 connection_id,
1851 proto::RemoveProjectCollaborator {
1852 project_id: project.id.to_proto(),
1853 peer_id: session.connection_id.0,
1854 },
1855 )
1856 .trace_err();
1857 }
1858 }
1859
1860 session
1861 .peer
1862 .send(
1863 session.connection_id,
1864 proto::UnshareProject {
1865 project_id: project.id.to_proto(),
1866 },
1867 )
1868 .trace_err();
1869 }
1870
1871 room_updated(&left_room.room, &session);
1872 {
1873 let pool = session.connection_pool().await;
1874 for canceled_user_id in left_room.canceled_calls_to_user_ids {
1875 for connection_id in pool.user_connection_ids(canceled_user_id) {
1876 session
1877 .peer
1878 .send(connection_id, proto::CallCanceled {})
1879 .trace_err();
1880 }
1881 contacts_to_update.insert(canceled_user_id);
1882 }
1883 }
1884
1885 for contact_user_id in contacts_to_update {
1886 update_user_contacts(contact_user_id, &session).await?;
1887 }
1888
1889 if let Some(live_kit) = session.live_kit_client.as_ref() {
1890 live_kit
1891 .remove_participant(
1892 left_room.room.live_kit_room.clone(),
1893 session.connection_id.to_string(),
1894 )
1895 .await
1896 .trace_err();
1897
1898 if left_room.room.participants.is_empty() {
1899 live_kit
1900 .delete_room(left_room.room.live_kit_room)
1901 .await
1902 .trace_err();
1903 }
1904 }
1905
1906 Ok(())
1907}
1908
1909pub trait ResultExt {
1910 type Ok;
1911
1912 fn trace_err(self) -> Option<Self::Ok>;
1913}
1914
1915impl<T, E> ResultExt for Result<T, E>
1916where
1917 E: std::fmt::Debug,
1918{
1919 type Ok = T;
1920
1921 fn trace_err(self) -> Option<T> {
1922 match self {
1923 Ok(value) => Some(value),
1924 Err(error) => {
1925 tracing::error!("{:?}", error);
1926 None
1927 }
1928 }
1929 }
1930}