1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, Database, ProjectId, RoomId, User, UserId},
6 AppState, Result,
7};
8use anyhow::anyhow;
9use async_tungstenite::tungstenite::{
10 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
11};
12use axum::{
13 body::Body,
14 extract::{
15 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
16 ConnectInfo, WebSocketUpgrade,
17 },
18 headers::{Header, HeaderName},
19 http::StatusCode,
20 middleware,
21 response::IntoResponse,
22 routing::get,
23 Extension, Router, TypedHeader,
24};
25use collections::{HashMap, HashSet};
26pub use connection_pool::ConnectionPool;
27use futures::{
28 channel::oneshot,
29 future::{self, BoxFuture},
30 stream::FuturesUnordered,
31 FutureExt, SinkExt, StreamExt, TryStreamExt,
32};
33use lazy_static::lazy_static;
34use prometheus::{register_int_gauge, IntGauge};
35use rpc::{
36 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
37 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
38};
39use serde::{Serialize, Serializer};
40use std::{
41 any::TypeId,
42 fmt,
43 future::Future,
44 marker::PhantomData,
45 mem,
46 net::SocketAddr,
47 ops::{Deref, DerefMut},
48 rc::Rc,
49 sync::{
50 atomic::{AtomicBool, Ordering::SeqCst},
51 Arc,
52 },
53 time::Duration,
54};
55use tokio::{
56 sync::{Mutex, MutexGuard},
57 time::Sleep,
58};
59use tower::ServiceBuilder;
60use tracing::{info_span, instrument, Instrument};
61
62lazy_static! {
63 static ref METRIC_CONNECTIONS: IntGauge =
64 register_int_gauge!("connections", "number of connections").unwrap();
65 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
66 "shared_projects",
67 "number of open projects with one or more guests"
68 )
69 .unwrap();
70}
71
72type MessageHandler =
73 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
74
75struct Response<R> {
76 peer: Arc<Peer>,
77 receipt: Receipt<R>,
78 responded: Arc<AtomicBool>,
79}
80
81impl<R: RequestMessage> Response<R> {
82 fn send(self, payload: R::Response) -> Result<()> {
83 self.responded.store(true, SeqCst);
84 self.peer.respond(self.receipt, payload)?;
85 Ok(())
86 }
87}
88
89#[derive(Clone)]
90struct Session {
91 user_id: UserId,
92 connection_id: ConnectionId,
93 db: Arc<Mutex<DbHandle>>,
94 peer: Arc<Peer>,
95 connection_pool: Arc<Mutex<ConnectionPool>>,
96 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
97}
98
99impl Session {
100 async fn db(&self) -> MutexGuard<DbHandle> {
101 #[cfg(test)]
102 tokio::task::yield_now().await;
103 let guard = self.db.lock().await;
104 #[cfg(test)]
105 tokio::task::yield_now().await;
106 guard
107 }
108
109 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
110 #[cfg(test)]
111 tokio::task::yield_now().await;
112 let guard = self.connection_pool.lock().await;
113 #[cfg(test)]
114 tokio::task::yield_now().await;
115 ConnectionPoolGuard {
116 guard,
117 _not_send: PhantomData,
118 }
119 }
120}
121
122impl fmt::Debug for Session {
123 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124 f.debug_struct("Session")
125 .field("user_id", &self.user_id)
126 .field("connection_id", &self.connection_id)
127 .finish()
128 }
129}
130
131struct DbHandle(Arc<Database>);
132
133impl Deref for DbHandle {
134 type Target = Database;
135
136 fn deref(&self) -> &Self::Target {
137 self.0.as_ref()
138 }
139}
140
141pub struct Server {
142 peer: Arc<Peer>,
143 pub(crate) connection_pool: Arc<Mutex<ConnectionPool>>,
144 app_state: Arc<AppState>,
145 handlers: HashMap<TypeId, MessageHandler>,
146}
147
148pub trait Executor: Send + Clone {
149 type Sleep: Send + Future;
150 fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
151 fn sleep(&self, duration: Duration) -> Self::Sleep;
152}
153
154#[derive(Clone)]
155pub struct RealExecutor;
156
157pub(crate) struct ConnectionPoolGuard<'a> {
158 guard: MutexGuard<'a, ConnectionPool>,
159 _not_send: PhantomData<Rc<()>>,
160}
161
162#[derive(Serialize)]
163pub struct ServerSnapshot<'a> {
164 peer: &'a Peer,
165 #[serde(serialize_with = "serialize_deref")]
166 connection_pool: ConnectionPoolGuard<'a>,
167}
168
169pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
170where
171 S: Serializer,
172 T: Deref<Target = U>,
173 U: Serialize,
174{
175 Serialize::serialize(value.deref(), serializer)
176}
177
178impl Server {
179 pub fn new(app_state: Arc<AppState>) -> Arc<Self> {
180 let mut server = Self {
181 peer: Peer::new(),
182 app_state,
183 connection_pool: Default::default(),
184 handlers: Default::default(),
185 };
186
187 server
188 .add_request_handler(ping)
189 .add_request_handler(create_room)
190 .add_request_handler(join_room)
191 .add_message_handler(leave_room)
192 .add_request_handler(call)
193 .add_request_handler(cancel_call)
194 .add_message_handler(decline_call)
195 .add_request_handler(update_participant_location)
196 .add_request_handler(share_project)
197 .add_message_handler(unshare_project)
198 .add_request_handler(join_project)
199 .add_message_handler(leave_project)
200 .add_request_handler(update_project)
201 .add_request_handler(update_worktree)
202 .add_message_handler(start_language_server)
203 .add_message_handler(update_language_server)
204 .add_message_handler(update_diagnostic_summary)
205 .add_request_handler(forward_project_request::<proto::GetHover>)
206 .add_request_handler(forward_project_request::<proto::GetDefinition>)
207 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
208 .add_request_handler(forward_project_request::<proto::GetReferences>)
209 .add_request_handler(forward_project_request::<proto::SearchProject>)
210 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
211 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
212 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
213 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
214 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
215 .add_request_handler(forward_project_request::<proto::GetCompletions>)
216 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
217 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
218 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
219 .add_request_handler(forward_project_request::<proto::PrepareRename>)
220 .add_request_handler(forward_project_request::<proto::PerformRename>)
221 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
222 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
223 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
224 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
225 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
226 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
227 .add_message_handler(create_buffer_for_peer)
228 .add_request_handler(update_buffer)
229 .add_message_handler(update_buffer_file)
230 .add_message_handler(buffer_reloaded)
231 .add_message_handler(buffer_saved)
232 .add_request_handler(save_buffer)
233 .add_request_handler(get_users)
234 .add_request_handler(fuzzy_search_users)
235 .add_request_handler(request_contact)
236 .add_request_handler(remove_contact)
237 .add_request_handler(respond_to_contact_request)
238 .add_request_handler(follow)
239 .add_message_handler(unfollow)
240 .add_message_handler(update_followers)
241 .add_message_handler(update_diff_base)
242 .add_request_handler(get_private_user_info);
243
244 Arc::new(server)
245 }
246
247 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
248 where
249 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
250 Fut: 'static + Send + Future<Output = Result<()>>,
251 M: EnvelopedMessage,
252 {
253 let prev_handler = self.handlers.insert(
254 TypeId::of::<M>(),
255 Box::new(move |envelope, session| {
256 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
257 let span = info_span!(
258 "handle message",
259 payload_type = envelope.payload_type_name()
260 );
261 span.in_scope(|| {
262 tracing::info!(
263 payload_type = envelope.payload_type_name(),
264 "message received"
265 );
266 });
267 let future = (handler)(*envelope, session);
268 async move {
269 if let Err(error) = future.await {
270 tracing::error!(%error, "error handling message");
271 }
272 }
273 .instrument(span)
274 .boxed()
275 }),
276 );
277 if prev_handler.is_some() {
278 panic!("registered a handler for the same message twice");
279 }
280 self
281 }
282
283 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
284 where
285 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
286 Fut: 'static + Send + Future<Output = Result<()>>,
287 M: EnvelopedMessage,
288 {
289 self.add_handler(move |envelope, session| handler(envelope.payload, session));
290 self
291 }
292
293 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
294 where
295 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
296 Fut: Send + Future<Output = Result<()>>,
297 M: RequestMessage,
298 {
299 let handler = Arc::new(handler);
300 self.add_handler(move |envelope, session| {
301 let receipt = envelope.receipt();
302 let handler = handler.clone();
303 async move {
304 let peer = session.peer.clone();
305 let responded = Arc::new(AtomicBool::default());
306 let response = Response {
307 peer: peer.clone(),
308 responded: responded.clone(),
309 receipt,
310 };
311 match (handler)(envelope.payload, response, session).await {
312 Ok(()) => {
313 if responded.load(std::sync::atomic::Ordering::SeqCst) {
314 Ok(())
315 } else {
316 Err(anyhow!("handler did not send a response"))?
317 }
318 }
319 Err(error) => {
320 peer.respond_with_error(
321 receipt,
322 proto::Error {
323 message: error.to_string(),
324 },
325 )?;
326 Err(error)
327 }
328 }
329 }
330 })
331 }
332
333 pub fn handle_connection<E: Executor>(
334 self: &Arc<Self>,
335 connection: Connection,
336 address: String,
337 user: User,
338 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
339 executor: E,
340 ) -> impl Future<Output = Result<()>> {
341 let this = self.clone();
342 let user_id = user.id;
343 let login = user.github_login;
344 let span = info_span!("handle connection", %user_id, %login, %address);
345 async move {
346 let (connection_id, handle_io, mut incoming_rx) = this
347 .peer
348 .add_connection(connection, {
349 let executor = executor.clone();
350 move |duration| {
351 let timer = executor.sleep(duration);
352 async move {
353 timer.await;
354 }
355 }
356 });
357
358 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
359 this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
360 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
361
362 if let Some(send_connection_id) = send_connection_id.take() {
363 let _ = send_connection_id.send(connection_id);
364 }
365
366 if !user.connected_once {
367 this.peer.send(connection_id, proto::ShowContacts {})?;
368 this.app_state.db.set_user_connected_once(user_id, true).await?;
369 }
370
371 let (contacts, invite_code) = future::try_join(
372 this.app_state.db.get_contacts(user_id),
373 this.app_state.db.get_invite_code_for_user(user_id)
374 ).await?;
375
376 {
377 let mut pool = this.connection_pool.lock().await;
378 pool.add_connection(connection_id, user_id, user.admin);
379 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
380
381 if let Some((code, count)) = invite_code {
382 this.peer.send(connection_id, proto::UpdateInviteInfo {
383 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
384 count: count as u32,
385 })?;
386 }
387 }
388
389 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
390 this.peer.send(connection_id, incoming_call)?;
391 }
392
393 let session = Session {
394 user_id,
395 connection_id,
396 db: Arc::new(Mutex::new(DbHandle(this.app_state.db.clone()))),
397 peer: this.peer.clone(),
398 connection_pool: this.connection_pool.clone(),
399 live_kit_client: this.app_state.live_kit_client.clone()
400 };
401 update_user_contacts(user_id, &session).await?;
402
403 let handle_io = handle_io.fuse();
404 futures::pin_mut!(handle_io);
405
406 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
407 // This prevents deadlocks when e.g., client A performs a request to client B and
408 // client B performs a request to client A. If both clients stop processing further
409 // messages until their respective request completes, they won't have a chance to
410 // respond to the other client's request and cause a deadlock.
411 //
412 // This arrangement ensures we will attempt to process earlier messages first, but fall
413 // back to processing messages arrived later in the spirit of making progress.
414 let mut foreground_message_handlers = FuturesUnordered::new();
415 loop {
416 let next_message = incoming_rx.next().fuse();
417 futures::pin_mut!(next_message);
418 futures::select_biased! {
419 result = handle_io => {
420 if let Err(error) = result {
421 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
422 }
423 break;
424 }
425 _ = foreground_message_handlers.next() => {}
426 message = next_message => {
427 if let Some(message) = message {
428 let type_name = message.payload_type_name();
429 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
430 let span_enter = span.enter();
431 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
432 let is_background = message.is_background();
433 let handle_message = (handler)(message, session.clone());
434 drop(span_enter);
435
436 let handle_message = handle_message.instrument(span);
437 if is_background {
438 executor.spawn_detached(handle_message);
439 } else {
440 foreground_message_handlers.push(handle_message);
441 }
442 } else {
443 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
444 }
445 } else {
446 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
447 break;
448 }
449 }
450 }
451 }
452
453 drop(foreground_message_handlers);
454 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
455 if let Err(error) = sign_out(session).await {
456 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
457 }
458
459 Ok(())
460 }.instrument(span)
461 }
462
463 pub async fn invite_code_redeemed(
464 self: &Arc<Self>,
465 inviter_id: UserId,
466 invitee_id: UserId,
467 ) -> Result<()> {
468 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
469 if let Some(code) = &user.invite_code {
470 let pool = self.connection_pool.lock().await;
471 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
472 for connection_id in pool.user_connection_ids(inviter_id) {
473 self.peer.send(
474 connection_id,
475 proto::UpdateContacts {
476 contacts: vec![invitee_contact.clone()],
477 ..Default::default()
478 },
479 )?;
480 self.peer.send(
481 connection_id,
482 proto::UpdateInviteInfo {
483 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
484 count: user.invite_count as u32,
485 },
486 )?;
487 }
488 }
489 }
490 Ok(())
491 }
492
493 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
494 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
495 if let Some(invite_code) = &user.invite_code {
496 let pool = self.connection_pool.lock().await;
497 for connection_id in pool.user_connection_ids(user_id) {
498 self.peer.send(
499 connection_id,
500 proto::UpdateInviteInfo {
501 url: format!(
502 "{}{}",
503 self.app_state.config.invite_link_prefix, invite_code
504 ),
505 count: user.invite_count as u32,
506 },
507 )?;
508 }
509 }
510 }
511 Ok(())
512 }
513
514 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
515 ServerSnapshot {
516 connection_pool: ConnectionPoolGuard {
517 guard: self.connection_pool.lock().await,
518 _not_send: PhantomData,
519 },
520 peer: &self.peer,
521 }
522 }
523}
524
525impl<'a> Deref for ConnectionPoolGuard<'a> {
526 type Target = ConnectionPool;
527
528 fn deref(&self) -> &Self::Target {
529 &*self.guard
530 }
531}
532
533impl<'a> DerefMut for ConnectionPoolGuard<'a> {
534 fn deref_mut(&mut self) -> &mut Self::Target {
535 &mut *self.guard
536 }
537}
538
539impl<'a> Drop for ConnectionPoolGuard<'a> {
540 fn drop(&mut self) {
541 #[cfg(test)]
542 self.check_invariants();
543 }
544}
545
546impl Executor for RealExecutor {
547 type Sleep = Sleep;
548
549 fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
550 tokio::task::spawn(future);
551 }
552
553 fn sleep(&self, duration: Duration) -> Self::Sleep {
554 tokio::time::sleep(duration)
555 }
556}
557
558fn broadcast<F>(
559 sender_id: ConnectionId,
560 receiver_ids: impl IntoIterator<Item = ConnectionId>,
561 mut f: F,
562) where
563 F: FnMut(ConnectionId) -> anyhow::Result<()>,
564{
565 for receiver_id in receiver_ids {
566 if receiver_id != sender_id {
567 f(receiver_id).trace_err();
568 }
569 }
570}
571
572lazy_static! {
573 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
574}
575
576pub struct ProtocolVersion(u32);
577
578impl Header for ProtocolVersion {
579 fn name() -> &'static HeaderName {
580 &ZED_PROTOCOL_VERSION
581 }
582
583 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
584 where
585 Self: Sized,
586 I: Iterator<Item = &'i axum::http::HeaderValue>,
587 {
588 let version = values
589 .next()
590 .ok_or_else(axum::headers::Error::invalid)?
591 .to_str()
592 .map_err(|_| axum::headers::Error::invalid())?
593 .parse()
594 .map_err(|_| axum::headers::Error::invalid())?;
595 Ok(Self(version))
596 }
597
598 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
599 values.extend([self.0.to_string().parse().unwrap()]);
600 }
601}
602
603pub fn routes(server: Arc<Server>) -> Router<Body> {
604 Router::new()
605 .route("/rpc", get(handle_websocket_request))
606 .layer(
607 ServiceBuilder::new()
608 .layer(Extension(server.app_state.clone()))
609 .layer(middleware::from_fn(auth::validate_header)),
610 )
611 .route("/metrics", get(handle_metrics))
612 .layer(Extension(server))
613}
614
615pub async fn handle_websocket_request(
616 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
617 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
618 Extension(server): Extension<Arc<Server>>,
619 Extension(user): Extension<User>,
620 ws: WebSocketUpgrade,
621) -> axum::response::Response {
622 if protocol_version != rpc::PROTOCOL_VERSION {
623 return (
624 StatusCode::UPGRADE_REQUIRED,
625 "client must be upgraded".to_string(),
626 )
627 .into_response();
628 }
629 let socket_address = socket_address.to_string();
630 ws.on_upgrade(move |socket| {
631 use util::ResultExt;
632 let socket = socket
633 .map_ok(to_tungstenite_message)
634 .err_into()
635 .with(|message| async move { Ok(to_axum_message(message)) });
636 let connection = Connection::new(Box::pin(socket));
637 async move {
638 server
639 .handle_connection(connection, socket_address, user, None, RealExecutor)
640 .await
641 .log_err();
642 }
643 })
644}
645
646pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
647 let connections = server
648 .connection_pool
649 .lock()
650 .await
651 .connections()
652 .filter(|connection| !connection.admin)
653 .count();
654
655 METRIC_CONNECTIONS.set(connections as _);
656
657 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
658 METRIC_SHARED_PROJECTS.set(shared_projects as _);
659
660 let encoder = prometheus::TextEncoder::new();
661 let metric_families = prometheus::gather();
662 let encoded_metrics = encoder
663 .encode_to_string(&metric_families)
664 .map_err(|err| anyhow!("{}", err))?;
665 Ok(encoded_metrics)
666}
667
668#[instrument(err)]
669async fn sign_out(session: Session) -> Result<()> {
670 session.peer.disconnect(session.connection_id);
671 let decline_calls = {
672 let mut pool = session.connection_pool().await;
673 pool.remove_connection(session.connection_id)?;
674 let mut connections = pool.user_connection_ids(session.user_id);
675 connections.next().is_none()
676 };
677
678 leave_room_for_session(&session).await.trace_err();
679 if decline_calls {
680 if let Some(room) = session
681 .db()
682 .await
683 .decline_call(None, session.user_id)
684 .await
685 .trace_err()
686 {
687 room_updated(&room, &session);
688 }
689 }
690
691 update_user_contacts(session.user_id, &session).await?;
692
693 Ok(())
694}
695
696async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
697 response.send(proto::Ack {})?;
698 Ok(())
699}
700
701async fn create_room(
702 _request: proto::CreateRoom,
703 response: Response<proto::CreateRoom>,
704 session: Session,
705) -> Result<()> {
706 let live_kit_room = nanoid::nanoid!(30);
707 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
708 if let Some(_) = live_kit
709 .create_room(live_kit_room.clone())
710 .await
711 .trace_err()
712 {
713 if let Some(token) = live_kit
714 .room_token(&live_kit_room, &session.connection_id.to_string())
715 .trace_err()
716 {
717 Some(proto::LiveKitConnectionInfo {
718 server_url: live_kit.url().into(),
719 token,
720 })
721 } else {
722 None
723 }
724 } else {
725 None
726 }
727 } else {
728 None
729 };
730
731 {
732 let room = session
733 .db()
734 .await
735 .create_room(session.user_id, session.connection_id, &live_kit_room)
736 .await?;
737
738 response.send(proto::CreateRoomResponse {
739 room: Some(room.clone()),
740 live_kit_connection_info,
741 })?;
742 }
743
744 update_user_contacts(session.user_id, &session).await?;
745 Ok(())
746}
747
748async fn join_room(
749 request: proto::JoinRoom,
750 response: Response<proto::JoinRoom>,
751 session: Session,
752) -> Result<()> {
753 let room = {
754 let room = session
755 .db()
756 .await
757 .join_room(
758 RoomId::from_proto(request.id),
759 session.user_id,
760 session.connection_id,
761 )
762 .await?;
763 room_updated(&room, &session);
764 room.clone()
765 };
766
767 for connection_id in session
768 .connection_pool()
769 .await
770 .user_connection_ids(session.user_id)
771 {
772 session
773 .peer
774 .send(connection_id, proto::CallCanceled {})
775 .trace_err();
776 }
777
778 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
779 if let Some(token) = live_kit
780 .room_token(&room.live_kit_room, &session.connection_id.to_string())
781 .trace_err()
782 {
783 Some(proto::LiveKitConnectionInfo {
784 server_url: live_kit.url().into(),
785 token,
786 })
787 } else {
788 None
789 }
790 } else {
791 None
792 };
793
794 response.send(proto::JoinRoomResponse {
795 room: Some(room),
796 live_kit_connection_info,
797 })?;
798
799 update_user_contacts(session.user_id, &session).await?;
800 Ok(())
801}
802
803async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
804 leave_room_for_session(&session).await
805}
806
807async fn call(
808 request: proto::Call,
809 response: Response<proto::Call>,
810 session: Session,
811) -> Result<()> {
812 let room_id = RoomId::from_proto(request.room_id);
813 let calling_user_id = session.user_id;
814 let calling_connection_id = session.connection_id;
815 let called_user_id = UserId::from_proto(request.called_user_id);
816 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
817 if !session
818 .db()
819 .await
820 .has_contact(calling_user_id, called_user_id)
821 .await?
822 {
823 return Err(anyhow!("cannot call a user who isn't a contact"))?;
824 }
825
826 let incoming_call = {
827 let (room, incoming_call) = &mut *session
828 .db()
829 .await
830 .call(
831 room_id,
832 calling_user_id,
833 calling_connection_id,
834 called_user_id,
835 initial_project_id,
836 )
837 .await?;
838 room_updated(&room, &session);
839 mem::take(incoming_call)
840 };
841 update_user_contacts(called_user_id, &session).await?;
842
843 let mut calls = session
844 .connection_pool()
845 .await
846 .user_connection_ids(called_user_id)
847 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
848 .collect::<FuturesUnordered<_>>();
849
850 while let Some(call_response) = calls.next().await {
851 match call_response.as_ref() {
852 Ok(_) => {
853 response.send(proto::Ack {})?;
854 return Ok(());
855 }
856 Err(_) => {
857 call_response.trace_err();
858 }
859 }
860 }
861
862 {
863 let room = session
864 .db()
865 .await
866 .call_failed(room_id, called_user_id)
867 .await?;
868 room_updated(&room, &session);
869 }
870 update_user_contacts(called_user_id, &session).await?;
871
872 Err(anyhow!("failed to ring user"))?
873}
874
875async fn cancel_call(
876 request: proto::CancelCall,
877 response: Response<proto::CancelCall>,
878 session: Session,
879) -> Result<()> {
880 let called_user_id = UserId::from_proto(request.called_user_id);
881 let room_id = RoomId::from_proto(request.room_id);
882 {
883 let room = session
884 .db()
885 .await
886 .cancel_call(Some(room_id), session.connection_id, called_user_id)
887 .await?;
888 room_updated(&room, &session);
889 }
890
891 for connection_id in session
892 .connection_pool()
893 .await
894 .user_connection_ids(called_user_id)
895 {
896 session
897 .peer
898 .send(connection_id, proto::CallCanceled {})
899 .trace_err();
900 }
901 response.send(proto::Ack {})?;
902
903 update_user_contacts(called_user_id, &session).await?;
904 Ok(())
905}
906
907async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
908 let room_id = RoomId::from_proto(message.room_id);
909 {
910 let room = session
911 .db()
912 .await
913 .decline_call(Some(room_id), session.user_id)
914 .await?;
915 room_updated(&room, &session);
916 }
917
918 for connection_id in session
919 .connection_pool()
920 .await
921 .user_connection_ids(session.user_id)
922 {
923 session
924 .peer
925 .send(connection_id, proto::CallCanceled {})
926 .trace_err();
927 }
928 update_user_contacts(session.user_id, &session).await?;
929 Ok(())
930}
931
932async fn update_participant_location(
933 request: proto::UpdateParticipantLocation,
934 response: Response<proto::UpdateParticipantLocation>,
935 session: Session,
936) -> Result<()> {
937 let room_id = RoomId::from_proto(request.room_id);
938 let location = request
939 .location
940 .ok_or_else(|| anyhow!("invalid location"))?;
941 let room = session
942 .db()
943 .await
944 .update_room_participant_location(room_id, session.connection_id, location)
945 .await?;
946 room_updated(&room, &session);
947 response.send(proto::Ack {})?;
948 Ok(())
949}
950
951async fn share_project(
952 request: proto::ShareProject,
953 response: Response<proto::ShareProject>,
954 session: Session,
955) -> Result<()> {
956 let (project_id, room) = &*session
957 .db()
958 .await
959 .share_project(
960 RoomId::from_proto(request.room_id),
961 session.connection_id,
962 &request.worktrees,
963 )
964 .await?;
965 response.send(proto::ShareProjectResponse {
966 project_id: project_id.to_proto(),
967 })?;
968 room_updated(&room, &session);
969
970 Ok(())
971}
972
973async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
974 let project_id = ProjectId::from_proto(message.project_id);
975
976 let (room, guest_connection_ids) = &*session
977 .db()
978 .await
979 .unshare_project(project_id, session.connection_id)
980 .await?;
981
982 broadcast(
983 session.connection_id,
984 guest_connection_ids.iter().copied(),
985 |conn_id| session.peer.send(conn_id, message.clone()),
986 );
987 room_updated(&room, &session);
988
989 Ok(())
990}
991
992async fn join_project(
993 request: proto::JoinProject,
994 response: Response<proto::JoinProject>,
995 session: Session,
996) -> Result<()> {
997 let project_id = ProjectId::from_proto(request.project_id);
998 let guest_user_id = session.user_id;
999
1000 tracing::info!(%project_id, "join project");
1001
1002 let (project, replica_id) = &mut *session
1003 .db()
1004 .await
1005 .join_project(project_id, session.connection_id)
1006 .await?;
1007
1008 let collaborators = project
1009 .collaborators
1010 .iter()
1011 .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32)
1012 .map(|collaborator| proto::Collaborator {
1013 peer_id: collaborator.connection_id as u32,
1014 replica_id: collaborator.replica_id.0 as u32,
1015 user_id: collaborator.user_id.to_proto(),
1016 })
1017 .collect::<Vec<_>>();
1018 let worktrees = project
1019 .worktrees
1020 .iter()
1021 .map(|(id, worktree)| proto::WorktreeMetadata {
1022 id: *id,
1023 root_name: worktree.root_name.clone(),
1024 visible: worktree.visible,
1025 abs_path: worktree.abs_path.clone(),
1026 })
1027 .collect::<Vec<_>>();
1028
1029 for collaborator in &collaborators {
1030 session
1031 .peer
1032 .send(
1033 ConnectionId(collaborator.peer_id),
1034 proto::AddProjectCollaborator {
1035 project_id: project_id.to_proto(),
1036 collaborator: Some(proto::Collaborator {
1037 peer_id: session.connection_id.0,
1038 replica_id: replica_id.0 as u32,
1039 user_id: guest_user_id.to_proto(),
1040 }),
1041 },
1042 )
1043 .trace_err();
1044 }
1045
1046 // First, we send the metadata associated with each worktree.
1047 response.send(proto::JoinProjectResponse {
1048 worktrees: worktrees.clone(),
1049 replica_id: replica_id.0 as u32,
1050 collaborators: collaborators.clone(),
1051 language_servers: project.language_servers.clone(),
1052 })?;
1053
1054 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1055 #[cfg(any(test, feature = "test-support"))]
1056 const MAX_CHUNK_SIZE: usize = 2;
1057 #[cfg(not(any(test, feature = "test-support")))]
1058 const MAX_CHUNK_SIZE: usize = 256;
1059
1060 // Stream this worktree's entries.
1061 let message = proto::UpdateWorktree {
1062 project_id: project_id.to_proto(),
1063 worktree_id,
1064 abs_path: worktree.abs_path.clone(),
1065 root_name: worktree.root_name,
1066 updated_entries: worktree.entries,
1067 removed_entries: Default::default(),
1068 scan_id: worktree.scan_id,
1069 is_last_update: worktree.is_complete,
1070 };
1071 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1072 session.peer.send(session.connection_id, update.clone())?;
1073 }
1074
1075 // Stream this worktree's diagnostics.
1076 for summary in worktree.diagnostic_summaries {
1077 session.peer.send(
1078 session.connection_id,
1079 proto::UpdateDiagnosticSummary {
1080 project_id: project_id.to_proto(),
1081 worktree_id: worktree.id,
1082 summary: Some(summary),
1083 },
1084 )?;
1085 }
1086 }
1087
1088 for language_server in &project.language_servers {
1089 session.peer.send(
1090 session.connection_id,
1091 proto::UpdateLanguageServer {
1092 project_id: project_id.to_proto(),
1093 language_server_id: language_server.id,
1094 variant: Some(
1095 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1096 proto::LspDiskBasedDiagnosticsUpdated {},
1097 ),
1098 ),
1099 },
1100 )?;
1101 }
1102
1103 Ok(())
1104}
1105
1106async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1107 let sender_id = session.connection_id;
1108 let project_id = ProjectId::from_proto(request.project_id);
1109
1110 let project = session
1111 .db()
1112 .await
1113 .leave_project(project_id, sender_id)
1114 .await?;
1115 tracing::info!(
1116 %project_id,
1117 host_user_id = %project.host_user_id,
1118 host_connection_id = %project.host_connection_id,
1119 "leave project"
1120 );
1121
1122 broadcast(
1123 sender_id,
1124 project.connection_ids.iter().copied(),
1125 |conn_id| {
1126 session.peer.send(
1127 conn_id,
1128 proto::RemoveProjectCollaborator {
1129 project_id: project_id.to_proto(),
1130 peer_id: sender_id.0,
1131 },
1132 )
1133 },
1134 );
1135
1136 Ok(())
1137}
1138
1139async fn update_project(
1140 request: proto::UpdateProject,
1141 response: Response<proto::UpdateProject>,
1142 session: Session,
1143) -> Result<()> {
1144 let project_id = ProjectId::from_proto(request.project_id);
1145 let (room, guest_connection_ids) = &*session
1146 .db()
1147 .await
1148 .update_project(project_id, session.connection_id, &request.worktrees)
1149 .await?;
1150 broadcast(
1151 session.connection_id,
1152 guest_connection_ids.iter().copied(),
1153 |connection_id| {
1154 session
1155 .peer
1156 .forward_send(session.connection_id, connection_id, request.clone())
1157 },
1158 );
1159 room_updated(&room, &session);
1160 response.send(proto::Ack {})?;
1161
1162 Ok(())
1163}
1164
1165async fn update_worktree(
1166 request: proto::UpdateWorktree,
1167 response: Response<proto::UpdateWorktree>,
1168 session: Session,
1169) -> Result<()> {
1170 let guest_connection_ids = session
1171 .db()
1172 .await
1173 .update_worktree(&request, session.connection_id)
1174 .await?;
1175
1176 broadcast(
1177 session.connection_id,
1178 guest_connection_ids.iter().copied(),
1179 |connection_id| {
1180 session
1181 .peer
1182 .forward_send(session.connection_id, connection_id, request.clone())
1183 },
1184 );
1185 response.send(proto::Ack {})?;
1186 Ok(())
1187}
1188
1189async fn update_diagnostic_summary(
1190 message: proto::UpdateDiagnosticSummary,
1191 session: Session,
1192) -> Result<()> {
1193 let guest_connection_ids = session
1194 .db()
1195 .await
1196 .update_diagnostic_summary(&message, session.connection_id)
1197 .await?;
1198
1199 broadcast(
1200 session.connection_id,
1201 guest_connection_ids.iter().copied(),
1202 |connection_id| {
1203 session
1204 .peer
1205 .forward_send(session.connection_id, connection_id, message.clone())
1206 },
1207 );
1208
1209 Ok(())
1210}
1211
1212async fn start_language_server(
1213 request: proto::StartLanguageServer,
1214 session: Session,
1215) -> Result<()> {
1216 let guest_connection_ids = session
1217 .db()
1218 .await
1219 .start_language_server(&request, session.connection_id)
1220 .await?;
1221
1222 broadcast(
1223 session.connection_id,
1224 guest_connection_ids.iter().copied(),
1225 |connection_id| {
1226 session
1227 .peer
1228 .forward_send(session.connection_id, connection_id, request.clone())
1229 },
1230 );
1231 Ok(())
1232}
1233
1234async fn update_language_server(
1235 request: proto::UpdateLanguageServer,
1236 session: Session,
1237) -> Result<()> {
1238 let project_id = ProjectId::from_proto(request.project_id);
1239 let project_connection_ids = session
1240 .db()
1241 .await
1242 .project_connection_ids(project_id, session.connection_id)
1243 .await?;
1244 broadcast(
1245 session.connection_id,
1246 project_connection_ids.iter().copied(),
1247 |connection_id| {
1248 session
1249 .peer
1250 .forward_send(session.connection_id, connection_id, request.clone())
1251 },
1252 );
1253 Ok(())
1254}
1255
1256async fn forward_project_request<T>(
1257 request: T,
1258 response: Response<T>,
1259 session: Session,
1260) -> Result<()>
1261where
1262 T: EntityMessage + RequestMessage,
1263{
1264 let project_id = ProjectId::from_proto(request.remote_entity_id());
1265 let host_connection_id = {
1266 let collaborators = session
1267 .db()
1268 .await
1269 .project_collaborators(project_id, session.connection_id)
1270 .await?;
1271 ConnectionId(
1272 collaborators
1273 .iter()
1274 .find(|collaborator| collaborator.is_host)
1275 .ok_or_else(|| anyhow!("host not found"))?
1276 .connection_id as u32,
1277 )
1278 };
1279
1280 let payload = session
1281 .peer
1282 .forward_request(session.connection_id, host_connection_id, request)
1283 .await?;
1284
1285 response.send(payload)?;
1286 Ok(())
1287}
1288
1289async fn save_buffer(
1290 request: proto::SaveBuffer,
1291 response: Response<proto::SaveBuffer>,
1292 session: Session,
1293) -> Result<()> {
1294 let project_id = ProjectId::from_proto(request.project_id);
1295 let host_connection_id = {
1296 let collaborators = session
1297 .db()
1298 .await
1299 .project_collaborators(project_id, session.connection_id)
1300 .await?;
1301 let host = collaborators
1302 .iter()
1303 .find(|collaborator| collaborator.is_host)
1304 .ok_or_else(|| anyhow!("host not found"))?;
1305 ConnectionId(host.connection_id as u32)
1306 };
1307 let response_payload = session
1308 .peer
1309 .forward_request(session.connection_id, host_connection_id, request.clone())
1310 .await?;
1311
1312 let mut collaborators = session
1313 .db()
1314 .await
1315 .project_collaborators(project_id, session.connection_id)
1316 .await?;
1317 collaborators
1318 .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32);
1319 let project_connection_ids = collaborators
1320 .iter()
1321 .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
1322 broadcast(host_connection_id, project_connection_ids, |conn_id| {
1323 session
1324 .peer
1325 .forward_send(host_connection_id, conn_id, response_payload.clone())
1326 });
1327 response.send(response_payload)?;
1328 Ok(())
1329}
1330
1331async fn create_buffer_for_peer(
1332 request: proto::CreateBufferForPeer,
1333 session: Session,
1334) -> Result<()> {
1335 session.peer.forward_send(
1336 session.connection_id,
1337 ConnectionId(request.peer_id),
1338 request,
1339 )?;
1340 Ok(())
1341}
1342
1343async fn update_buffer(
1344 request: proto::UpdateBuffer,
1345 response: Response<proto::UpdateBuffer>,
1346 session: Session,
1347) -> Result<()> {
1348 let project_id = ProjectId::from_proto(request.project_id);
1349 let project_connection_ids = session
1350 .db()
1351 .await
1352 .project_connection_ids(project_id, session.connection_id)
1353 .await?;
1354
1355 broadcast(
1356 session.connection_id,
1357 project_connection_ids.iter().copied(),
1358 |connection_id| {
1359 session
1360 .peer
1361 .forward_send(session.connection_id, connection_id, request.clone())
1362 },
1363 );
1364 response.send(proto::Ack {})?;
1365 Ok(())
1366}
1367
1368async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1369 let project_id = ProjectId::from_proto(request.project_id);
1370 let project_connection_ids = session
1371 .db()
1372 .await
1373 .project_connection_ids(project_id, session.connection_id)
1374 .await?;
1375
1376 broadcast(
1377 session.connection_id,
1378 project_connection_ids.iter().copied(),
1379 |connection_id| {
1380 session
1381 .peer
1382 .forward_send(session.connection_id, connection_id, request.clone())
1383 },
1384 );
1385 Ok(())
1386}
1387
1388async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1389 let project_id = ProjectId::from_proto(request.project_id);
1390 let project_connection_ids = session
1391 .db()
1392 .await
1393 .project_connection_ids(project_id, session.connection_id)
1394 .await?;
1395 broadcast(
1396 session.connection_id,
1397 project_connection_ids.iter().copied(),
1398 |connection_id| {
1399 session
1400 .peer
1401 .forward_send(session.connection_id, connection_id, request.clone())
1402 },
1403 );
1404 Ok(())
1405}
1406
1407async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1408 let project_id = ProjectId::from_proto(request.project_id);
1409 let project_connection_ids = session
1410 .db()
1411 .await
1412 .project_connection_ids(project_id, session.connection_id)
1413 .await?;
1414 broadcast(
1415 session.connection_id,
1416 project_connection_ids.iter().copied(),
1417 |connection_id| {
1418 session
1419 .peer
1420 .forward_send(session.connection_id, connection_id, request.clone())
1421 },
1422 );
1423 Ok(())
1424}
1425
1426async fn follow(
1427 request: proto::Follow,
1428 response: Response<proto::Follow>,
1429 session: Session,
1430) -> Result<()> {
1431 let project_id = ProjectId::from_proto(request.project_id);
1432 let leader_id = ConnectionId(request.leader_id);
1433 let follower_id = session.connection_id;
1434 {
1435 let project_connection_ids = session
1436 .db()
1437 .await
1438 .project_connection_ids(project_id, session.connection_id)
1439 .await?;
1440
1441 if !project_connection_ids.contains(&leader_id) {
1442 Err(anyhow!("no such peer"))?;
1443 }
1444 }
1445
1446 let mut response_payload = session
1447 .peer
1448 .forward_request(session.connection_id, leader_id, request)
1449 .await?;
1450 response_payload
1451 .views
1452 .retain(|view| view.leader_id != Some(follower_id.0));
1453 response.send(response_payload)?;
1454 Ok(())
1455}
1456
1457async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1458 let project_id = ProjectId::from_proto(request.project_id);
1459 let leader_id = ConnectionId(request.leader_id);
1460 let project_connection_ids = session
1461 .db()
1462 .await
1463 .project_connection_ids(project_id, session.connection_id)
1464 .await?;
1465 if !project_connection_ids.contains(&leader_id) {
1466 Err(anyhow!("no such peer"))?;
1467 }
1468 session
1469 .peer
1470 .forward_send(session.connection_id, leader_id, request)?;
1471 Ok(())
1472}
1473
1474async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1475 let project_id = ProjectId::from_proto(request.project_id);
1476 let project_connection_ids = session
1477 .db
1478 .lock()
1479 .await
1480 .project_connection_ids(project_id, session.connection_id)
1481 .await?;
1482
1483 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1484 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1485 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1486 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1487 });
1488 for follower_id in &request.follower_ids {
1489 let follower_id = ConnectionId(*follower_id);
1490 if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1491 session
1492 .peer
1493 .forward_send(session.connection_id, follower_id, request.clone())?;
1494 }
1495 }
1496 Ok(())
1497}
1498
1499async fn get_users(
1500 request: proto::GetUsers,
1501 response: Response<proto::GetUsers>,
1502 session: Session,
1503) -> Result<()> {
1504 let user_ids = request
1505 .user_ids
1506 .into_iter()
1507 .map(UserId::from_proto)
1508 .collect();
1509 let users = session
1510 .db()
1511 .await
1512 .get_users_by_ids(user_ids)
1513 .await?
1514 .into_iter()
1515 .map(|user| proto::User {
1516 id: user.id.to_proto(),
1517 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1518 github_login: user.github_login,
1519 })
1520 .collect();
1521 response.send(proto::UsersResponse { users })?;
1522 Ok(())
1523}
1524
1525async fn fuzzy_search_users(
1526 request: proto::FuzzySearchUsers,
1527 response: Response<proto::FuzzySearchUsers>,
1528 session: Session,
1529) -> Result<()> {
1530 let query = request.query;
1531 let users = match query.len() {
1532 0 => vec![],
1533 1 | 2 => session
1534 .db()
1535 .await
1536 .get_user_by_github_account(&query, None)
1537 .await?
1538 .into_iter()
1539 .collect(),
1540 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1541 };
1542 let users = users
1543 .into_iter()
1544 .filter(|user| user.id != session.user_id)
1545 .map(|user| proto::User {
1546 id: user.id.to_proto(),
1547 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1548 github_login: user.github_login,
1549 })
1550 .collect();
1551 response.send(proto::UsersResponse { users })?;
1552 Ok(())
1553}
1554
1555async fn request_contact(
1556 request: proto::RequestContact,
1557 response: Response<proto::RequestContact>,
1558 session: Session,
1559) -> Result<()> {
1560 let requester_id = session.user_id;
1561 let responder_id = UserId::from_proto(request.responder_id);
1562 if requester_id == responder_id {
1563 return Err(anyhow!("cannot add yourself as a contact"))?;
1564 }
1565
1566 session
1567 .db()
1568 .await
1569 .send_contact_request(requester_id, responder_id)
1570 .await?;
1571
1572 // Update outgoing contact requests of requester
1573 let mut update = proto::UpdateContacts::default();
1574 update.outgoing_requests.push(responder_id.to_proto());
1575 for connection_id in session
1576 .connection_pool()
1577 .await
1578 .user_connection_ids(requester_id)
1579 {
1580 session.peer.send(connection_id, update.clone())?;
1581 }
1582
1583 // Update incoming contact requests of responder
1584 let mut update = proto::UpdateContacts::default();
1585 update
1586 .incoming_requests
1587 .push(proto::IncomingContactRequest {
1588 requester_id: requester_id.to_proto(),
1589 should_notify: true,
1590 });
1591 for connection_id in session
1592 .connection_pool()
1593 .await
1594 .user_connection_ids(responder_id)
1595 {
1596 session.peer.send(connection_id, update.clone())?;
1597 }
1598
1599 response.send(proto::Ack {})?;
1600 Ok(())
1601}
1602
1603async fn respond_to_contact_request(
1604 request: proto::RespondToContactRequest,
1605 response: Response<proto::RespondToContactRequest>,
1606 session: Session,
1607) -> Result<()> {
1608 let responder_id = session.user_id;
1609 let requester_id = UserId::from_proto(request.requester_id);
1610 let db = session.db().await;
1611 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1612 db.dismiss_contact_notification(responder_id, requester_id)
1613 .await?;
1614 } else {
1615 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1616
1617 db.respond_to_contact_request(responder_id, requester_id, accept)
1618 .await?;
1619 let requester_busy = db.is_user_busy(requester_id).await?;
1620 let responder_busy = db.is_user_busy(responder_id).await?;
1621
1622 let pool = session.connection_pool().await;
1623 // Update responder with new contact
1624 let mut update = proto::UpdateContacts::default();
1625 if accept {
1626 update
1627 .contacts
1628 .push(contact_for_user(requester_id, false, requester_busy, &pool));
1629 }
1630 update
1631 .remove_incoming_requests
1632 .push(requester_id.to_proto());
1633 for connection_id in pool.user_connection_ids(responder_id) {
1634 session.peer.send(connection_id, update.clone())?;
1635 }
1636
1637 // Update requester with new contact
1638 let mut update = proto::UpdateContacts::default();
1639 if accept {
1640 update
1641 .contacts
1642 .push(contact_for_user(responder_id, true, responder_busy, &pool));
1643 }
1644 update
1645 .remove_outgoing_requests
1646 .push(responder_id.to_proto());
1647 for connection_id in pool.user_connection_ids(requester_id) {
1648 session.peer.send(connection_id, update.clone())?;
1649 }
1650 }
1651
1652 response.send(proto::Ack {})?;
1653 Ok(())
1654}
1655
1656async fn remove_contact(
1657 request: proto::RemoveContact,
1658 response: Response<proto::RemoveContact>,
1659 session: Session,
1660) -> Result<()> {
1661 let requester_id = session.user_id;
1662 let responder_id = UserId::from_proto(request.user_id);
1663 let db = session.db().await;
1664 db.remove_contact(requester_id, responder_id).await?;
1665
1666 let pool = session.connection_pool().await;
1667 // Update outgoing contact requests of requester
1668 let mut update = proto::UpdateContacts::default();
1669 update
1670 .remove_outgoing_requests
1671 .push(responder_id.to_proto());
1672 for connection_id in pool.user_connection_ids(requester_id) {
1673 session.peer.send(connection_id, update.clone())?;
1674 }
1675
1676 // Update incoming contact requests of responder
1677 let mut update = proto::UpdateContacts::default();
1678 update
1679 .remove_incoming_requests
1680 .push(requester_id.to_proto());
1681 for connection_id in pool.user_connection_ids(responder_id) {
1682 session.peer.send(connection_id, update.clone())?;
1683 }
1684
1685 response.send(proto::Ack {})?;
1686 Ok(())
1687}
1688
1689async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1690 let project_id = ProjectId::from_proto(request.project_id);
1691 let project_connection_ids = session
1692 .db()
1693 .await
1694 .project_connection_ids(project_id, session.connection_id)
1695 .await?;
1696 broadcast(
1697 session.connection_id,
1698 project_connection_ids.iter().copied(),
1699 |connection_id| {
1700 session
1701 .peer
1702 .forward_send(session.connection_id, connection_id, request.clone())
1703 },
1704 );
1705 Ok(())
1706}
1707
1708async fn get_private_user_info(
1709 _request: proto::GetPrivateUserInfo,
1710 response: Response<proto::GetPrivateUserInfo>,
1711 session: Session,
1712) -> Result<()> {
1713 let metrics_id = session
1714 .db()
1715 .await
1716 .get_user_metrics_id(session.user_id)
1717 .await?;
1718 let user = session
1719 .db()
1720 .await
1721 .get_user_by_id(session.user_id)
1722 .await?
1723 .ok_or_else(|| anyhow!("user not found"))?;
1724 response.send(proto::GetPrivateUserInfoResponse {
1725 metrics_id,
1726 staff: user.admin,
1727 })?;
1728 Ok(())
1729}
1730
1731fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1732 match message {
1733 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1734 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1735 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1736 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1737 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1738 code: frame.code.into(),
1739 reason: frame.reason,
1740 })),
1741 }
1742}
1743
1744fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1745 match message {
1746 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1747 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1748 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1749 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1750 AxumMessage::Close(frame) => {
1751 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1752 code: frame.code.into(),
1753 reason: frame.reason,
1754 }))
1755 }
1756 }
1757}
1758
1759fn build_initial_contacts_update(
1760 contacts: Vec<db::Contact>,
1761 pool: &ConnectionPool,
1762) -> proto::UpdateContacts {
1763 let mut update = proto::UpdateContacts::default();
1764
1765 for contact in contacts {
1766 match contact {
1767 db::Contact::Accepted {
1768 user_id,
1769 should_notify,
1770 busy,
1771 } => {
1772 update
1773 .contacts
1774 .push(contact_for_user(user_id, should_notify, busy, &pool));
1775 }
1776 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
1777 db::Contact::Incoming {
1778 user_id,
1779 should_notify,
1780 } => update
1781 .incoming_requests
1782 .push(proto::IncomingContactRequest {
1783 requester_id: user_id.to_proto(),
1784 should_notify,
1785 }),
1786 }
1787 }
1788
1789 update
1790}
1791
1792fn contact_for_user(
1793 user_id: UserId,
1794 should_notify: bool,
1795 busy: bool,
1796 pool: &ConnectionPool,
1797) -> proto::Contact {
1798 proto::Contact {
1799 user_id: user_id.to_proto(),
1800 online: pool.is_user_online(user_id),
1801 busy,
1802 should_notify,
1803 }
1804}
1805
1806fn room_updated(room: &proto::Room, session: &Session) {
1807 for participant in &room.participants {
1808 session
1809 .peer
1810 .send(
1811 ConnectionId(participant.peer_id),
1812 proto::RoomUpdated {
1813 room: Some(room.clone()),
1814 },
1815 )
1816 .trace_err();
1817 }
1818}
1819
1820async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
1821 let db = session.db().await;
1822 let contacts = db.get_contacts(user_id).await?;
1823 let busy = db.is_user_busy(user_id).await?;
1824
1825 let pool = session.connection_pool().await;
1826 let updated_contact = contact_for_user(user_id, false, busy, &pool);
1827 for contact in contacts {
1828 if let db::Contact::Accepted {
1829 user_id: contact_user_id,
1830 ..
1831 } = contact
1832 {
1833 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
1834 session
1835 .peer
1836 .send(
1837 contact_conn_id,
1838 proto::UpdateContacts {
1839 contacts: vec![updated_contact.clone()],
1840 remove_contacts: Default::default(),
1841 incoming_requests: Default::default(),
1842 remove_incoming_requests: Default::default(),
1843 outgoing_requests: Default::default(),
1844 remove_outgoing_requests: Default::default(),
1845 },
1846 )
1847 .trace_err();
1848 }
1849 }
1850 }
1851 Ok(())
1852}
1853
1854async fn leave_room_for_session(session: &Session) -> Result<()> {
1855 let mut contacts_to_update = HashSet::default();
1856
1857 let canceled_calls_to_user_ids;
1858 let live_kit_room;
1859 let delete_live_kit_room;
1860 {
1861 let mut left_room = session.db().await.leave_room(session.connection_id).await?;
1862 contacts_to_update.insert(session.user_id);
1863
1864 for project in left_room.left_projects.values() {
1865 for connection_id in &project.connection_ids {
1866 if project.host_user_id == session.user_id {
1867 session
1868 .peer
1869 .send(
1870 *connection_id,
1871 proto::UnshareProject {
1872 project_id: project.id.to_proto(),
1873 },
1874 )
1875 .trace_err();
1876 } else {
1877 session
1878 .peer
1879 .send(
1880 *connection_id,
1881 proto::RemoveProjectCollaborator {
1882 project_id: project.id.to_proto(),
1883 peer_id: session.connection_id.0,
1884 },
1885 )
1886 .trace_err();
1887 }
1888 }
1889
1890 session
1891 .peer
1892 .send(
1893 session.connection_id,
1894 proto::UnshareProject {
1895 project_id: project.id.to_proto(),
1896 },
1897 )
1898 .trace_err();
1899 }
1900
1901 room_updated(&left_room.room, &session);
1902 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
1903 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
1904 delete_live_kit_room = left_room.room.participants.is_empty();
1905 }
1906
1907 {
1908 let pool = session.connection_pool().await;
1909 for canceled_user_id in canceled_calls_to_user_ids {
1910 for connection_id in pool.user_connection_ids(canceled_user_id) {
1911 session
1912 .peer
1913 .send(connection_id, proto::CallCanceled {})
1914 .trace_err();
1915 }
1916 contacts_to_update.insert(canceled_user_id);
1917 }
1918 }
1919
1920 for contact_user_id in contacts_to_update {
1921 update_user_contacts(contact_user_id, &session).await?;
1922 }
1923
1924 if let Some(live_kit) = session.live_kit_client.as_ref() {
1925 live_kit
1926 .remove_participant(live_kit_room.clone(), session.connection_id.to_string())
1927 .await
1928 .trace_err();
1929
1930 if delete_live_kit_room {
1931 live_kit.delete_room(live_kit_room).await.trace_err();
1932 }
1933 }
1934
1935 Ok(())
1936}
1937
1938pub trait ResultExt {
1939 type Ok;
1940
1941 fn trace_err(self) -> Option<Self::Ok>;
1942}
1943
1944impl<T, E> ResultExt for Result<T, E>
1945where
1946 E: std::fmt::Debug,
1947{
1948 type Ok = T;
1949
1950 fn trace_err(self) -> Option<T> {
1951 match self {
1952 Ok(value) => Some(value),
1953 Err(error) => {
1954 tracing::error!("{:?}", error);
1955 None
1956 }
1957 }
1958 }
1959}